Compare commits

..

No commits in common. 'main' and '0.02' have entirely different histories.
main ... 0.02

6
.gitignore vendored

@ -5,12 +5,6 @@
*.xlsx
*.xls
wandb/
data/
vocab.json
*.sh
*log/
test/
tools/
# Byte-compiled / optimized / DLL files
__pycache__/

@ -1,201 +1,25 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
BSD 2-Clause License
Copyright (c) 2021, PENG Bo
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -1,502 +1,4 @@
# The RWKV Language Model (and my LM tricks)
## RWKV: Parallelizable RNN with Transformer-level LLM Performance (pronounced as "RwaKuv", from 4 major params: R W K V)
RWKV is an RNN with Transformer-level LLM performance, which can also be directly trained like a GPT transformer (parallelizable). And it's 100% attention-free. You only need the hidden state at position t to compute the state at position t+1. You can use the "GPT" mode to quickly compute the hidden state for the "RNN" mode.
So it's combining the best of RNN and transformer - **great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding** (using the final hidden state).
**HuggingFace Gradio demo (14B ctx8192)**: https://huggingface.co/spaces/BlinkDL/ChatRWKV-gradio
Raven (7B finetuned on Alpaca) Demo: https://huggingface.co/spaces/BlinkDL/Raven-RWKV-7B
**ChatRWKV:** with "stream" and "split" strategies and INT8. **3G VRAM is enough to run RWKV 14B :)** https://github.com/BlinkDL/ChatRWKV
**RWKV pip package**: https://pypi.org/project/rwkv/
```python
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
from rwkv.model import RWKV # pip install rwkv
model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda fp16')
out, state = model.forward([187, 510, 1563, 310, 247], None) # use 20B_tokenizer.json
print(out.detach().cpu().numpy()) # get logits
out, state = model.forward([187, 510], None)
out, state = model.forward([1563], state) # RNN has state (use deepcopy if you want to clone it)
out, state = model.forward([310, 247], state)
print(out.detach().cpu().numpy()) # same result as above
```
**Download RWKV-4 0.1/0.4/1.5/3/7/14B weights**: https://huggingface.co/BlinkDL
## Join Our Discord: https://discord.gg/bDSBUMeFpc (lots of developers)
**Twitter**: https://twitter.com/BlinkDL_AI
**RWKV in 150 lines** (model, inference, text generation): https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py
ChatRWKV with RWKV 14B ctx8192:
![RWKV-chat](RWKV-chat.png)
You are welcome to join the RWKV discord https://discord.gg/bDSBUMeFpc to build upon it. We have plenty of potential compute (A100 40Gs) now (thanks to Stability and EleutherAI), so if you have interesting ideas I can run them.
![RWKV-eval2](RWKV-eval2.png)
RWKV [loss vs token position] for 10000 ctx4k+ documents in Pile. RWKV 1B5-4k is mostly flat after ctx1500, but 3B-4k and 7B-4k and 14B-4k have some slopes, and they are getting better. This debunks the old view that RNNs cannot model long ctxlens. We can predict that RWKV 100B will be great, and RWKV 1T is probably all you need :)
![RWKV-ctxlen](RWKV-ctxlen.png)
I believe RNN is a better candidate for fundamental models, because: (1) It's more friendly for ASICs (no kv cache). (2) It's more friendly for RL. (3) When we write, our brain is more similar to RNN. (4) The universe is like an RNN too (because of locality). Transformers are non-local models.
RWKV-3 1.5B on A40 (tf32) = always 0.015 sec/token, tested using simple pytorch code (no CUDA), GPU utilization 45%, VRAM 7823M
GPT2-XL 1.3B on A40 (tf32) = 0.032 sec/token (for ctxlen 1000), tested using HF, GPU utilization 45% too (interesting), VRAM 9655M
Training speed: (new training code) RWKV-4 14B BF16 ctxlen4096 = 114K tokens/s on 8x8 A100 80G (ZERO2+CP). (old training code) RWKV-4 1.5B BF16 ctxlen1024 = 106K tokens/s on 8xA100 40G.
I am doing image experiments too (For example: https://huggingface.co/BlinkDL/clip-guided-binary-autoencoder) and RWKV will be able to do txt2img diffusion :) My idea: 256x256 rgb image -> 32x32x13bit latents -> apply RWKV to compute transition probability for each of the 32x32 grid -> pretend the grids are independent and "diffuse" using these probabilities.
Smooth training - no loss spikes! (lr & bsz change around 15G tokens)
![RWKV-loss](RWKV-loss.png)
![RWKV-eval](RWKV-eval.png)
All of the trained models will be open-source. Inference is very fast (only matrix-vector multiplications, no matrix-matrix multiplications) even on CPUs, so you can even run a LLM on your phone.
How it works: RWKV gathers information to a number of channels, which are also decaying with different speeds as you move to the next token. It's very simple once you understand it.
**RWKV is parallelizable because the time-decay of each channel is data-independent (and trainable)**. For example, in usual RNN you can adjust the time-decay of a channel from say 0.8 to 0.5 (these are called "gates"), while in RWKV you simply move the information from a W-0.8-channel to a W-0.5-channel to achieve the same effect. Moreover, you can fine-tune RWKV into a non-parallelizable RNN (then you can use outputs of later layers of the previous token) if you want extra performance.
![RWKV-formula](RWKV-formula.png)
Here are some of my TODOs. Let's work together :)
* HuggingFace integration (check https://github.com/huggingface/transformers/issues/17230
), and optimized CPU & iOS & Android & WASM & WebGL inference. RWKV is a RNN and very friendly for edge devices. Let's make it possible to run a LLM on your phone.
* Test it on bidirectional & MLM tasks, and image & audio & video tokens. I think RWKV can support Encoder-Decoder via this: for each decoder token, use a learned mixture of [decoder previous hidden state] & [encoder final hidden state]. Hence all decoder tokens will have access to the encoder output.
* Now training RWKV-4a with one single tiny extra attention (just a few extra lines comparing with RWKV-4) to further improve some difficult zeroshot tasks (such as LAMBADA) for smaller models. See https://github.com/BlinkDL/RWKV-LM/commit/a268cd2e40351ee31c30c5f8a5d1266d35b41829
User feedback:
> *I've so far toyed around the character-based model on our relatively small pre-training dataset (around 10GB of text), and the results are extremely good - similar ppl to models taking much, much longer to train.*
> *dear god rwkv is fast. i switched to another tab after starting training it from scratch & when i returned it was emitting plausible english & maori words, i left to go microwave some coffee & when i came back it was producing fully grammatically correct sentences.*
Tweet from Sepp Hochreiter (thank you!): https://twitter.com/HochreiterSepp/status/1524270961314484227
You can find me (BlinkDL) in the EleutherAI Discord too: https://www.eleuther.ai/get-involved/
![RWKV-demo](RWKV-demo.png)
## Quick start
Use https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v4neo (latest code, compatible with v4).
Here is a great prompt for testing Q&A of LLMs. Works for any model: (found by minimizing ChatGPT ppls for RWKV 1.5B)
```python
prompt = f'\nQ & A\n\nQuestion:\n{qq}\n\nDetailed Expert Answer:\n' # let the model generate after this
```
**Cool Community RWKV Projects (check them!)**:
https://pypi.org/project/rwkvstic/ a pip package (with 8bit & offload for low VRAM GPUs)
https://github.com/harrisonvanderbyl/rwkv_chatbot a chatbot
https://github.com/hizkifw/WebChatRWKVstic WebUI (WIP)
https://github.com/gururise/rwkv_gradio RWKV Gradio
https://github.com/cryscan/eloise RWKV QQ bot
https://github.com/Blealtan/RWKV-LM-LoRA LoRA fine-tuning
https://github.com/mrsteyk/RWKV-LM-jax
https://github.com/wozeparrot/tinyrwkv RWKV in tinygrad (nice simple DL framework)
https://github.com/huggingface/transformers/issues/17230 RWKV HF package (WIP)
https://github.com/ArEnSc/Production-RWKV RWKV HF package source
https://github.com/nlpodyssey/verbaflow RWKV in Go
https://github.com/nlpodyssey/rwkv RWKV in Go
https://github.com/mrsteyk/rwkvk-rs RWKV in Rust
https://github.com/josephrocca/rwkv-v4-web RWKV in browser
https://github.com/imxcstar/CSharp-RWKV-V4 RWKV in C#
https://github.com/mrsteyk/RWKV-LM-deepspeed Another training fork
https://github.com/resloved/RWKV-notebooks RWKV colab notebooks
https://colab.research.google.com/github/harrisonvanderbyl/rwkvstic/blob/master/notebooks/chatbot.ipynb RWKV chatbot colab notebook
https://github.com/Pathos14489/RWKVDistributedInference RWKV Distributed Inference
https://github.com/AXKuhta/rwkv-onnx-dml RWKV ONNX
### Inference
**Run RWKV-4 Pile models:** Download models from https://huggingface.co/BlinkDL. Set TOKEN_MODE = 'pile' in run.py and run it. It's fast even on CPU (the default mode).
**Colab for RWKV-4 Pile 1.5B**: https://colab.research.google.com/drive/1F7tZoPZaWJf1fsCmZ5tjw6sYHiFOYVWM
Run RWKV-4 Pile models in your browser (and onnx version): see this issue https://github.com/BlinkDL/RWKV-LM/issues/7
RWKV-4 Web Demo: https://josephrocca.github.io/rwkv-v4-web/demo/ (note: only greedy sampling for now)
For the old RWKV-2: see the release here for a 27M params model on enwik8 with 0.72 BPC(dev). Run run.py in https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN. You can even run it in your browser: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng https://blinkdl.github.io/AI-Writer/eng/ (this is using tf.js WASM single-thread mode).
### Training / Fine-tuning
**Training RWKV-4 from scratch:** run train.py, which by default is using the enwik8 dataset (unzip https://data.deepai.org/enwik8.zip).
You will be training the "GPT" version because it's paralleziable and faster to train. RWKV-4 can extrapolate, so training with ctxLen 1024 can work for ctxLen of 2500+. You can fine-tune the model with longer ctxLen and it can quickly adapt to longer ctxLens.
**Fine-tuning RWKV-4 Pile models:** use 'prepare-data.py' in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/tree/main/RWKV-v3 to tokenize .txt into train.npy data. Then set EXPRESS_PILE_MODE to True in train.py, and run it.
Read the inference code in src/model.py and try using the final hidden state.xx .aa .bb) as a faithful sentence embedding for other tasks. Probably you should begin with .xx and .aa/.bb (.aa divided by .bb).
Colab for fine-tuning RWKV-4 Pile models: https://colab.research.google.com/github/resloved/RWKV-notebooks/blob/master/RWKV_v4_RNN_Pile_Fine_Tuning.ipynb
**Large corpus:** Use https://github.com/EleutherAI/gpt-neox to convert .jsonl into .bin and .idx
```
python tools/preprocess_data.py --input ./my_data.jsonl --output-prefix ./data/my_data --vocab ./20B_tokenizer.json --dataset-impl mmap --tokenizer-type HFTokenizer --append-eod
```
The jsonl format sample (one line for each document):
```
{"meta": {"ID": 101}, "text": "This is the first document."}
{"meta": {"ID": 102}, "text": "Hello\nWorld"}
{"meta": {"ID": 103}, "text": "1+1=2\n1+2=3\n2+2=4"}
```
generated by code like this:
```
ss = json.dumps({"meta": meta, "text": text}, ensure_ascii=False)
out.write(ss + "\n")
```
## Towards RWKV-5 (just to record some new ideas)
### Some ideas
1. Now time decay is like 0.999^T (0.999 is learnable). Change it to something like (0.999^T + 0.1) where 0.1 is learnable too. The 0.1 part will be kept forever. Or, A^T + B^T + C = fast-decay + slow-decay + constant. Can even use different formulas (for example, K^2 instead of e^K for a decay component, or, without normalization).
2. Use complex-valued decay (so, rotation instead of decay) in some channels.
3. Inject some trainable and extrapolatable positional encoding?
4. Aside from 2d rotation, we can try other Lie groups such as 3d rotation ( SO(3) ). Non-abelian RWKV lol.
5. RWKV might be great on analog devices (search for Analog Matrix-vector multiplication & Photonic Matrix-vector multiplication). The RNN mode is very hardware-friendly (processing-in-memory). Can be a SNN too (https://github.com/ridgerchu/SpikeGPT). I wonder if it can be optimized for quantum computation.
6. Trainable initial hidden state (xx aa bb pp xx).
### Vision Tasks
1. I find it's good to add a 2d pos encoding:
```
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
...
x = x + pos_emb_x + pos_emb_y
```
2. In a BPE langauge model, it's the best to use [tokenShift of 1 token] (you can mix more tokens in a char-level English model). However you can try [tokenShift of N (or N-1) (or N+1) tokens] if the image size is N x N, because that will be like mixing [the token above the current positon (or the token above the to-be-predicted positon)] with [current token]. You can use try different tokenShift styles for "ATT" & "FFN", or mixing different tokenShift styles - such as mixing [token A] with [token A-1] and [token A-(N-1)] etc.
### Misc
I have an idea to improve tokenization. We can hardcode some channels to have meanings. Example:
Channel 0 = "space"
Channel 1 = "capitalize first letter"
Channel 2 = "capitalize all letters"
Therefore:
Embedding of "abc": [0, 0, 0, x0, x1, x2 , ..]
Embedding of " abc": [1, 0, 0, x0, x1, x2, ..]
Embedding of " Abc": [1, 1, 0, x0, x1, x2, ..]
Embedding of "ABC": [0, 0, 1, x0, x1, x2, ...]
......
so they will share most of the embedding. And we can rapidly compute the output probability of all variations of "abc".
Note: the above method is assuming that p(" xyz") / p("xyz") is the same for any "xyz", which can be wrong.
Better: define emb_space emb_capitalize_first emb_capitalize_all to be a function of emb.
Maybe the Best: let 'abc' ' abc' etc. to share the last 90% of their embeddings.
At this moment, all our tokenizers spend too many items to represent all variations of 'abc' ' abc' ' Abc' etc. Moreover the model cannot discover that these are actually similar if some of these variations are rare in the dataset. The method here can improve this. I plan to test this in a new version of RWKV.
## How it works
RWKV is inspired by Apple's AFT (https://arxiv.org/abs/2105.14103).
Moreover it's using a number of my tricks, such as:
* SmallInitEmb: https://github.com/BlinkDL/SmallInitEmb (applicable to all transformers) which helps the embedding quality, and stabilizes Post-LN (which is what I am using).
* Token-shift: https://github.com/BlinkDL/RWKV-LM#token-shift-time-shift-mixing (applicable to all transformers), especially helpful for char-level models.
* Head-QK: https://github.com/BlinkDL/RWKV-LM#the-head-qk-trick-learning-to-copy-and-avoid-tokens (applicable to all transformers). Note: it's helpful, but I disabled it in the Pile model to keep it 100% RNN.
* Extra R-gate in the FFN (applicable to all transformers). I am also using reluSquared from Primer.
* Better initilization: I init most of the matrices to ZERO (see RWKV_Init in https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v2-RNN/src/model.py).
* You can transfer some parameters from a small model to a large model (note: I sort & smooth them too), for faster and better convergence (see https://www.reddit.com/r/MachineLearning/comments/umq908/r_rwkvv2rnn_a_parallelizable_rnn_with/).
* My CUDA kernel: https://github.com/BlinkDL/RWKV-CUDA to speedup training.
## The pseudocode (execution from top to bottom):
![RWKV-v2-RNN](RWKV-v2-RNN.png)
The a b c d factors work together to build a time-decay curve: [X, 1, W, W^2, W^3, ...].
Write out the formulas for "token at pos 2" and "token at pos 3" and you will get the idea:
* a and b: EMAs of kv and k.
* c and d: these are a and b combined with "self-attention".
kv / k is the memory mechanism. The token with high k can be remembered for a long duration, if W is close to 1 in the channel.
The R-gate is important for performance. k = info strength of this token (to be passed to future tokens). r = whether to apply the info to this token.
## RWKV-3 improvements
Use different trainable TimeMix factors for R / K / V in SA and FF layers. Example:
```python
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
```
Use preLN instead of postLN (more stable & faster convergence):
```python
if self.layer_id == 0:
x = self.ln0(x)
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
```
## Explaining the code for RWKV-3 GPT mode
### The GPT mode - overview
The building blocks of RWKV-3 GPT mode are similar to that of a usual preLN GPT.
The only difference is an extra LN after embedding. Note you can absorb this LN into the embedding after finishing the training.
```python
x = self.emb(idx) # input: idx = token indices
x = self.ln_emb(x) # extra LN after embedding
x = x + self.att_0(self.ln_att_0(x)) # preLN
x = x + self.ffn_0(self.ln_ffn_0(x))
...
x = x + self.att_n(self.ln_att_n(x))
x = x + self.ffn_n(self.ln_ffn_n(x))
x = self.ln_head(x) # final LN before projection
x = self.head(x) # output: x = logits
```
It is important to initialize emb to tiny values, such as nn.init.uniform_(a=-1e-4, b=1e-4), to utilize my trick https://github.com/BlinkDL/SmallInitEmb.
For the 1.5B RWKV-3, I use Adam (no wd, no dropout) optimizer on 8 * A100 40G.
batchSz = 32 * 896, ctxLen = 896. I am using tf32 so the batchSz is a bit small.
For the first 15B tokens, LR is fixed at 3e-4, and beta=(0.9, 0.99).
Then I set beta=(0.9, 0.999), and do an exponential decay of LR, reaching 1e-5 at 332B tokens.
### The GPT mode - ATT block
The RWKV-3 does not have any attention in the usual sense, but we will call this block ATT anyway.
```python
B, T, C = x.size() # x = (Batch,Time,Channel)
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1))
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
k = torch.clamp(k, max=60) # clamp k to avoid overflow
k = torch.exp(k)
kv = k * v
# Compute the W-curve = [e^(-n * e^time_decay), e^(-(n-1) * e^time_decay), ..., 1, e^(time_first)]
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(x.device), self.time_first], dim=-1)
w = torch.exp(self.time_w)
# Use W to mix kv and k respectively. Add K_EPS to wk to avoid divide-by-zero
if RUN_DEVICE == 'cuda':
wkv = TimeX.apply(w, kv, B,C,T, 0)
wk = TimeX.apply(w, k, B,C,T, K_EPS)
else:
w = w[:,-T:].unsqueeze(1)
wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + K_EPS
# The RWKV formula
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv) # final output projection
```
The self.key, self.receptance, self.output matrices are all initialized to zero.
The time_mix, time_decay, time_first vectors are transferred from a smaller trained model (note: I sort & smooth them too).
### The GPT mode - FFN block
The FFN block has three tricks comparing with the usual GPT:
1. My time_mix trick.
2. The sqReLU from the Primer paper.
3. An extra receptance-gate (similar to the receptance-gate in ATT block).
```python
# Mix x with the previous timestep to produce xk, xr
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# The usual FFN operation
k = self.key(xk)
k = torch.square(torch.relu(k)) # from the Primer paper
kv = self.value(k)
# Apply an extra receptance-gate to kv
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
```
The self.value, self.receptance matrices are all initialized to zero.
## RWKV-4 improvements
![RWKV-v3-plan](RWKV-v3-plan.png)
## From GPT to RWKV (the formulas)
Let F[t] be the system state at t.
Let x[t] be the new external input at t.
In GPT, predicting F[t+1] requires considering F[0], F[1], .. F[t]. So it takes O(T^2) to generate a length T sequence.
The **simplified formula** for GPT:
![F[\mathrm{t}+1]=\frac{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{Q}x[\mathrm{t}] * \mathbf{K}F[\mathrm{i}]) \cdot(\mathbf{V}F[\mathrm{i}])}{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{Q}x[\mathrm{t}] * \mathbf{K}F[\mathrm{i}])}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5B%5Cmathrm%7Bt%7D%2B1%5D%3D%5Cfrac%7B%5Csum_%7B%5Cmathrm%7Bi%7D%3D0%7D%5E%7B%5Cmathrm%7Bt%7D%7D+%5Cexp+%28%5Cmathbf%7BQ%7Dx%5B%5Cmathrm%7Bt%7D%5D+%2A+%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29+%5Ccdot%28%5Cmathbf%7BV%7DF%5B%5Cmathrm%7Bi%7D%5D%29%7D%7B%5Csum_%7B%5Cmathrm%7Bi%7D%3D0%7D%5E%7B%5Cmathrm%7Bt%7D%7D+%5Cexp+%28%5Cmathbf%7BQ%7Dx%5B%5Cmathrm%7Bt%7D%5D+%2A+%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29%7D)
It's very capable in theory, however that **does not mean we can fully utilize its capability with usual optimizers**. I suspect the loss landscape is too difficult for our current methods.
Compare with the **simplified formula** for RWKV (the parallel mode, looks similar to Apple's AFT):
![F[\mathrm{t}+1]=\sigma(\mathbf{R}x[\mathrm{t}]) \cdot \frac{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i})) \cdot \exp (\mathbf{K}F[\mathrm{i}]) \cdot(\mathbf{V}F[\mathrm{i}])}{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i})) \cdot \exp (\mathbf{K }F[\mathrm{i}])}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5B%5Cmathrm%7Bt%7D%2B1%5D%3D%5Csigma%28%5Cmathbf%7BR%7Dx%5B%5Cmathrm%7Bt%7D%5D%29+%5Ccdot+%5Cfrac%7B%5Csum_%7B%5Cmathrm%7Bi%7D%3D0%7D%5E%7B%5Cmathrm%7Bt%7D%7D+%5Cexp+%28%5Cmathbf%7BW%7D+%5Ccdot%28%5Cmathrm%7Bt%7D-%5Cmathrm%7Bi%7D%29%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29+%5Ccdot%28%5Cmathbf%7BV%7DF%5B%5Cmathrm%7Bi%7D%5D%29%7D%7B%5Csum_%7B%5Cmathrm%7Bi%7D%3D0%7D%5E%7B%5Cmathrm%7Bt%7D%7D+%5Cexp+%28%5Cmathbf%7BW%7D+%5Ccdot%28%5Cmathrm%7Bt%7D-%5Cmathrm%7Bi%7D%29%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B%5Cmathrm%7Bi%7D%5D%29%7D)
The R, K, V are trainable matrices, and W is a trainable vector (time-decay factor for each channel).
In GPT, the contribution of F[i] to F[t+1] is weighted by ![ \exp (\mathbf{Q}x[\mathrm{t}] * \mathbf{K}F[\mathrm{i}]) ](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle++%5Cexp+%28%5Cmathbf%7BQ%7Dx%5B%5Cmathrm%7Bt%7D%5D+%2A+%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29+).
In RWKV-2, the contribution of F[i] to F[t+1] is weighted by ![\sigma(\mathbf{R}x[\mathrm{t}]) \cdot \exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i})) \cdot \exp (\mathbf{K}F[\mathrm{i}]) ](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+%5Csigma%28%5Cmathbf%7BR%7Dx%5B%5Cmathrm%7Bt%7D%5D%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BW%7D+%5Ccdot%28%5Cmathrm%7Bt%7D-%5Cmathrm%7Bi%7D%29%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29+).
* The ![\sigma](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+%5Csigma) is a non-linearity and we can use sigmoid.
* Note ![\sigma(\mathbf{R}x[\mathrm{t}])](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+%5Csigma%28%5Cmathbf%7BR%7Dx%5B%5Cmathrm%7Bt%7D%5D%29) is not in the denominator, and I call R the "receptance".
* The ![\exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i}))](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+%5Cexp+%28%5Cmathbf%7BW%7D+%5Ccdot%28%5Cmathrm%7Bt%7D-%5Cmathrm%7Bi%7D%29%29) is the time-decay factor. I proposed the same idea (scaling the attention by distance) in Aug 2020 and called it the "time-weighting" (check the commit history of https://github.com/BlinkDL/minGPT-tuned).
Here comes the punchline: we can rewrite it into a RNN (recursive formula). Note:
![F[1]=\sigma(\mathbf{R }x[0]) \cdot \frac{ \exp (\mathbf{K }F[0]) \cdot(\mathbf{V }F[0])}{\exp (\mathbf{K }F[0])}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5B1%5D%3D%5Csigma%28%5Cmathbf%7BR+%7Dx%5B0%5D%29+%5Ccdot+%5Cfrac%7B+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B0%5D%29+%5Ccdot%28%5Cmathbf%7BV+%7DF%5B0%5D%29%7D%7B%5Cexp+%28%5Cmathbf%7BK+%7DF%5B0%5D%29%7D)
![F[2]=\sigma(\mathbf{R }x[1]) \cdot \frac{ \exp (\mathbf{K }F[1]) \cdot(\mathbf{V }F[1])+\exp (\mathbf{W} ) \cdot \exp (\mathbf{K }F[0]) \cdot(\mathbf{V }F[0])}{ \exp (\mathbf{K }F[1])+\exp (\mathbf{W} ) \cdot \exp (\mathbf{K }F[0])}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5B2%5D%3D%5Csigma%28%5Cmathbf%7BR+%7Dx%5B1%5D%29+%5Ccdot+%5Cfrac%7B+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B1%5D%29+%5Ccdot%28%5Cmathbf%7BV+%7DF%5B1%5D%29%2B%5Cexp+%28%5Cmathbf%7BW%7D+%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B0%5D%29+%5Ccdot%28%5Cmathbf%7BV+%7DF%5B0%5D%29%7D%7B+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B1%5D%29%2B%5Cexp+%28%5Cmathbf%7BW%7D+%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B0%5D%29%7D)
Therefore it's straightforward to verify:
![F[t+1]=\sigma(\mathbf{R }x[t]) \cdot \frac{\exp (\mathbf{K}F[\mathrm{t}]) \cdot(\mathbf{V}F[\mathrm{t}])+\exp (\mathbf{W}) \cdot A[\mathrm{t}]}{ \exp (\mathbf{K}F[\mathrm{t}])+\exp (\mathbf{W}) \cdot B[\mathrm{t}]}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5Bt%2B1%5D%3D%5Csigma%28%5Cmathbf%7BR+%7Dx%5Bt%5D%29+%5Ccdot+%5Cfrac%7B%5Cexp+%28%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bt%7D%5D%29+%5Ccdot%28%5Cmathbf%7BV%7DF%5B%5Cmathrm%7Bt%7D%5D%29%2B%5Cexp+%28%5Cmathbf%7BW%7D%29+%5Ccdot+A%5B%5Cmathrm%7Bt%7D%5D%7D%7B+%5Cexp+%28%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bt%7D%5D%29%2B%5Cexp+%28%5Cmathbf%7BW%7D%29+%5Ccdot+B%5B%5Cmathrm%7Bt%7D%5D%7D)
where A[t] and B[t] are the numerator and denominator of the previous step, respectively.
I believe RWKV is performant because W is like repeatedly applying a diagonal matrix. Note (P^{-1} D P)^n = P^{-1} D^n P, so it is similar to repeatedly applying a general diagonalizable matrix.
Moreover it's possible to turn it into a continuous ODE (a bit similar to State Space Models). I will write about it later.
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=BlinkDL/RWKV-LM&type=Date)](https://star-history.com/#BlinkDL/RWKV-LM&Date)
## Multimodal ideas
I have an idea for [text --> 32x32 RGB image] using a LM (transformer, RWKV, etc.). Will test it soon.
Firstly, LM loss (instead of L2 loss), so the image will not be blurry.
Secondly, color quantization. For example, only allowing 8 levels for R/G/B. Then the image vocab size is 8x8x8 = 512 (for each pixel), instead of 2^24.
Therefore, a 32x32 RGB image = a len1024 sequence of vocab512 (image tokens), which is a typical input for usual LMs.
(Later we can use diffusion models to upsample and generate RGB888 images. We might be able to use a LM for this too.)
Thirdly, 2D positional embeddings that are easy for the model to understand.
For example, add one-hot X & Y coords to the first 64(=32+32) channels. Say if the pixel is at x=8, y=20, then we will add 1 to channel 8 and channel 52 (=32+20).
Moreover probably we can add the float X & Y coords (normalized to 0~1 range) to another 2 channels. And other periodic pos. encoding might help too (will test).
Finally, RandRound when doing the color quantization in the DataLoader.
For example, if the float level is 4.578, then there is a 57.8% chance to use 5, and (1-57.8%) chance to use 4.
And we can allow both 4 and 5 in the prediction, but the loss will be higher if the prediction is 4.
Multi-task training might help too. I will try this dataset format:
[TxtFirst] [Desc of Img (txt tokens)] [Img] [img tokens]
and sometimes
[ImgFirst] [img tokens] [Txt] [Desc of Img (txt tokens)]
... the order of the imgs should be randomized in the DataLoader, and [TxtFirst] [ImgFirst] [Img] [Txt] are special tokens
and do random sampling of the full dataset. So sometimes the model will see the img tokens first and then the corresponding txt tokens, which is a [img -> txt] task. And the model will see some partial imgs and partial txts. I think a char-level LM might help the model to write correct text on images.
## How to sample a large dataset (for training)
I am using a trick to sample the Pile deterministically yet randomly enough.
Let's say the pile has x chunks (a chunk = ctx_len tokens).
pick a prime number p just less than x, and make sure p = 2 (mod 3).
Use (step * step * step) mod p to sample it. Add some bias to step for extra randomness.
## The top-p-x sampling method (for inference)
We propose a new sampling method called top-p-x:
it's like top-p, and the only difference is you also keep all tokens whose prob > x.
Try x = 0.01 first.
## Better Learning Rate Schedule via Variantional Method of Loss Curve
I propose a simple new method to find better LR schedules. The method is cost-efficient and practical for large LMs. The takeaway is we can model the loss curve dynamics (phenomenology) w.r.t. the LR, and a nice closed-form LR curve can be directly computed from it using variantional method. Moreover we can predict the final loss with reasonable accuracy.
UPDATE: In "Conclusion 1.", use the best-fitting regime (ignore the initial steps where our approximations break down) to fit the parameters.
Try this: fixed lr for 1 hr, then exponential decay to 0.2 * lr in 12 hrs, and choose the t=[1hr, 13hr] segment.
In the last three plots, black = predicted loss curve of the new LR schedule, blue = original (unoptimized) real loss curve, orange = new LR schedule.
![better_lr_schedule](Research/better_lr_schedule.png)
# RWKV v1
# The RWKV Language Model
We propose the RWKV language model, with alternating time-mix and channel-mix layers:
@ -518,8 +20,6 @@ alt="\begin{align*}
"https://render.githubusercontent.com/render/math?math=%5Cdisplaystyle+%5Ctext%7Bsoftmax%7D_t%28%5Ctext%7BK%7D_%7Bu%2Cc%7D%29+%3D+%5Cfrac%7B%5Cexp%28%5Ctext%7BK%7D_%7Bu%2Cc%7D%29%7D%7B%5Csum_%7Bv+%5Cleq+t%7D%5Cexp%28%5Ctext%7BK%7D_%7Bv%2Cc%7D%29%7D"
alt="\text{softmax}_t(\text{K}_{u,c}) = \frac{\exp(\text{K}_{u,c})}{\sum_{v \leq t}\exp(\text{K}_{v,c})}">
**(UPDATE: We are using the original AFT normalization in v2)**
Initialize K and R matrices (and the output projection matrix) to ZERO for fast & stable convergence.
(2) We decompose W_{t,u,c} and introduce multi-head W (here h is the corresponding head of c):
@ -530,25 +30,15 @@ alt="W_{t,u,c}=f_h(t-u)\cdot \alpha_h(u) \cdot \beta_h(t)">
Moreover we multiply the final output of Time-mix layer by γ(t). The reason for the α β γ factors, is because the context size is smaller when t is small, and this can be compensated using the α β γ factors.
**(UPDATE: We remove α β γ factors in v2-RNN and restrict W to be of a simple form and hence able to rewrite it as RNN)**
* The Channel-mix is similar to GeGLU (https://arxiv.org/abs/2002.05202) with an extra R factor. Initialize R and W matrices to ZERO for fast & stable convergence.
* Finally, we add extra token-shift (time-shift mixing) as in (https://github.com/BlinkDL/minGPT-tuned).
# Token-shift (time-shift mixing)
The token-shift explicitly uses (half the channels of this token) & (half the channels of prev token) to generate all vectors (QKV, RWKV, ...).
The token-shift means explicitly using both (half channel of this token) & (half channel of prev token) to generate all vectors.
```
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
```
Dividing channels by 2 and shift-1 works great for char-level English and char-level Chinese LM.
However for BPE-level English LM, it's only effective if your embedding is large enough (at least 1024 - so the usual small L12-D768 model is not enough).
I found dividing channels by 2 and shift-1 works the best for Chinese LM. You may want to use more shift for English char-level LM. I checked the weights and found you may want to use less mixing in higher layers.
My theory on the effectiveness of token-shift:
@ -560,44 +50,19 @@ When we train a GPT, the hidden representation of a token has to accomplish two
The shifted channels can focus on (2), so we have good propagation of info. It's like some kind of residual connection, or a small RNN inside the transformer.
You can use token-shift in usual QKV self-attention too. I looked at the weights, and found V really likes the shifted channels, less so for Q. Makes sense if you think about it. I also found you may want to use less mixing in higher layers.
You can use token-shift in usual QKV self-attention too. I looked at the weights, and found V really likes the shifted channels, less so for Q. Makes sense if you think about it.
p.s. There is a MHA_pro model in this repo with strong performance. Give it a try :)
# The Head-QK Trick: learning to copy and avoid tokens
In usual transformer, a small model has difficulty copying tokens (such as person names) in the context. We add extra Q & K to the final output such that the model can directly copy (or avoid) tokens in the context. Afterwards the model will teach itself NER (named entity recognition) if you look at the learned weights.
```
q = self.head_q(x)[:,:T,:] # projecting to 256-d
k = self.head_k(x)[:,:T,:] # projecting to 256-d
c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
x = self.head(x) + c
```
Note: when a token occurs multiple times in the context, it might be better to use max(prob) instead of sum(prob).
# The top-a sampling method
# The top-a Sampling method
We also propose a new sampling method called top-a (as in src/utils.py):
(1) Find the max probability p_max after softmax.
(2) Remove all entries whose probability is lower than 0.2 * pow(p_max, 2). So it's adaptive, hence "top-a".
(3) Feel free to tune the 0.2 and 2 factor. Tune 0.2 first.
The idea of top-a:
1. If max_prob=0.9, then remove all tokens with prob < 0.162 (so, removing all alternatives)
2. If max_prob=0.5, then remove all tokens with prob < 0.05 (so, allowing more choices)
3. If max_prob=0.1, then remove all tokens with prob < 0.002 (so, allowing lots of possibilities)
(2) Remove all entries whose probability is lower than 0.02 * pow(p_max, 2). So it's adaptive, hence "top-a".
```
probs = F.softmax(logits, dim=-1)
limit = torch.pow(torch.max(probs), 2) * 0.02
logits[probs < limit] = -float('Inf')
```
(3) Feel free to tune the 0.02 and 2 factor.
# Performance

Binary file not shown.

Before

Width:  |  Height:  |  Size: 161 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 67 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 410 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 359 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 66 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 143 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 649 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 289 KiB

@ -1,172 +0,0 @@
#include <stdio.h>
// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)
#define F4(A, B) ((float4 *)(A))[(B) >> 2]
template <typename F>
__global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,
const F eps, const int B, const int C, const int T) {
const int i = blockIdx.y;
const int ij = (B * C) / BF;
const int t = threadIdx.x << 2;
__shared__ F ww[Tmax];
__shared__ F kk[Tmax * BF];
F4(ww, t) = F4(__w, t + T * (i % C));
#pragma unroll
for (int j = 0; j < BF; j++) {
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
}
__syncthreads();
float4 s[BF];
#pragma unroll
for (int j = 0; j < BF; j++) {
s[j] = {eps, eps, eps, eps};
}
const F *__restrict__ const w = ww + T - t - 4;
for (int u = 0; u <= t; u++) {
#pragma unroll
for (int j = 0; j < BF; j++) {
const F x = kk[u + Tmax * j];
s[j].x += w[u + 3] * x;
s[j].y += w[u + 2] * x;
s[j].z += w[u + 1] * x;
s[j].w += w[u + 0] * x;
}
}
#pragma unroll
for (int j = 0; j < BF; j++) {
const F *__restrict__ const k = kk + Tmax * j;
s[j].y += w[t + 3] * k[t + 1];
s[j].z += w[t + 2] * k[t + 1];
s[j].z += w[t + 3] * k[t + 2];
s[j].w += w[t + 1] * k[t + 1];
s[j].w += w[t + 2] * k[t + 2];
s[j].w += w[t + 3] * k[t + 3];
F4(x, t + T * (i + ij * j)) = s[j];
}
}
template <typename F>
__global__ void kernel_backward_W(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
F *__restrict__ const gw, F *__restrict__ const gk,
const int B, const int C, const int T) {
const int i = blockIdx.y;
const int t = threadIdx.x << 2;
__shared__ F k[Tmax];
__shared__ F gg[Tmax];
F4(k, t) = F4(__k, t + T * i);
F4(gg, t) = F4(__gwk, t + T * i);
__syncthreads();
float4 s = {0, 0, 0, 0};
const F *__restrict__ const g = gg + T - t - 4;
for (int u = 0; u <= t; u++) {
F x = k[u];
s.x += g[u + 3] * x;
s.y += g[u + 2] * x;
s.z += g[u + 1] * x;
s.w += g[u + 0] * x;
}
s.y += g[t + 3] * k[t + 1];
s.z += g[t + 2] * k[t + 1];
s.z += g[t + 3] * k[t + 2];
s.w += g[t + 1] * k[t + 1];
s.w += g[t + 2] * k[t + 2];
s.w += g[t + 3] * k[t + 3];
F4(gw, t + T * i) = s;
}
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) {
dim3 gridDim(1, B * C / BF);
dim3 blockDim(T >> 2);
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
}
template <typename F>
__global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
F *__restrict__ const gw, F *__restrict__ const gk,
const int B, const int C, const int T) {
const int i = blockIdx.y;
const int ij = (B * C) / BB;
const int t = threadIdx.x << 2;
__shared__ F w[Tmax];
__shared__ F kk[Tmax * BB];
__shared__ F gg[Tmax * BB];
F4(w, t) = F4(__w, t + T * (i % C));
#pragma unroll
for (int j = 0; j < BB; j++) {
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
F4(gg, t + Tmax * j) = F4(__gwk, t + T * (i + ij * j));
}
__syncthreads();
float4 s[BB];
#pragma unroll
for (int j = 0; j < BB; j++) {
s[j] = {0, 0, 0, 0};
}
for (int u = 0; u <= t; u++) {
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
F x = kk[u + Tmax * j];
s[j].x += g[u + 3] * x;
s[j].y += g[u + 2] * x;
s[j].z += g[u + 1] * x;
s[j].w += g[u + 0] * x;
}
}
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const k = kk + Tmax * j;
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
s[j].y += g[t + 3] * k[t + 1];
s[j].z += g[t + 2] * k[t + 1];
s[j].z += g[t + 3] * k[t + 2];
s[j].w += g[t + 1] * k[t + 1];
s[j].w += g[t + 2] * k[t + 2];
s[j].w += g[t + 3] * k[t + 3];
F4(gw, t + T * (i + ij * j)) = s[j];
}
#pragma unroll
for (int j = 0; j < BB; j++) {
s[j] = {0, 0, 0, 0};
}
for (int u = t + 3; u < T; u++) {
F x = w[u];
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
s[j].x += g[2 - u] * x;
s[j].y += g[3 - u] * x;
s[j].z += g[4 - u] * x;
s[j].w += g[5 - u] * x;
}
}
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
s[j].x += g[2 - t] * w[t + 0];
s[j].x += g[1 - t] * w[t + 1];
s[j].x += g[0 - t] * w[t + 2];
s[j].y += g[2 - t] * w[t + 1];
s[j].y += g[1 - t] * w[t + 2];
s[j].z += g[2 - t] * w[t + 2];
F4(gk, t + T * (i + ij * j)) = s[j];
}
}
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) {
dim3 gridDim(1, B * C / BB);
dim3 blockDim(T >> 2);
kernel_backward<<<gridDim, blockDim>>>(w, k, gwk, gw, gk, B, C, T);
}

@ -1,21 +0,0 @@
#include <torch/extension.h>
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T);
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T);
void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) {
cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T);
}
void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) {
cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "timex forward");
m.def("backward", &backward, "timex backward");
}
TORCH_LIBRARY(timex, m) {
m.def("forward", forward);
m.def("backward", backward);
}

Binary file not shown.

@ -1,133 +0,0 @@
# -*- coding:utf-8 -*-
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math
import time
import types
import copy
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER, Dataset
from src.model_run import RWKV_RNN
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
### Step 1: set model ##################################################################################
ctx_len = 1024
n_layer = 6
n_embd = 512
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
# your trained model
MODEL_NAME = 'trained-31'
WORD_NAME = 'vocab' # the .json vocab (generated by train.py
# ########## Uncomment these to test my 27M params enwik8 model ##########
# MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13'
# WORD_NAME = 'enwik8-vocab'
# EVAL_DATA = 'enwik8' # uncomment this for EVAL MODE (no text generation)
# ########################################################################
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> all unknown tokens in your context will be denoted by it <--
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False # True False - show softmax output
### Step 2: set context ################################################################################
context = "\nIn the" # ==> this is your prompt
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 500
TEMPERATURE = 1.0
top_p = 0.7
top_p_newline = 0.9
########################################################################################################
print(f'Loading {MODEL_NAME}...')
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
########################################################################################################
if 'EVAL_DATA' in vars() or 'EVAL_DATA' in globals():
print('Evaluating on ' + EVAL_DATA + ' ...')
data = open(EVAL_DATA, "r", encoding='utf-8').read()
loss_table = np.zeros(ctx_len)
N_SAMPLE = 1000
for iii in range(N_SAMPLE):
pos = np.random.randint(0, len(data) - ctx_len-1)
context = data[pos:pos+ctx_len+1]
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
model.clear()
for i in range(1, ctx_len+1):
x = ctx[:i]
out = model.run(x)
prob = F.softmax(torch.tensor(out), dim=-1)
loss_table[i-1] += -math.log(prob[ctx[i]])
print(f'Tested {iii+1} samples: avg_loss over ctx_len =',
np.mean(loss_table) / (iii+1))
exit(0)
########################################################################################################
context = tokenizer.refine_context(context)
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n')
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
t_begin = time.time_ns()
src_len = len(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
print(('-' * 30) + context, end='')
model.clear()
if TRIAL == 0:
init_state = types.SimpleNamespace()
for i in range(src_len):
x = ctx[:i+1]
if i == src_len - 1:
init_state.out = model.run(x)
else:
model.run(x)
model.save(init_state)
else:
model.load(init_state)
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
x = ctx[:i+1]
x = x[-ctx_len:]
if i == src_len:
out = copy.deepcopy(init_state.out)
else:
out = model.run(x)
if DEBUG_DEBUG:
print('model', np.array(x), '==>', np.array(
out), np.max(out), np.min(out))
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
top_p_usual=top_p, top_p_newline=top_p_newline)
char = char.item()
print(tokenizer.itos[int(char)], end='', flush=True)
ctx += [char]
t_end = time.time_ns()
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')

@ -1,349 +0,0 @@
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
from torch.utils.cpp_extension import load
import math
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)
########################################################################################################
# CUDA Kernel
########################################################################################################
T_MAX = 1024 # increase this if your ctx_len > 1024
B_GROUP_FORWARD = 4 # set to 8 for best performance
B_GROUP_BACKWARD = 2 # set to 2 for best performance
timex_cuda = load(name="timex", sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}'])
class TimeX(torch.autograd.Function):
@staticmethod
def forward(ctx, w, k, B, C, T, eps):
ctx.B = B
ctx.C = C
ctx.T = T
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
w = w.contiguous()
k = k.contiguous()
ctx.save_for_backward(w, k)
wk = torch.empty((B, C, T), device='cuda',
memory_format=torch.contiguous_format)
timex_cuda.forward(w, k, wk, eps, B, C, T)
return wk
@staticmethod
def backward(ctx, gwk):
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
w, k = ctx.saved_tensors
gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
memory_format=torch.contiguous_format)
gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
memory_format=torch.contiguous_format)
timex_cuda.backward(w, k, gwk.contiguous(), gw,
gk, ctx.B, ctx.C, ctx.T)
return (gw.sum(dim=0), gk, None, None, None, None)
########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
RWKV_K_CLAMP = 60 # e^60 = 1e26
RWKV_K_EPS = 1e-16
RWKV_HEAD_QK_DIM = 256
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
with torch.no_grad():
name = '[unknown weight]'
for name, parameter in module.named_parameters(): # find the name of the weight
if id(m.weight) == id(parameter):
break
shape = m.weight.data.shape
gain = 1.0
scale = 1.0 # extra scale for gain
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
scale = 1e-4
else:
scale = 0
if isinstance(m, nn.Linear):
if m.bias is not None:
m.bias.data.zero_()
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
scale = 0.5
if hasattr(m, 'scale_init'):
scale = m.scale_init
# print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
gain *= scale
if scale == -999:
nn.init.eye_(m.weight)
elif gain == 0:
# zero init is great for some RWKV matrices
nn.init.zeros_(m.weight)
elif gain > 0:
nn.init.orthogonal_(m.weight, gain=gain)
else:
nn.init.normal_(m.weight, mean=0.0, std=-scale)
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_embd = config.n_embd
attn_sz = config.n_embd
############# fancy init of time_w curves ###################################
f1_begin = 3.0
f1_end = 1.2
f2_begin = 0.65
f2_end = 0.4
with torch.no_grad(): # initial time_w curves for better convergence
decay_speed = torch.ones(attn_sz, 1)
first_sa_layer_id = 1
for h in range(attn_sz):
f1 = f1_begin + (layer_id-first_sa_layer_id) / \
(config.n_layer-1-first_sa_layer_id) * (f1_end - f1_begin)
f2 = f2_begin + (layer_id-first_sa_layer_id) / \
(config.n_layer-1-first_sa_layer_id) * (f2_end - f2_begin)
if layer_id == first_sa_layer_id:
f1 += 0.5
if layer_id == config.n_layer-2:
f2 = 0.4
if layer_id == config.n_layer-1:
f2 = 0.37
decay_speed[h][0] = math.pow(f2, h / (attn_sz-1) * 7) * f1
self.time_decay = nn.Parameter(torch.log(decay_speed)) # will use exp(self.time_decay) to ensure time_decay > 0
self.time_curve = torch.tensor(
[-(config.ctx_len - 2 - i) for i in range(config.ctx_len-1)]).unsqueeze(0)
self.time_curve = self.time_curve.to('cuda')
self.time_first = nn.Parameter(torch.ones(attn_sz, 1) * math.log(0.3))
#############################################################################
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # init to "shift half of the channels"
ww = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd // 2):
ww[0, 0, i] = 0
self.time_mix = nn.Parameter(ww)
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 0
def forward(self, x):
B, T, C = x.size()
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
k = self.key(x).transpose(-1, -2)
v = self.value(x).transpose(-1, -2)
r = self.receptance(x)
# RWKV_K_CLAMP can be removed if the CUDA kernel substracts the correct k_max for each k (I will do this later)
k = torch.clamp(k, max=RWKV_K_CLAMP)
k = torch.exp(k)
kv = k * v
self.time_w = torch.cat(
[torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1)
w = torch.exp(self.time_w)
wkv = TimeX.apply(w, kv, B, C, T, 0)
# RWKV_K_EPS can be removed if the CUDA kernel sets 0/0 = 0 (I will do this later)
wk = TimeX.apply(w, k, B, C, T, RWKV_K_EPS)
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv)
return rwkv
class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # init to "shift half of the channels"
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd // 2):
x[0, 0, i] = 0
self.time_mix = nn.Parameter(x)
hidden_sz = 4 * config.n_embd
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
self.value.scale_init = 0
self.receptance.scale_init = 0
def forward(self, x):
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
k = self.key(x)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(x)) * kv
return rkv
########################################################################################################
# The GPT Model with our blocks
########################################################################################################
class GPTConfig:
def __init__(self, vocab_size, ctx_len, **kwargs):
self.vocab_size = vocab_size
self.ctx_len = ctx_len
for k, v in kwargs.items():
setattr(self, k, v)
class Block(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.config = config
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(config, layer_id+1000)
else:
self.att = RWKV_TimeMix(config, layer_id)
self.ffn = RWKV_ChannelMix(config, layer_id)
def forward(self, x):
x = self.ln1(x)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(x) # better in some cases
else:
x = x + self.att(x)
x = self.ln2(x)
x = x + self.ffn(x)
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.step = 0
self.config = config
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.Sequential(*[Block(config, i)
for i in range(config.n_layer)])
self.ln_out = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len
RWKV_Init(self, config)
logger.info("number of parameters: %e", sum(p.numel()
for p in self.parameters()))
def get_ctx_len(self):
return self.ctx_len
def _init_weights(self, module):
if isinstance(module, (nn.Linear)):
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, (nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=1e-5)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def configure_optimizers(self, train_config):
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
for mn, m in self.named_modules(): # here we disable weight_decay
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
no_decay.add(fpn)
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(
inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
optim_groups = [
{"params": [param_dict[pn]
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.Adam(
optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
return optimizer
def forward(self, idx, targets=None):
self.step += 1
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).float()
x = self.head(x) + c
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
return x, loss

@ -1,143 +0,0 @@
import types
import copy
import torch
from torch.nn import functional as F
RWKV_K_CLAMP = 60
RWKV_K_EPS = 1e-16
RWKV_HEAD_QK_DIM = 256
DEBUG_TIME = False # True False - show trained time-coeffs
class RWKV_RNN():
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
self.RUN_DEVICE = RUN_DEVICE
self.model_type = model_type
self.n_layer = n_layer
self.n_embd = n_embd
self.ctx_len = ctx_len
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth',
map_location=torch.device(RUN_DEVICE))
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
if DEBUG_TIME and '.time_' in x:
print(x, w[x].squeeze().cpu().numpy())
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
self.hk = None
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
target.hk = copy.deepcopy(self.hk)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
self.hk = copy.deepcopy(target.hk)
def LN(self, xx, w):
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.square(torch.relu(w.key.weight @ x))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.exp(torch.clamp(w.key.weight @ x, max=RWKV_K_CLAMP))
v = w.value.weight @ x
kv = k * v
a = self.aa[name] + w.time_first * kv
b = self.bb[name] + w.time_first * k
self.aa[name] = w.time_decay * self.aa[name] + kv
self.bb[name] = w.time_decay * self.bb[name] + k
rwkv = r * a / (b + RWKV_K_EPS)
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
for i in range(self.n_layer):
x = self.LN(x, w.blocks[i].ln1)
if i == 0 and self.model_type == 'RWKV-ffnPre':
x = x + self.FF(x, w.blocks[i].ffnPre, f'ffnPre.{i}')
else:
x = x + self.SA(x, w.blocks[i].att, f'att.{i}')
x = self.LN(x, w.blocks[i].ln2)
x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
if self.hk == None:
self.hk = (w.head_k.weight @ x).unsqueeze(0)
else:
self.hk = torch.cat(
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
if self.hk.shape[0] > self.ctx_len:
self.hk = self.hk[-self.ctx_len:, :]
q = w.head_q.weight @ x
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
for i in range(len(c)):
x[ctx[i]] += c[i]
return x

@ -1,170 +0,0 @@
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
from torch.utils.data.dataloader import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import functional as F
import torch.nn as nn
import torch.optim as optim
import torch
from tqdm.auto import tqdm
import numpy as np
import logging
import os
import datetime
import sys
import math
# import wandb # comment this if you don't have wandb
# print('logging to wandb... (comment it if you don\'t have wandb)')
logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
log_file = open("mylog.txt", "a")
class TrainerConfig:
max_epochs = 10
batch_size = 64
learning_rate = 4e-4
betas = (0.9, 0.99)
eps = 1e-8
grad_norm_clip = 1.0
lr_decay = True # linear warmup followed by cosine decay
warmup_tokens = 0
final_tokens = 0
epoch_save_frequency = 0
epoch_save_path = 'trained-'
num_workers = 0 # for DataLoader
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class Trainer:
def __init__(self, model, train_dataset, test_dataset, config):
self.model = model
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.config = config
self.avg_loss = -1
self.steps = 0
if 'wandb' in sys.modules:
cfg = model.config
for k in config.__dict__:
setattr(cfg, k, config.__dict__[k]) # combine cfg
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' +
datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
self.device = 'cpu'
if torch.cuda.is_available(): # take over whatever gpus are on the system
self.device = torch.cuda.current_device()
def get_run_name(self):
raw_model = self.model.module if hasattr(
self.model, "module") else self.model
cfg = raw_model.config
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
return run_name
def train(self):
model, config = self.model, self.config
raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config)
def run_epoch(split):
is_train = split == 'train'
model.train(is_train)
data = self.train_dataset if is_train else self.test_dataset
if config.num_workers > 0:
loader = DataLoader(data, shuffle=False, pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers)
else:
loader = DataLoader(data, shuffle=False,
batch_size=config.batch_size,
num_workers=config.num_workers)
pbar = tqdm(enumerate(loader), total=len(
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
for it, (x, y) in pbar:
x = x.to(self.device) # place data on the correct device
y = y.to(self.device)
with torch.set_grad_enabled(is_train):
_, loss = model(x, y) # forward the model
if is_train: # backprop and update the parameters
model.zero_grad()
loss.backward()
if config.grad_norm_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), config.grad_norm_clip)
optimizer.step()
if config.lr_decay: # decay the learning rate based on our progress
# number of tokens processed this step (i.e. label is not -100)
self.tokens += (y >= 0).sum()
lr_final_factor = config.lr_final / config.learning_rate
if self.tokens < config.warmup_tokens:
# linear warmup
lr_mult = lr_final_factor + \
(1 - lr_final_factor) * float(self.tokens) / \
float(config.warmup_tokens)
progress = 0
else:
# cosine learning rate decay
progress = float(self.tokens - config.warmup_tokens) / float(
max(1, config.final_tokens - config.warmup_tokens))
lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor /
2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1
lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
lr = config.learning_rate
now_loss = loss.item() # report progress
self.lr = lr
if 'wandb' in sys.modules:
wandb.log({"loss": now_loss},
step=self.steps * self.config.batch_size)
self.steps += 1
if self.avg_loss < 0:
self.avg_loss = now_loss
else:
factor = 1 / (it + 1)
self.avg_loss = self.avg_loss * \
(1.0 - factor) + now_loss * factor
pbar.set_description(
f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):
run_epoch('train')
log_file.write(
f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n')
log_file.flush()
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
# DataParallel wrappers keep raw model object in .module
raw_model = self.model.module if hasattr(
self.model, "module") else self.model
torch.save(raw_model.state_dict(),
self.config.epoch_save_path + str(epoch+1) + '.pth')

@ -1,122 +0,0 @@
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import json
import random
import time
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
class Dataset(Dataset):
def __init__(self, data, ctx_len, epoch_length_fixed):
print('building token list...', end=' ')
unique = sorted(list(set(data)))
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
data_size, vocab_size = len(data), len(unique)
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.vocab_size = vocab_size
self.data = data
def __len__(self):
return self.epoch_length_fixed
def __getitem__(self, idx):
# cheat: pick a random spot in dataset
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long,
device=torch.device('cuda'))
y = torch.tensor(dix[1:], dtype=torch.long,
device=torch.device('cuda'))
return x, y
class TOKENIZER():
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
self.word_table = json.load(result_file)
self.vocab_size = len(self.word_table)
self.stoi = {v: int(k) for k, v in self.word_table.items()}
self.itos = {int(k): v for k, v in self.word_table.items()}
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
# out[self.UNKNOWN_CHAR] = -float('Inf')
lastChar = int(x[-1])
probs = F.softmax(torch.tensor(out), dim=-1)
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
sorted_probs, s_index = torch.sort(probs, descending=True)
# for j in range(30):
# pp = sorted_probs[j].item()
# if pp < 0.005:
# break
# ss = self.itos[int(s_index[j])].replace('\n','_')
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
# print('')
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
return torch.multinomial(probs, num_samples=1)[0]
def to_float(x):
return x.cpu().detach().numpy().flatten()[0].astype(float)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

@ -1,98 +0,0 @@
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import logging
import datetime
import json
from src.model import GPT, GPTConfig
from src.trainer import Trainer, TrainerConfig
from src.utils import Dataset
import torch
import numpy as np
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
### Step 1: set training data ##########################################################################
datafile = "enwik8"
datafile_encoding = 'utf-8'
# datafile_encoding = 'utf-16le'
### Step 2: set model size #############################################################################
ctx_len = 1024 # ===> increase T_MAX in model.py if your ctx_len > 1024
n_layer = 6
n_embd = 512
# 'RWKV' (better for char-level English) or 'RWKV-ffnPre' (better in some cases)
model_type = 'RWKV'
### Step 3: set batch size #############################################################################
# ===> batch_size must be divisible by B_GROUP_FORWARD and B_GROUP_BACKWARD in model.py
# For example, if your batch_size = 20, you can set B_GROUP_FORWARD = 4, B_GROUP_BACKWARD = 2
# If you see "CUDA out of memory", reduce it. Use GPU-Z to find the highest value for your VRAM.
batch_size = 12
### Step 4: set learning rate, training mini-epochs #######################################################
lr_init = 6e-4
lr_final = 1e-5
# the mini-epoch is very short and of fixed length (ctx_len * epoch_length_fixed tokens)
n_epoch = 500
# 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, etc.
epoch_save_frequency = 30
epoch_save_path = 'trained-'
epoch_length_fixed = 10000
########################################################################################################
# import src.utils
# src.utils.set_seed(42) # remember to change seed if you load a model
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
grad_norm_clip = 1.0
warmup_tokens = 0
betas = (0.9, 0.99)
eps = 4e-9
num_workers = 0
########################################################################################################
# Load data
########################################################################################################
print('loading data... ' + datafile)
train_dataset = Dataset(open(
datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)
########################################################################################################
# Train model
########################################################################################################
if __name__ == '__main__':
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
n_layer=n_layer, n_embd=n_embd)).cuda()
# # # load a trained model. remember to change random seed
# m2 = torch.load('trained-61.pth')
# model.load_state_dict(m2)
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, )
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip,
warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()
torch.save(model.state_dict(), 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() +
'-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')

Binary file not shown.

Before

Width:  |  Height:  |  Size: 121 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 321 KiB

@ -1,172 +0,0 @@
#include <stdio.h>
// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)
#define F4(A, B) ((float4 *)(A))[(B) >> 2]
template <typename F>
__global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,
const F eps, const int B, const int C, const int T) {
const int i = blockIdx.y;
const int ij = (B * C) / BF;
const int t = threadIdx.x << 2;
__shared__ F ww[Tmax];
__shared__ F kk[Tmax * BF];
F4(ww, t) = F4(__w, t + T * (i % C));
#pragma unroll
for (int j = 0; j < BF; j++) {
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
}
__syncthreads();
float4 s[BF];
#pragma unroll
for (int j = 0; j < BF; j++) {
s[j] = {eps, eps, eps, eps};
}
const F *__restrict__ const w = ww + T - t - 4;
for (int u = 0; u <= t; u++) {
#pragma unroll
for (int j = 0; j < BF; j++) {
const F x = kk[u + Tmax * j];
s[j].x += w[u + 3] * x;
s[j].y += w[u + 2] * x;
s[j].z += w[u + 1] * x;
s[j].w += w[u + 0] * x;
}
}
#pragma unroll
for (int j = 0; j < BF; j++) {
const F *__restrict__ const k = kk + Tmax * j;
s[j].y += w[t + 3] * k[t + 1];
s[j].z += w[t + 2] * k[t + 1];
s[j].z += w[t + 3] * k[t + 2];
s[j].w += w[t + 1] * k[t + 1];
s[j].w += w[t + 2] * k[t + 2];
s[j].w += w[t + 3] * k[t + 3];
F4(x, t + T * (i + ij * j)) = s[j];
}
}
template <typename F>
__global__ void kernel_backward_W(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
F *__restrict__ const gw, F *__restrict__ const gk,
const int B, const int C, const int T) {
const int i = blockIdx.y;
const int t = threadIdx.x << 2;
__shared__ F k[Tmax];
__shared__ F gg[Tmax];
F4(k, t) = F4(__k, t + T * i);
F4(gg, t) = F4(__gwk, t + T * i);
__syncthreads();
float4 s = {0, 0, 0, 0};
const F *__restrict__ const g = gg + T - t - 4;
for (int u = 0; u <= t; u++) {
F x = k[u];
s.x += g[u + 3] * x;
s.y += g[u + 2] * x;
s.z += g[u + 1] * x;
s.w += g[u + 0] * x;
}
s.y += g[t + 3] * k[t + 1];
s.z += g[t + 2] * k[t + 1];
s.z += g[t + 3] * k[t + 2];
s.w += g[t + 1] * k[t + 1];
s.w += g[t + 2] * k[t + 2];
s.w += g[t + 3] * k[t + 3];
F4(gw, t + T * i) = s;
}
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) {
dim3 gridDim(1, B * C / BF);
dim3 blockDim(T >> 2);
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
}
template <typename F>
__global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
F *__restrict__ const gw, F *__restrict__ const gk,
const int B, const int C, const int T) {
const int i = blockIdx.y;
const int ij = (B * C) / BB;
const int t = threadIdx.x << 2;
__shared__ F w[Tmax];
__shared__ F kk[Tmax * BB];
__shared__ F gg[Tmax * BB];
F4(w, t) = F4(__w, t + T * (i % C));
#pragma unroll
for (int j = 0; j < BB; j++) {
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
F4(gg, t + Tmax * j) = F4(__gwk, t + T * (i + ij * j));
}
__syncthreads();
float4 s[BB];
#pragma unroll
for (int j = 0; j < BB; j++) {
s[j] = {0, 0, 0, 0};
}
for (int u = 0; u <= t; u++) {
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
F x = kk[u + Tmax * j];
s[j].x += g[u + 3] * x;
s[j].y += g[u + 2] * x;
s[j].z += g[u + 1] * x;
s[j].w += g[u + 0] * x;
}
}
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const k = kk + Tmax * j;
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
s[j].y += g[t + 3] * k[t + 1];
s[j].z += g[t + 2] * k[t + 1];
s[j].z += g[t + 3] * k[t + 2];
s[j].w += g[t + 1] * k[t + 1];
s[j].w += g[t + 2] * k[t + 2];
s[j].w += g[t + 3] * k[t + 3];
F4(gw, t + T * (i + ij * j)) = s[j];
}
#pragma unroll
for (int j = 0; j < BB; j++) {
s[j] = {0, 0, 0, 0};
}
for (int u = t + 3; u < T; u++) {
F x = w[u];
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
s[j].x += g[2 - u] * x;
s[j].y += g[3 - u] * x;
s[j].z += g[4 - u] * x;
s[j].w += g[5 - u] * x;
}
}
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
s[j].x += g[2 - t] * w[t + 0];
s[j].x += g[1 - t] * w[t + 1];
s[j].x += g[0 - t] * w[t + 2];
s[j].y += g[2 - t] * w[t + 1];
s[j].y += g[1 - t] * w[t + 2];
s[j].z += g[2 - t] * w[t + 2];
F4(gk, t + T * (i + ij * j)) = s[j];
}
}
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) {
dim3 gridDim(1, B * C / BB);
dim3 blockDim(T >> 2);
kernel_backward<<<gridDim, blockDim>>>(w, k, gwk, gw, gk, B, C, T);
}

@ -1,21 +0,0 @@
#include <torch/extension.h>
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T);
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T);
void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) {
cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T);
}
void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) {
cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "timex forward");
m.def("backward", &backward, "timex backward");
}
TORCH_LIBRARY(timex, m) {
m.def("forward", forward);
m.def("backward", backward);
}

@ -1,98 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math
import time
import types
import copy
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER, Dataset
from src.model_run import RWKV_RNN
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
### Step 1: set model ##################################################################################
ctx_len = 1024
n_layer = 6
n_embd = 512
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
# your trained model
MODEL_NAME = 'trained-1'
WORD_NAME = 'vocab' # the .json vocab (generated by train.py
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> all unknown tokens in your context will be denoted by it <--
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False # True False - show softmax output
### Step 2: set context ################################################################################
context = "\nIn the" # ==> this is your prompt
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 500
TEMPERATURE = 1.0
top_p = 0.7
top_p_newline = 0.9
########################################################################################################
print(f'Loading {MODEL_NAME}...')
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
########################################################################################################
context = tokenizer.refine_context(context)
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n')
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
t_begin = time.time_ns()
src_len = len(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
print(('-' * 30) + context, end='')
model.clear()
if TRIAL == 0:
init_state = types.SimpleNamespace()
for i in range(src_len):
x = ctx[:i+1]
if i == src_len - 1:
init_state.out = model.run(x)
else:
model.run(x)
model.save(init_state)
else:
model.load(init_state)
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
x = ctx[:i+1]
x = x[-ctx_len:]
if i == src_len:
out = copy.deepcopy(init_state.out)
else:
out = model.run(x)
if DEBUG_DEBUG:
print('model', np.array(x), '==>', np.array(
out), np.max(out), np.min(out))
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
top_p_usual=top_p, top_p_newline=top_p_newline)
char = char.item()
print(tokenizer.itos[int(char)], end='', flush=True)
ctx += [char]
t_end = time.time_ns()
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')

@ -1,363 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
from torch.utils.cpp_extension import load
import math
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)
RWKV_K_CLAMP = 60 # e^60 = 1e26
RWKV_K_EPS = 1e-8
RWKV_HEAD_QK_DIM = 256
print(f'\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
########################################################################################################
# CUDA Kernel
########################################################################################################
T_MAX = 1024 # increase this if your ctx_len > 1024
B_GROUP_FORWARD = 4 # set to 8 for best performance
B_GROUP_BACKWARD = 2 # set to 2 for best performance (sometimes 8 is faster)
timex_cuda = load(name="timex", sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}'])
class TimeX(torch.autograd.Function):
@staticmethod
def forward(ctx, w, k, B, C, T, eps):
ctx.B = B
ctx.C = C
ctx.T = T
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
w = w.contiguous()
k = k.contiguous()
ctx.save_for_backward(w, k)
wk = torch.empty((B, C, T), device='cuda',
memory_format=torch.contiguous_format)
timex_cuda.forward(w, k, wk, eps, B, C, T)
return wk
@staticmethod
def backward(ctx, gwk):
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
w, k = ctx.saved_tensors
gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
memory_format=torch.contiguous_format)
gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
memory_format=torch.contiguous_format)
timex_cuda.backward(w, k, gwk.contiguous(), gw,
gk, ctx.B, ctx.C, ctx.T)
return (gw.sum(dim=0), gk, None, None, None, None)
########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
with torch.no_grad():
name = '[unknown weight]'
for name, parameter in module.named_parameters(): # find the name of the weight
if id(m.weight) == id(parameter):
break
shape = m.weight.data.shape
gain = 1.0
scale = 1.0 # extra scale for gain
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
scale = 1e-4
else:
scale = 0
if isinstance(m, nn.Linear):
if m.bias is not None:
m.bias.data.zero_()
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
scale = 0.5
if hasattr(m, 'scale_init'):
scale = m.scale_init
# print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
gain *= scale
if scale == -999:
nn.init.eye_(m.weight)
elif gain == 0:
# zero init is great for some RWKV matrices
nn.init.zeros_(m.weight)
elif gain > 0:
nn.init.orthogonal_(m.weight, gain=gain)
else:
nn.init.normal_(m.weight, mean=0.0, std=-scale)
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_embd = config.n_embd
attn_sz = config.n_embd
with torch.no_grad(): # fancy init
self.time_curve = torch.tensor([-(config.ctx_len - 2 - i) for i in range(config.ctx_len-1)]).unsqueeze(0)
self.time_curve = self.time_curve.to('cuda')
ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
# fancy time_decay
decay_speed = torch.ones(attn_sz, 1)
for h in range(attn_sz):
decay_speed[h][0] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1)
self.time_decay = nn.Parameter(decay_speed)
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
# fancy time_first
zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5).unsqueeze(1)
self.time_first = nn.Parameter(torch.ones(attn_sz, 1) * math.log(0.3) + zigzag)
# fancy time_mix
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd):
x[0, 0, i] = i / config.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 0
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1))
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
# RWKV_K_CLAMP can be removed if the CUDA kernel substracts the correct k_max for each k (I will do this later)
k = torch.clamp(k, max=RWKV_K_CLAMP) # clamp k to avoid overflow
k = torch.exp(k)
kv = k * v
# Compute the W-curve = [e^(-n * e^time_decay), e^(-(n-1) * e^time_decay), ..., 1, e^(time_first)]
self.time_w = torch.cat(
[torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1)
w = torch.exp(self.time_w)
# Use W to mix kv and k respectively. Add K_EPS to wk to avoid divide-by-zero
wkv = TimeX.apply(w, kv, B, C, T, 0)
# RWKV_K_EPS can be removed if the CUDA kernel sets 0/0 = 0 (I will do this later)
wk = TimeX.apply(w, k, B, C, T, RWKV_K_EPS)
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv)
return rwkv
class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd):
x[0, 0, i] = i / config.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
hidden_sz = 4 * config.n_embd
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
self.value.scale_init = 0
self.receptance.scale_init = 0
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
########################################################################################################
# The GPT Model with our blocks
########################################################################################################
class GPTConfig:
def __init__(self, vocab_size, ctx_len, **kwargs):
self.vocab_size = vocab_size
self.ctx_len = ctx_len
for k, v in kwargs.items():
setattr(self, k, v)
class Block(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.config = config
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(config, layer_id+1000)
else:
self.att = RWKV_TimeMix(config, layer_id)
self.ffn = RWKV_ChannelMix(config, layer_id)
def forward(self, x):
if self.layer_id == 0:
x = self.ln0(x)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(self.ln1(x)) # better in some cases
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.step = 0
self.config = config
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.Sequential(*[Block(config, i)
for i in range(config.n_layer)])
self.ln_out = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
if RWKV_HEAD_QK_DIM > 0:
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len
RWKV_Init(self, config)
logger.info("number of parameters: %e", sum(p.numel()
for p in self.parameters()))
def get_ctx_len(self):
return self.ctx_len
def _init_weights(self, module):
if isinstance(module, (nn.Linear)):
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, (nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=1e-5)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def configure_optimizers(self, train_config):
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
for mn, m in self.named_modules(): # here we disable weight_decay
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
no_decay.add(fpn)
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(
inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
optim_groups = [
{"params": [param_dict[pn]
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.Adam(
optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
return optimizer
def forward(self, idx, targets=None):
self.step += 1
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
if RWKV_HEAD_QK_DIM > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).float()
x = self.head(x) + c
else:
x = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
return x, loss

@ -1,319 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import types
import copy
import torch
import math
from torch.nn import functional as F
import torch.nn as nn
RWKV_K_CLAMP = 60
RWKV_K_EPS = 1e-8
RWKV_HEAD_QK_DIM = 256
print(f'\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
DEBUG_TIME = False # True False - show trained time-coeffs
############################################################################################################
RWKV_CFG = types.SimpleNamespace()
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
hidden_sz = 4 * RWKV_CFG.n_embd
self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False)
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
class RWKV_TimeMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1))
self.time_curve = torch.tensor([-(RWKV_CFG.ctx_len - 2 - i) for i in range(RWKV_CFG.ctx_len-1)]).unsqueeze(0)
self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1) * math.log(0.3))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
def forward(self, x):
B, T, C = x.size()
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
k = torch.clamp(k, max=RWKV_K_CLAMP)
k = torch.exp(k)
kv = k * v
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
w = torch.exp(self.time_w)
w = w[:,-T:].unsqueeze(1)
wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + RWKV_K_EPS
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv)
return rwkv
class Block(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd)
self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd)
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(layer_id+1000)
else:
self.att = RWKV_TimeMix(layer_id)
self.ffn = RWKV_ChannelMix(layer_id)
def forward(self, x):
if self.layer_id == 0:
x = self.ln0(x)
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(self.ln1(x))
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class RWKV_GPT(nn.Module):
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len):
global RWKV_CFG
super().__init__()
RWKV_CFG.RUN_DEVICE = RUN_DEVICE
RWKV_CFG.model_type = model_type
RWKV_CFG.vocab_size = vocab_size
RWKV_CFG.n_layer = n_layer
RWKV_CFG.n_embd = n_embd
RWKV_CFG.ctx_len = ctx_len
print('\nloading RWKV-GPT', MODEL_NAME)
self.emb = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
self.ln_out = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
if RWKV_HEAD_QK_DIM > 0:
self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(ctx_len, ctx_len)))
self.ctx_len = ctx_len
self.eval()
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
self.eval()
def forward(self, idx):
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
if RWKV_HEAD_QK_DIM > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).float()
x = self.head(x) + c
else:
x = self.head(x)
return x
############################################################################################################
class RWKV_RNN():
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
self.RUN_DEVICE = RUN_DEVICE
self.model_type = model_type
self.n_layer = n_layer
self.n_embd = n_embd
self.ctx_len = ctx_len
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth',
map_location=torch.device(RUN_DEVICE))
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
if DEBUG_TIME and '.time_' in x:
print(x, w[x].squeeze().cpu().numpy())
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
self.hk = None
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
target.hk = copy.deepcopy(self.hk)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
self.hk = copy.deepcopy(target.hk)
def LN(self, xx, w):
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.square(torch.relu(w.key.weight @ xk))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.exp(torch.clamp(w.key.weight @ xk, max=RWKV_K_CLAMP))
v = w.value.weight @ xv
kv = k * v
a = self.aa[name] + w.time_first * kv
b = self.bb[name] + w.time_first * k
self.aa[name] = w.time_decay * self.aa[name] + kv
self.bb[name] = w.time_decay * self.bb[name] + k
rwkv = r * a / (b + RWKV_K_EPS)
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
for i in range(self.n_layer):
if i == 0:
x = self.LN(x, w.blocks[i].ln0)
if i == 0 and self.model_type == 'RWKV-ffnPre':
x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
else:
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
if RWKV_HEAD_QK_DIM > 0:
if self.hk == None:
self.hk = (w.head_k.weight @ x).unsqueeze(0)
else:
self.hk = torch.cat(
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
if self.hk.shape[0] > self.ctx_len:
self.hk = self.hk[-self.ctx_len:, :]
q = w.head_q.weight @ x
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
for i in range(len(c)):
x[ctx[i]] += c[i]
else:
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
return x

@ -1,171 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
from torch.utils.data.dataloader import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import functional as F
import torch.nn as nn
import torch.optim as optim
import torch
from tqdm.auto import tqdm
import numpy as np
import logging
import os
import datetime
import sys
import math
# import wandb # comment this if you don't have wandb
# print('logging to wandb... (comment it if you don\'t have wandb)')
logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
log_file = open("mylog.txt", "a")
class TrainerConfig:
max_epochs = 10
batch_size = 64
learning_rate = 4e-4
betas = (0.9, 0.99)
eps = 1e-8
grad_norm_clip = 1.0
lr_decay = True # linear warmup followed by cosine decay
warmup_tokens = 0
final_tokens = 0
epoch_save_frequency = 0
epoch_save_path = 'trained-'
num_workers = 0 # for DataLoader
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class Trainer:
def __init__(self, model, train_dataset, test_dataset, config):
self.model = model
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.config = config
self.avg_loss = -1
self.steps = 0
if 'wandb' in sys.modules:
cfg = model.config
for k in config.__dict__:
setattr(cfg, k, config.__dict__[k]) # combine cfg
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' +
datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
self.device = 'cpu'
if torch.cuda.is_available(): # take over whatever gpus are on the system
self.device = torch.cuda.current_device()
def get_run_name(self):
raw_model = self.model.module if hasattr(
self.model, "module") else self.model
cfg = raw_model.config
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
return run_name
def train(self):
model, config = self.model, self.config
raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config)
def run_epoch(split):
is_train = split == 'train'
model.train(is_train)
data = self.train_dataset if is_train else self.test_dataset
if config.num_workers > 0:
loader = DataLoader(data, shuffle=False, pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers)
else:
loader = DataLoader(data, shuffle=False,
batch_size=config.batch_size,
num_workers=config.num_workers)
pbar = tqdm(enumerate(loader), total=len(
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
for it, (x, y) in pbar:
x = x.to(self.device) # place data on the correct device
y = y.to(self.device)
with torch.set_grad_enabled(is_train):
_, loss = model(x, y) # forward the model
if is_train: # backprop and update the parameters
model.zero_grad()
loss.backward()
if config.grad_norm_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), config.grad_norm_clip)
optimizer.step()
if config.lr_decay: # decay the learning rate based on our progress
# number of tokens processed this step (i.e. label is not -100)
self.tokens += (y >= 0).sum()
lr_final_factor = config.lr_final / config.learning_rate
if self.tokens < config.warmup_tokens:
# linear warmup
lr_mult = lr_final_factor + \
(1 - lr_final_factor) * float(self.tokens) / \
float(config.warmup_tokens)
progress = 0
else:
# exponential learning rate decay
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
if progress >= 1:
lr_mult = lr_final_factor
else:
lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1))
lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
lr = config.learning_rate
now_loss = loss.item() # report progress
self.lr = lr
if 'wandb' in sys.modules:
wandb.log({"loss": now_loss},
step=self.steps * self.config.batch_size)
self.steps += 1
if self.avg_loss < 0:
self.avg_loss = now_loss
else:
factor = 1 / (it + 1)
self.avg_loss = self.avg_loss * \
(1.0 - factor) + now_loss * factor
pbar.set_description(
f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):
run_epoch('train')
log_file.write(
f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n')
log_file.flush()
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
# DataParallel wrappers keep raw model object in .module
raw_model = self.model.module if hasattr(
self.model, "module") else self.model
torch.save(raw_model.state_dict(),
self.config.epoch_save_path + str(epoch+1) + '.pth')

@ -1,122 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import json
import random
import time
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
class Dataset(Dataset):
def __init__(self, data, ctx_len, epoch_length_fixed):
print('building token list...', end=' ')
unique = sorted(list(set(data)))
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
data_size, vocab_size = len(data), len(unique)
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.vocab_size = vocab_size
self.data = data
def __len__(self):
return self.epoch_length_fixed
def __getitem__(self, idx):
# cheat: pick a random spot in dataset
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long,
device=torch.device('cuda'))
y = torch.tensor(dix[1:], dtype=torch.long,
device=torch.device('cuda'))
return x, y
class TOKENIZER():
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
self.word_table = json.load(result_file)
self.vocab_size = len(self.word_table)
self.stoi = {v: int(k) for k, v in self.word_table.items()}
self.itos = {int(k): v for k, v in self.word_table.items()}
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
# out[self.UNKNOWN_CHAR] = -float('Inf')
lastChar = int(x[-1])
probs = F.softmax(torch.tensor(out), dim=-1)
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
sorted_probs, s_index = torch.sort(probs, descending=True)
# for j in range(30):
# pp = sorted_probs[j].item()
# if pp < 0.005:
# break
# ss = self.itos[int(s_index[j])].replace('\n','_')
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
# print('')
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
return torch.multinomial(probs, num_samples=1)[0]
def to_float(x):
return x.cpu().detach().numpy().flatten()[0].astype(float)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

@ -1,118 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os
# if False: # True False ---> Set to False if you don't understand it
# print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n")
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# import src.utils
# src.utils.set_seed(42) # make training deterministic (including dataloader). if you are doing this, remember to change seed when you load a model (otherwise the dataloader loads old samples)
import logging
import datetime
from src.model import GPT, GPTConfig
from src.trainer import Trainer, TrainerConfig
from src.utils import Dataset
import torch
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
### Step 1: set training data ##########################################################################
datafile = "../data/enwik8" # your data
datafile_encoding = 'utf-8'
# datafile_encoding = 'utf-16le'
### Step 2: set model size #############################################################################
# ----> test deeper models (n_layer at least 12) to see the advantage of RWKV-3 over RWKV-2
ctx_len = 1024 # increase T_MAX in model.py if your ctx_len > 1024
n_layer = 6
n_embd = 512
# 'RWKV' (better for English) or 'RWKV-ffnPre' (better in some cases)
model_type = 'RWKV'
# ---> there is a RWKV_HEAD_QK_DIM in model.py and model_run.py
# set it to 256, then it's using my headQK trick (similar to a tiny attention) to improve loss
# set it to 0, then it's a pure RNN (attention-free)
### Step 3: set batch size #############################################################################
# ---> batch_size must be divisible by B_GROUP_FORWARD and B_GROUP_BACKWARD in model.py
# for example, if your batch_size = 20, you can set B_GROUP_FORWARD = 4, B_GROUP_BACKWARD = 2
# if you see "CUDA out of memory", reduce batch_size. Use nvidia-smi to find the highest value for your GPU.
batch_size = 12
### Step 4: set learning rate, number of mini-epochs #######################################################
# By default we are using exponential LR decay.
#
# Here are my suggestions for training a good model.
# Let's say you will train a L6-D512 model.
# 1) Set lr_init = lr_final = 8e-4. Let it run for some mini-epochs, until the improvement of loss become slow.
# 2) Check epoch_save_frequency and make sure the partially-trained model is saved. Ctrl+C to stop the run.
# 3) Set lr_init = 8e-4, lr_final = 1e-5, warmup_tokens = ctx_len * batch_size * 50, betas = (0.9, 0.999).
# 4) Search for "torch.load" here and modify it to load the partially-trained model. Continue the training.
#
# For L12-D768, set lr_init = 6e-4. For L24-D1024, set lr_init = 4e-4. For L24-D2048, set lr_init = 3e-4.
lr_init = 8e-4 # we can use larger lr because of preLN
lr_final = 1e-5
# the mini-epoch is very short and of fixed length (length = ctx_len * epoch_length_fixed tokens)
n_epoch = 500
epoch_length_fixed = 10000
# 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, ...
epoch_save_frequency = 10
epoch_save_path = 'trained-'
########################################################################################################
grad_norm_clip = 1.0
warmup_tokens = ctx_len * batch_size * 0
betas = (0.9, 0.99)
eps = 4e-9
num_workers = 0
########################################################################################################
# Load data
########################################################################################################
print('loading data... ' + datafile)
train_dataset = Dataset(open(
datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)
########################################################################################################
# Train model
########################################################################################################
if __name__ == '__main__':
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
n_layer=n_layer, n_embd=n_embd)).cuda()
### ---> load a trained model <---
# m2 = torch.load('trained-61.pth')
# model.load_state_dict(m2)
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, )
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip,
warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()
torch.save(model.state_dict(), 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() +
'-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')

@ -1,65 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
# this is for verifying the results of different models and make sure they agree with each other
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
RUN_DEVICE = 'cuda'
import torch
from src.model_run import RWKV_RNN, RWKV_GPT
from src.model import GPT, GPTConfig
ctx_len = 1024
n_layer = 6
n_embd = 512
model_type = 'RWKV'
model_name = 'trained-1'
from src.utils import TOKENIZER
tokenizer = TOKENIZER('vocab', UNKNOWN_CHAR=' ')
########################################################################################################
model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda()
print('loading ' + model_name)
m2 = torch.load(model_name + '.pth', map_location=RUN_DEVICE)
model_train.load_state_dict(m2)
model_rnn = RWKV_RNN(model_name, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
model_gpt = RWKV_GPT(model_name, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len).cuda()
########################################################################################################
context = '\nIn a'
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
print(f'input len {len(ctx)} data {ctx}')
########################################################################################################
print('\nRWKV-GPT output')
out = model_gpt.forward(torch.tensor(ctx).unsqueeze(0).cuda())[0].detach().cpu().numpy()
print(out)
print('\nRWKV-RNN output')
model_rnn.clear()
src_len = len(ctx)
for i in range(src_len):
x = ctx[:i+1]
out = model_rnn.run(x)
if i < 3 or i >= src_len - 3:
print(torch.tensor(out).detach().cpu().numpy())
if i == 2:
print('...')
print('\nRWKV-train output')
ctx += [0] * (ctx_len - src_len) # pad to ctx_len
ctx = [ctx] * 4 # increase batch size (to make it work with B_GROUP_FORWARD & B_GROUP_BACKWARD)
out = model_train.forward(torch.tensor(ctx).cuda())[0][0][:src_len].detach().cpu().numpy()
print(out, '\n')

Binary file not shown.

Before

Width:  |  Height:  |  Size: 70 KiB

File diff suppressed because it is too large Load Diff

@ -1,125 +0,0 @@
#include <stdio.h>
#include <assert.h>
#define MIN_VALUE (-1e38)
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
F *__restrict__ const _y) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
F *__restrict__ const y = _y + _offset;
F p = 0, q = 0, o = MIN_VALUE;
// p and q are running sums divided by exp(o) (to avoid overflows)
for (int i = 0; i < T; i++) {
const int ii = i * C;
F no = max(o, u + k[ii]);
F A = exp(o - no);
F B = exp(u + k[ii] - no);
y[ii] = (A * p + B * v[ii]) / (A * q + B);
no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
}
}
template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
const F *__restrict__ const gy = _gy + _offset;
F *__restrict__ const gk = _gk + _offset;
F *__restrict__ const gv = _gv + _offset;
F y[Tmax], z[Tmax], zexp[Tmax];
F gw = 0, gu = 0;
F p = 0, q = 0;
F dpdw = 0, dqdw = 0;
F o = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
F no = max(o, k[ii] + u);
F A = exp(o - no);
F B = exp(k[ii] + u - no);
F num = A * p + B * v[ii];
F iden = 1 / (A * q + B);
y[i] = num * iden;
z[i] = iden;
zexp[i] = k[ii] + u - no;
gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
gu += gy[ii] * (v[ii] - y[i]) * B * iden;
no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
dpdw = A * (p + dpdw);
dqdw = A * (q + dqdw);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
}
F gp = 0, gq = 0;
o = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) {
const int ii = i * C;
F A = gy[ii] * z[i] * exp(zexp[i]);
F B = exp(k[ii] + o);
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
gv[ii] = A + B * gp;
F no = max(w + o, zexp[i] - k[ii] - u);
A = exp(w + o - no);
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
gp = A * gp + B;
gq = A * gq - B * y[i];
o = no;
}
// Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] += gw * _w[_c];
_gu[_offsetBC] += gu;
}
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
}
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
}

@ -1,21 +0,0 @@
#include <torch/extension.h>
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv forward");
m.def("backward", &backward, "wkv backward");
}
TORCH_LIBRARY(wkv, m) {
m.def("forward", forward);
m.def("backward", backward);
}

@ -1,149 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math, os
import time
import types
import copy
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER, Dataset
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
########################################################################################################
# Step 1: set model
#
# Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch.
#
# Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models.
########################################################################################################
TOKEN_MODE = 'char' # char / bpe / pile
n_layer = 6
n_embd = 512
ctx_len = 1024
if TOKEN_MODE == 'char':
MODEL_NAME = 'trained-500' # your trained model
WORD_NAME = 'vocab' # the .json vocab (generated by train.py)
# set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it
UNKNOWN_CHAR = ' ' # here we just set it to ' ' for simplicity
elif TOKEN_MODE == 'bpe':
MODEL_NAME = 'trained-500' # your trained model
WORD_NAME = ['model-vocab.json', 'model-merges.txt'] # [vocab, merge] for your BPE model
UNKNOWN_CHAR = None
elif TOKEN_MODE == 'pile':
WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json']
UNKNOWN_CHAR = None
#---> you can set MODEL_NAME to your fine-tuned model <---
MODEL_NAME = 'RWKV-4-Pile-169M-20220807-8023'
# MODEL_NAME = 'trained-11'
n_layer = 12
n_embd = 768
ctx_len = 1024
# MODEL_NAME = 'RWKV-4-Pile-430M-20220808-8066'
# n_layer = 24
# n_embd = 1024
# ctx_len = 1024
# MODEL_NAME = 'RWKV-4-Pile-1B5-20220903-8040'
# n_layer = 24
# n_embd = 2048
# ctx_len = 1024
os.environ['RWKV_FLOAT_MODE'] = 'fp32' # 'bf16' / 'fp16' / 'fp32' (note: only using fp32 at this moment)
os.environ['RWKV_RUN_DEVICE'] = 'cpu' # 'cpu' (already very fast) or 'cuda'
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
########################################################################################################
# Step 2: set prompt & sampling stuffs
########################################################################################################
# context = 'A'
# context = "\nIn the"
# context = '\nSugar:'
context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.'
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 333
TEMPERATURE = 1.0
top_p = 0.7
top_p_newline = 0.9 # only used in TOKEN_MODE = char
DEBUG_DEBUG = False # True False --> show softmax output
########################################################################################################
print(f'Loading {MODEL_NAME}...')
from src.model_run import RWKV_RNN
model = RWKV_RNN(MODEL_NAME, os.environ['RWKV_RUN_DEVICE'], model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
########################################################################################################
if tokenizer.charMode:
context = tokenizer.refine_context(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
else:
ctx = tokenizer.tokenizer.encode(context)
src_len = len(ctx)
src_ctx = ctx.copy()
print('\nYour prompt has ' + str(src_len) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n')
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
t_begin = time.time_ns()
print(('-' * 30) + context, end='')
ctx = src_ctx.copy()
model.clear()
if TRIAL == 0:
init_state = types.SimpleNamespace()
for i in range(src_len):
x = ctx[:i+1]
if i == src_len - 1:
init_state.out = model.run(x)
else:
model.run(x)
model.save(init_state)
else:
model.load(init_state)
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
x = ctx[:i+1]
x = x[-ctx_len:]
if i == src_len:
out = copy.deepcopy(init_state.out)
else:
out = model.run(x)
if DEBUG_DEBUG:
print('model', np.array(x), '==>', np.array(
out), np.max(out), np.min(out))
if TOKEN_MODE == 'pile':
out[0] = -999999999 # disable <|endoftext|>
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
top_p_usual=top_p, top_p_newline=top_p_newline)
char = char.item()
if tokenizer.charMode:
print(tokenizer.itos[int(char)], end='', flush=True)
else:
print(tokenizer.tokenizer.decode(int(char)), end='', flush=True)
ctx += [char]
t_end = time.time_ns()
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')

@ -1,216 +0,0 @@
from lib2to3.pgen2 import token
import os
import torch
import numpy as np
import shutil
import struct
from functools import lru_cache
from itertools import accumulate
def print_rank_0(*message):
"""If distributed is initialized print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(*message, flush=True)
else:
print(*message, flush=True)
def _warmup_mmap_file(path):
pass
# with open(path, "rb") as stream:
# while stream.read(100 * 1024 * 1024):
# pass
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: float,
7: np.double,
8: np.uint16,
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + ".idx"
def data_file_path(prefix_path):
return prefix_path + ".bin"
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b"MMIDIDX\x00\x00"
def __init__(self, path, skip_warmup=False):
with open(path, "rb") as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
# Little endian unsigned 64 Bit integer
version = struct.unpack("<Q", stream.read(8))
assert (1,) == version
# Little endian unsigned 8 Bit integer
(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack("<Q", stream.read(8))[0]
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()
if not skip_warmup:
print_rank_0(" warming up index mmap file...")
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print_rank_0(" reading sizes...")
self._sizes = np.frombuffer(
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes,
)
print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@property
def doc_idx(self):
return self._doc_idx
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path, skip_warmup=False):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path, skip_warmup)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, path, skip_warmup):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup:
print_rank_0(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print_rank_0(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(
data_file_path(self._path), mode="r", order="C"
)
print_rank_0(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index
def __len__(self):
return len(self._index)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
ptr, size = self._index[idx]
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError(
"Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
)
sents = np.split(np_array, offsets[:-1])
return sents
def get(self, idx, offset=0, length=None):
"""Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
)
return np_array
@property
def sizes(self):
return self._index.sizes
@property
def doc_idx(self):
return self._index.doc_idx
def get_doc_idx(self):
return self._index._doc_idx
def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(
data_file_path(path)
)

@ -1,414 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import math, os
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
try:
from deepspeed.ops.adam import FusedAdam
except:
pass # some poor windows users cant install deepspeed
logger = logging.getLogger(__name__)
RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
class L2Wrap(torch.autograd.Function):
@staticmethod
def forward(ctx, loss, y):
ctx.save_for_backward(y)
return loss
@staticmethod
def backward(ctx, grad_output):
y = ctx.saved_tensors[0]
# to encourage the logits to be close to 0
factor = 1e-4 / (y.shape[0] * y.shape[1])
maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y)
gy.scatter_(-1, ids, maxx * factor)
return (grad_output, gy)
########################################################################################################
# CUDA Kernel
########################################################################################################
T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
if '32' in os.environ['RWKV_FLOAT_MODE']:
w = -torch.exp(w.contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
else:
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
if '32' in os.environ['RWKV_FLOAT_MODE']:
return y
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return y.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return y.bfloat16()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda').contiguous()
gu = torch.zeros((B, C), device='cuda').contiguous()
gk = torch.zeros((B, T, C), device='cuda').contiguous()
gv = torch.zeros((B, T, C), device='cuda').contiguous()
if '32' in os.environ['RWKV_FLOAT_MODE']:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
if '32' in os.environ['RWKV_FLOAT_MODE']:
return (None, None, None, gw, gu, gk, gv)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in the model
print("\n[--> first run, init model params (very slow for large models) <--]")
print("[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n")
for mm in model.modules():
if "RecursiveScriptModule" in str(type(mm)):
if mm.original_name not in ["Linear"]:
continue
ww = None
for name, param in mm.named_parameters():
if name == "weight":
ww = param
else:
m = mm
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
ww = m.weight
with torch.no_grad():
name = "[unknown weight]"
for name, parameter in model.named_parameters(): # find the name of the weight
if id(ww) == id(parameter):
break
shape = ww.shape
gain = 1.0
scale = 1.0 # extra scale for gain
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # token emb?
scale = 1e-4
else:
scale = 0
if isinstance(m, nn.Linear):
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # final projection?
scale = 0.5
if hasattr(m, "scale_init"):
scale = m.scale_init
# print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {name}")
gain *= scale
if scale == -999:
nn.init.eye_(ww)
elif gain == 0:
# zero init is great for some RWKV matrices
nn.init.zeros_(ww)
elif gain > 0:
nn.init.orthogonal_(ww, gain=gain)
else:
nn.init.normal_(ww, mean=0.0, std=-scale)
class RWKV_TimeMix(torch.jit.ScriptModule):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_embd = config.n_embd
attn_sz = config.n_embd
with torch.no_grad(): # fancy init
ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
# fancy time_decay
decay_speed = torch.ones(attn_sz)
for h in range(attn_sz):
decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1)
self.time_decay = nn.Parameter(decay_speed)
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
# fancy time_first
zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5)
self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
# fancy time_mix
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd):
x[0, 0, i] = i / config.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 0
@torch.jit.script_method
def jit_func(self, x):
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)
return sr, k, v
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v = self.jit_func(x)
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
rwkv = self.output(rwkv)
return rwkv
class RWKV_ChannelMix(torch.jit.ScriptModule):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd):
x[0, 0, i] = i / config.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
hidden_sz = 4 * config.n_embd
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
self.value.scale_init = 0
self.receptance.scale_init = 0
@torch.jit.script_method
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
########################################################################################################
# The GPT Model with our blocks
########################################################################################################
class GPTConfig:
def __init__(self, vocab_size, ctx_len, **kwargs):
self.vocab_size = vocab_size
self.ctx_len = ctx_len
for k, v in kwargs.items():
setattr(self, k, v)
class Block(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.config = config
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(config, 0)
else:
self.att = RWKV_TimeMix(config, layer_id)
self.ffn = RWKV_ChannelMix(config, layer_id)
def forward(self, x):
if self.layer_id == 0:
x = self.ln0(x)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(self.ln1(x)) # better in some cases
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.step = 0
self.config = config
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.Sequential(*[Block(config, i)
for i in range(config.n_layer)])
self.ln_out = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
if RWKV_HEAD_QK_DIM > 0:
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len
try:
if os.environ['RWKV_LOAD_MODEL'] == str(False):
RWKV_Init(self, config)
except:
pass
logger.info("number of parameters: %e", sum(p.numel()
for p in self.parameters()))
def get_ctx_len(self):
return self.ctx_len
def _init_weights(self, module):
if isinstance(module, (nn.Linear)):
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, (nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=1e-5)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def configure_optimizers(self, train_config):
no_decay = set()
for mn, m in self.named_modules(): # here we disable weight_decay
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
no_decay.add(fpn)
param_dict = {pn: p for pn, p in self.named_parameters()}
optim_groups = [
{"params": [param_dict[pn]
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
try:
optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
except:
print('\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n')
optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
return optimizer
def forward(self, idx, targets=None):
idx = idx.to(self.emb.weight.device)
self.step += 1
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
if RWKV_HEAD_QK_DIM > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
if '32' in os.environ['RWKV_FLOAT_MODE']:
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).bfloat16()
x = self.head(x) + c
else:
x = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1))
return L2Wrap.apply(loss, x)

@ -1,392 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import types
import copy
import torch
import math, os
from torch.nn import functional as F
import torch.nn as nn
RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
DEBUG_TIME = False # True False - show trained time-coeffs
########################################################################################################
# CUDA Kernel
########################################################################################################
if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
if '32' in os.environ['RWKV_FLOAT_MODE']:
w = -torch.exp(w.contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
else:
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
if '32' in os.environ['RWKV_FLOAT_MODE']:
return y
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return y.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return y.bfloat16()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda').contiguous()
gu = torch.zeros((B, C), device='cuda').contiguous()
gk = torch.zeros((B, T, C), device='cuda').contiguous()
gv = torch.zeros((B, T, C), device='cuda').contiguous()
if '32' in os.environ['RWKV_FLOAT_MODE']:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
if '32' in os.environ['RWKV_FLOAT_MODE']:
return (None, None, None, gw, gu, gk, gv)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
############################################################################################################
RWKV_CFG = types.SimpleNamespace()
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
hidden_sz = 4 * RWKV_CFG.n_embd
self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False)
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
class RWKV_TimeMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd))
self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd) * math.log(0.3))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
def forward(self, x):
B, T, C = x.size()
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
rwkv = self.output(rwkv)
return rwkv
class Block(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd)
self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd)
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(layer_id+1000)
else:
self.att = RWKV_TimeMix(layer_id)
self.ffn = RWKV_ChannelMix(layer_id)
def forward(self, x):
if self.layer_id == 0:
x = self.ln0(x)
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(self.ln1(x))
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class RWKV_GPT(nn.Module):
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len):
global RWKV_CFG
super().__init__()
RWKV_CFG.RUN_DEVICE = RUN_DEVICE
RWKV_CFG.model_type = model_type
RWKV_CFG.vocab_size = vocab_size
RWKV_CFG.n_layer = n_layer
RWKV_CFG.n_embd = n_embd
RWKV_CFG.ctx_len = ctx_len
print('\nloading RWKV-GPT', MODEL_NAME)
self.emb = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
self.ln_out = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
if RWKV_HEAD_QK_DIM > 0:
self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(ctx_len, ctx_len)))
self.ctx_len = ctx_len
self.eval()
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
self.eval()
def forward(self, idx):
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
if RWKV_HEAD_QK_DIM > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
if '32' in os.environ['RWKV_FLOAT_MODE']:
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).bfloat16()
x = self.head(x) + c
else:
x = self.head(x)
return x
############################################################################################################
class RWKV_RNN(): # this is running in FP32 at this moment
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
self.RUN_DEVICE = RUN_DEVICE
self.model_type = model_type
self.n_layer = n_layer
self.n_embd = n_embd
self.ctx_len = ctx_len
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth',
map_location=torch.device(RUN_DEVICE))
for x in w.keys():
w[x] = w[x].float()
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = -torch.exp(w[x])
if DEBUG_TIME and '.time_' in x:
print(x, w[x].squeeze().cpu().numpy())
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
self.pp = {}
self.hk = None
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
target.pp = copy.deepcopy(self.pp)
target.hk = copy.deepcopy(self.hk)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
self.pp = copy.deepcopy(target.pp)
self.hk = copy.deepcopy(target.hk)
def LN(self, xx, w):
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.square(torch.relu(w.key.weight @ xk))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = w.key.weight @ xk
v = w.value.weight @ xv
pp = self.pp[name]
aa = self.aa[name]
bb = self.bb[name]
ww = w.time_first + k
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * v
b = e1 * bb + e2
ww = pp + w.time_decay
p = torch.maximum(ww, k)
e1 = torch.exp(ww - p)
e2 = torch.exp(k - p)
self.aa[name] = e1 * aa + e2 * v
self.bb[name] = e1 * bb + e2
self.pp[name] = p
rwkv = r * a / b
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
for i in range(self.n_layer):
if i == 0:
x = self.LN(x, w.blocks[i].ln0)
if i == 0 and self.model_type == 'RWKV-ffnPre':
x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
else:
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
if RWKV_HEAD_QK_DIM > 0:
if self.hk == None:
self.hk = (w.head_k.weight @ x).unsqueeze(0)
else:
self.hk = torch.cat(
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
if self.hk.shape[0] > self.ctx_len:
self.hk = self.hk[-self.ctx_len:, :]
q = w.head_q.weight @ x
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
for i in range(len(c)):
x[ctx[i]] += c[i]
else:
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
return x

@ -1,187 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os
NUM_GPUS = int(os.environ['RWKV_NUM_GPUS'])
USE_WANDB = (int(os.environ['USE_WANDB']) == 1)
from torch.utils.data.dataloader import DataLoader
import torch
from tqdm.auto import tqdm
import logging
import datetime
import math
from pytorch_lightning.lite import LightningLite
import gc
logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
if os.environ['RWKV_FLOAT_MODE'] == 'fp32':
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
class TrainerConfig:
batch_size = 64
learning_rate = 4e-4
betas = (0.9, 0.99)
eps = 1e-8
grad_norm_clip = 1.0
warmup_tokens = 0
final_tokens = 0
epoch_save_frequency = 0
epoch_save_path = 'trained-'
num_workers = 0 # for DataLoader
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
from src.model import GPT, GPTConfig
class Trainer(LightningLite):
def get_run_name(self):
raw_model = self.model.module if hasattr(
self.model, "module") else self.model
cfg = raw_model.config
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
return run_name
def run(self, m_cfg, train_dataset, test_dataset, config):
self.cuda_id = int(str(self.device).strip('cuda:'))
print('[0]')
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=m_cfg.model_type,
n_layer=m_cfg.n_layer, n_embd=m_cfg.n_embd))
print('[1]')
with torch.no_grad():
if m_cfg.LOAD_MODEL:
print('loading', m_cfg.MODEL_NAME)
m2 = torch.load(m_cfg.MODEL_NAME + '.pth', map_location='cpu')
model.load_state_dict(m2)
del m2
model.to(self.device)
self.model = model
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.config = config
self.avg_loss = -1
self.EPOCH_BEGIN = m_cfg.EPOCH_BEGIN
self.steps = self.EPOCH_BEGIN * (len(self.train_dataset) // (config.batch_size // NUM_GPUS))
if self.cuda_id == 0:
log_file = open("mylog.txt", "a")
if USE_WANDB:
print('logging to wandb... (comment it if you don\'t have wandb)')
import wandb # comment this if you don't have wandb
cfg = model.config
for k in config.__dict__:
setattr(cfg, k, config.__dict__[k]) # combine cfg
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
model, config = self.model, self.config
raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config)
model, optimizer = self.setup(model, optimizer)
print('[3]')
def run_epoch(split):
is_train = split == 'train'
model.train(is_train)
data = self.train_dataset if is_train else self.test_dataset
data.idx_begin = self.steps * config.batch_size + 1
data.cuda_id = self.cuda_id
if config.num_workers > 0:
loader = DataLoader(data, shuffle=False, pin_memory=True,
batch_size=config.batch_size // NUM_GPUS,
num_workers=config.num_workers)
else:
loader = DataLoader(data, shuffle=False,
batch_size=config.batch_size // NUM_GPUS,
num_workers=config.num_workers)
pbar = tqdm(enumerate(loader), total=len(
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
loader = self.setup_dataloaders(loader)
gc.collect()
torch.cuda.empty_cache()
for it, (x, y) in pbar:
with torch.set_grad_enabled(is_train):
loss = model(x, y) # forward the model
if os.environ['RWKV_DEEPSPEED'] == '0':
all_loss = [loss.clone()]
else:
all_loss = [loss.clone() for _ in range(NUM_GPUS)]
torch.distributed.all_gather(all_loss, loss)
if is_train: # backprop and update the parameters
model.zero_grad()
self.backward(loss)
# deepspeed will handle gradient_clipping
optimizer.step()
# decay the learning rate based on our progress
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
lr_final_factor = config.lr_final / config.learning_rate
if self.tokens < config.warmup_tokens:
# linear warmup
lr_mult = lr_final_factor + \
(1 - lr_final_factor) * float(self.tokens) / \
float(config.warmup_tokens)
progress = 0
else:
# exponential learning rate decay
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
if progress >= 1:
lr_mult = lr_final_factor
else:
lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1))
lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups:
param_group['lr'] = lr
self.lr = lr
self.steps += 1
now_loss = 0
for gg in range(NUM_GPUS):
now_loss += all_loss[gg].item()
now_loss = now_loss / NUM_GPUS # report progress
if USE_WANDB and self.cuda_id == 0:
wandb.log({"loss": now_loss}, step = self.steps)
if self.avg_loss < 0:
self.avg_loss = now_loss
else:
factor = 1 / (it + 1)
self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor
pbar.set_description(f"miniE {epoch+1+self.EPOCH_BEGIN} s {self.steps} prog {progress*100.0:.2f}% : ppl {math.exp(self.avg_loss):.6f} loss {self.avg_loss:.6f} lr {lr:e}")
self.tokens = 0 # counter used for learning rate decay
for epoch in range(99999999):
run_epoch('train')
if math.isnan(self.avg_loss):
exit(0)
if self.cuda_id == 0:
log_file.write(f'{epoch+1+self.EPOCH_BEGIN} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} {epoch+1} \n')
log_file.flush()
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
raw_model = self.model.module if hasattr(self.model, "module") else self.model
torch.save(raw_model.state_dict(), self.config.epoch_save_path + str(epoch+1+self.EPOCH_BEGIN) + '.pth')

@ -1,153 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os
try:
NUM_GPUS = int(os.environ['RWKV_NUM_GPUS'])
except:
NUM_GPUS = 1
import json
import random
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data import Dataset
class Dataset(Dataset):
def __init__(self, data, ctx_len, epoch_length_fixed):
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.data = data
if 'MMapIndexedDataset' in str(type(self.data)):
self.vocab_size = int(os.environ['VOCAB_SIZE'])
print('current vocab size =', self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data._bin_buffer) // 2
print(f'data has {self.data_size} tokens.')
elif 'numpy' in str(type(self.data)):
self.vocab_size = int(os.environ['VOCAB_SIZE'])
print('current vocab size =', self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data)
print(f'data has {self.data_size} tokens.')
else:
print('building token list...', end=' ')
unique = sorted(list(set(data)))
self.vocab_size = len(unique)
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
self.data_size = len(self.data)
print('data has %d tokens, %d unique.' % (self.data_size, self.vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
def __len__(self):
return self.epoch_length_fixed // NUM_GPUS
def __getitem__(self, idx):
#
# we are cheating: pick a random spot in dataset
#
i = np.random.randint(0, self.data_size - (self.ctx_len + 1))
if 'MMapIndexedDataset' in str(type(self.data)):
dix = self.data.get(idx=0, offset=i, length=self.ctx_len + 1).astype(int)
elif 'numpy' in str(type(self.data)):
dix = self.data[i:i+self.ctx_len+1]
else:
dix = [self.stoi[s] for s in self.data[i:i+self.ctx_len+1]]
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
return x, y
class TOKENIZER():
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
if 'list' in str(type(WORD_NAME)):
self.charMode = False
if WORD_NAME[0] == WORD_NAME[1]:
from transformers import PreTrainedTokenizerFast
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
else:
from transformers import GPT2TokenizerFast
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
self.vocab_size = len(self.tokenizer)
else:
self.charMode = True
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
self.word_table = json.load(result_file)
self.vocab_size = len(self.word_table)
self.stoi = {v: int(k) for k, v in self.word_table.items()}
self.itos = {int(k): v for k, v in self.word_table.items()}
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
# out[self.UNKNOWN_CHAR] = -float('Inf')
lastChar = int(x[-1])
probs = F.softmax(torch.tensor(out), dim=-1)
if self.charMode:
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
else:
top_p = top_p_usual
sorted_probs, s_index = torch.sort(probs, descending=True)
# for j in range(30):
# pp = sorted_probs[j].item()
# if pp < 0.005:
# break
# ss = self.itos[int(s_index[j])].replace('\n','_')
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
# print('')
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
return torch.multinomial(probs, num_samples=1)[0]
def to_float(x):
return x.cpu().detach().numpy().flatten()[0].astype(float)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

@ -1,280 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os
import logging, types
from src.utils import Dataset
import torch
import numpy as np
from src.binidx import MMapIndexedDataset
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
# if False: # True False ---> Set to False if you don't understand it
# print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n")
# import src.utils
# src.utils.set_seed(42) # make training deterministic (including dataloader). if you are doing this, remember to change seed when you load a model (otherwise the dataloader loads old samples)
########################################################################################################
# Step 1: set training data & cfg
########################################################################################################
EXPRESS_PILE_MODE = False # True: express mode for fine-tuning a pile model // False: usual training
EXPRESS_PILE_MODEL_NAME = 'RWKV-4-Pile-169M-20220807-8023'
EXPRESS_PILE_MODEL_TYPE = 'RWKV-4-Pile-169M'
# EXPRESS_PILE_MODEL_NAME = 'RWKV-4-Pile-430M-20220808-8066'
# EXPRESS_PILE_MODEL_TYPE = 'RWKV-4-Pile-430M'
# EXPRESS_PILE_MODEL_NAME = 'RWKV-4-Pile-1B5-20220903-8040'
# EXPRESS_PILE_MODEL_TYPE = 'RWKV-4-Pile-1B5'
########################################################################################################
datafile = "../data/enwik8" # your data
datafile_encoding = 'utf-8' # 'utf-8' / 'utf-16le' / 'numpy' (for fine-tuning pile models) / 'binidx' (the Megatron-LM 'binidx' format)
# datafile = 'my-gpt_seq_document'
# datafile_encoding = 'binidx'
if EXPRESS_PILE_MODE:
datafile = 'train.npy' # use 'prepare-data.py' in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/tree/main/RWKV-v3 to tokenize .txt into .npy
datafile_encoding = 'numpy'
#
# set VOCAB_SIZE = 0 (auto-compute) if you are training a char-level LM from scratch
# set VOCAB_SIZE = 50277 for fine-tuning pile models
# set VOCAB_SIZE = your_vocab_size for 'binidx' data
#
os.environ['VOCAB_SIZE'] = '0'
if EXPRESS_PILE_MODE:
os.environ['VOCAB_SIZE'] = '50277'
#
# Currently it's slow to initialize a new model. Hence I suggest this procedure for multi-GPU training:
# 1) set RWKV_NUM_GPUS = '1' and let it run for 1 miniEpoch and it will save a trained-1.pth
# 2) set RWKV_NUM_GPUS = '8' (or your #GPU), batch_size = single_gpu_batchsz * RWKV_NUM_GPUS,
# EPOCH_BEGIN = 1, LOAD_MODEL = True, and it will load 'trained-1.pth' and continue the training from it
#
os.environ['RWKV_NUM_GPUS'] = '1' # num of GPUs to use
#
# 'bf16' (fast & stable)
# 'fp16' (fast & will overflow after training a large model for very long. can be solved in the future)
# 'tf32' (decent speed & stable)
# 'fp32' (!!!very slow!!! only for verification)
os.environ['RWKV_FLOAT_MODE'] = 'bf16'
os.environ['RWKV_DEEPSPEED'] = '1' # Use DeepSpeed? 0 = False, 1 = True
if int(os.environ['RWKV_NUM_GPUS']) == 1: # Usually you don't need DeepSpeed for 1 GPU training.
os.environ['RWKV_DEEPSPEED'] = '0' # However, sometimes DeepSpeed saves VRAM even for 1 GPU training. So you shall try it.
os.environ['USE_WANDB'] = '0' # wandb logging. 0 = False, 1 = True
########################################################################################################
# Step 2: set model details
########################################################################################################
EPOCH_BEGIN = 0 # begins with miniEpoch = EPOCH_BEGIN
LOAD_MODEL = False # shall we load the #EPOCH_BEGIN model and continue the training from it?
n_layer = 6
n_embd = 512
ctx_len = 1024 # increase T_MAX in src/model.py if your ctx_len is longer
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' (sometimes better)
# there is also a RWKV_HEAD_QK_DIM in model.py and model_run.py
# set it to 256, then it's using my headQK trick (a tiny attention) to improve loss
# set it to 0, then it's a pure RNN (attention-free)
if EXPRESS_PILE_MODE:
LOAD_MODEL = True
if EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-169M':
n_layer = 12
n_embd = 768
ctx_len = 1024
elif EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-430M':
n_layer = 24
n_embd = 1024
ctx_len = 1024
elif EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-1B5':
n_layer = 24
n_embd = 2048
ctx_len = 1024
########################################################################################################
# Step 3: set batch size & learning rate etc.
########################################################################################################
# if you see "CUDA out of memory", reduce batch_size. Use nvidia-smi to find the highest value for your GPU.
batch_size = 12 * int(os.environ['RWKV_NUM_GPUS'])
assert (batch_size % int(os.environ['RWKV_NUM_GPUS']) == 0)
# By default we are using exponential LR decay.
# Here are my suggestions for training.
# Let's say you are training a L6-D512 model.
# 1) Set lr_init = lr_final = 8e-4. Let it run for some mini-epochs, until you feel like reducing LR.
# 2) Check epoch_save_frequency and make sure the partially-trained model is saved. Ctrl+C to stop the run.
# 3) Set lr_init = 8e-4, lr_final = 1e-5, betas = (0.9, 0.999).
# 4) Set EPOCH_BEGIN & LOAD_MODEL to load the partially-trained model. Continue the training.
#
# For L12-D768, set lr_init = 6e-4. For L24-D1024, set lr_init = 4e-4. For L24-D2048, set lr_init = 3e-4.
lr_init = 8e-4
lr_final = 1e-5
# the mini-epoch is very short and of fixed length (length = ctx_len * epoch_length_fixed tokens)
n_epoch = 500
epoch_length_fixed = (10000 // batch_size) * batch_size # feel free to increase it if you have lots of GPU
# epoch_save_frequency 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, ...
epoch_save_frequency = 10
epoch_save_path = 'trained-'
if EXPRESS_PILE_MODE:
if EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-169M':
lr_init = 2e-5
else:
lr_init = 1e-5
lr_final = 1e-5
n_epoch = 100000
### misc stuffs ########################################################################################
if LOAD_MODEL and EPOCH_BEGIN > 0: # we are not saving gradients, so let's have some warmup if we load a model
warmup_tokens = 50 * ctx_len * batch_size // NUM_GPUS
else:
warmup_tokens = 0
betas = (0.9, 0.99) # set betas = (0.9, 0.999) if your model has been trained for a while
eps = 1e-8
num_workers = 1 # DataLoader worker. I only tested num_workers = 1
NUM_GPUS = int(os.environ['RWKV_NUM_GPUS'])
os.environ['RWKV_LOAD_MODEL'] = str(LOAD_MODEL)
MODEL_NAME = epoch_save_path + str(EPOCH_BEGIN)
if EXPRESS_PILE_MODE:
betas = (0.9, 0.999)
MODEL_NAME = EXPRESS_PILE_MODEL_NAME
torch.backends.cudnn.benchmark = True
if os.environ['RWKV_FLOAT_MODE'] == 'fp32':
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
########################################################################################################
# Load data
########################################################################################################
print(f'loading {datafile_encoding} data... ' + datafile)
if datafile_encoding == 'binidx':
train_dataset = Dataset(MMapIndexedDataset(datafile), ctx_len, epoch_length_fixed)
elif datafile_encoding == 'numpy':
train_dataset = Dataset(np.load(datafile).astype('int'), ctx_len, epoch_length_fixed)
else:
train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)
########################################################################################################
# Train model
########################################################################################################
if __name__ == '__main__':
from src.trainer import Trainer, TrainerConfig
print('\nmodel', model_type, os.environ['RWKV_FLOAT_MODE'], 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, '\n')
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps,
warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
m_cfg = types.SimpleNamespace()
m_cfg.model_type = model_type
m_cfg.n_layer = n_layer
m_cfg.n_embd = n_embd
m_cfg.EPOCH_BEGIN = EPOCH_BEGIN
m_cfg.LOAD_MODEL = LOAD_MODEL
m_cfg.MODEL_NAME = MODEL_NAME
if os.environ['RWKV_DEEPSPEED'] == '0':
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=16)
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision='bf16')
elif '32' in os.environ['RWKV_FLOAT_MODE']:
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=32)
else:
from pytorch_lightning.strategies import DeepSpeedStrategy
DEEPSPEED_CFG = {
"zero_allow_untested_optimizer":True,
"zero_optimization":{
"stage":2,
"contiguous_gradients":True,
"overlap_comm":True,
"allgather_partitions":True,
"reduce_scatter":True,
"allgather_bucket_size":200000000,
"reduce_bucket_size":200000000,
"sub_group_size":1000000000000
},
"activation_checkpointing":{
"partition_activations":False,
"cpu_checkpointing":False,
"contiguous_memory_optimization":False,
"synchronize_checkpoint_boundary":False
},
"aio":{
"block_size":1048576,
"queue_depth":8,
"single_submit":False,
"overlap_events":True,
"thread_count":1
},
"gradient_clipping": 1.0,
"gradient_accumulation_steps": 1,
}
if NUM_GPUS == 1:
DEEPSPEED_CFG['zero_optimization'] = {
"stage":1, # saves some VRAM
"contiguous_gradients":False,
"overlap_comm":False,
"allgather_partitions":False,
"reduce_scatter":False,
"allgather_bucket_size":200000000,
"reduce_bucket_size":200000000,
"sub_group_size":1000000000000
}
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
DEEPSPEED_CFG["fp16"] = {
"fp16": True,
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 12,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
}
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=16)
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
DEEPSPEED_CFG["bf16"] = {
"enabled": True
}
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16')
elif '32' in os.environ['RWKV_FLOAT_MODE']:
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=32)
print(trainer._strategy.config)
trainer.run(m_cfg, train_dataset, None, tconf)

@ -1,90 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
# this is for verifying the results of different models and make sure they agree with each other
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)
os.environ['RWKV_RUN_DEVICE'] = 'cuda'
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
import torch
from src.model_run import RWKV_RNN, RWKV_GPT
from src.model import GPT, GPTConfig
TOKEN_MODE = 'pile' # char / pile
if TOKEN_MODE == 'char':
MODEL_NAME = 'trained-1'
WORD_NAME = 'vocab' # the .json vocab (generated by train.py)
ctx_len = 1024
n_layer = 6
n_embd = 512
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
elif TOKEN_MODE == 'pile':
WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json']
MODEL_NAME = 'RWKV-4-Pile-169M-20220807-8023'
ctx_len = 1024
n_layer = 12
n_embd = 768
UNKNOWN_CHAR = None
model_type = 'RWKV'
from src.utils import TOKENIZER
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
if TOKEN_MODE == 'pile':
tokenizer.vocab_size = 50277
########################################################################################################
model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda()
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
model_train = model_train.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
model_train = model_train.bfloat16()
print('loading ' + MODEL_NAME)
m2 = torch.load(MODEL_NAME + '.pth', map_location=RUN_DEVICE)
model_train.load_state_dict(m2)
model_rnn = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
model_gpt = RWKV_GPT(MODEL_NAME, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len).cuda()
########################################################################################################
# context = '\nIn a'
context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.'
if TOKEN_MODE == 'char':
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
elif TOKEN_MODE == 'pile':
ctx = tokenizer.tokenizer.encode(context)
print(f'input len {len(ctx)} data {ctx}')
########################################################################################################
print('\nRWKV-GPT output')
out = model_gpt.forward(torch.tensor(ctx).unsqueeze(0).cuda())[0].detach().cpu().numpy()
print(out)
print('\nRWKV-RNN output')
model_rnn.clear()
src_len = len(ctx)
for i in range(src_len):
x = ctx[:i+1]
out = model_rnn.run(x)
if i < 3 or i >= src_len - 3:
print(torch.tensor(out).detach().cpu().numpy())
if i == 2:
print('...')
print('\nRWKV-train output')
out = model_train.forward(torch.tensor([ctx]).cuda())[0][0].detach().cpu().float().numpy()
print(out, '\n')

File diff suppressed because it is too large Load Diff

@ -1,361 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
print('Loading...')
from src.model_run import RWKV_RNN
import numpy as np
import os, copy, types, gc, sys
import torch
from src.utils import TOKENIZER
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
CHAT_LANG = 'English' # English Chinese
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
args = types.SimpleNamespace()
args.RUN_DEVICE = "cuda" # 'cpu' (already very fast) // 'cuda'
args.FLOAT_MODE = "fp16" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate)
args.vocab_size = 50277
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230108-5170'
args.n_layer = 40
args.n_embd = 5120
args.ctx_len = 1024
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
# args.n_layer = 32
# args.n_embd = 4096
# args.ctx_len = 1024
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023'
# args.n_layer = 32
# args.n_embd = 2560
# args.ctx_len = 1024
if CHAT_LANG == 'English':
user = "User"
bot = "Bot"
interface = ":"
# The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
# The following is a conversation between a highly knowledgeable and intelligent AI called {bot}, and a human called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} do its best to answer {user}'s questions. {bot} is respectful, polite and inclusive. {bot} knows a lot, and always tells the truth.
init_prompt = f'''
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
{user}{interface} french revolution what year
{bot}{interface} The French Revolution started in 1789, and lasted 10 years until 1799.
{user}{interface} 3+5=?
{bot}{interface} The answer is 8.
{user}{interface} guess i marry who ?
{bot}{interface} Only if you tell me more about yourself - what are your interests?
{user}{interface} solve for a: 9-a=2
{bot}{interface} The answer is a = 7, because 9 - 7 = 2.
{user}{interface} wat is lhc
{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
'''
HELP_MSG = '''Commands:
say something --> chat with bot. use \\n for new line.
+alt --> alternate chat reply
+reset --> reset chat
+gen YOUR PROMPT --> free generation with any prompt. use \\n for new line.
+qa YOUR QUESTION --> free generation - ask any question (just ask the question). use \\n for new line.
+more --> continue last free generation (only for +gen / +qa)
+retry --> retry last free generation (only for +gen / +qa)
Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results.
This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +gen for free generation.
'''
elif CHAT_LANG == 'Chinese':
args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run3z/rwkv-293'
args.n_layer = 32
args.n_embd = 4096
args.ctx_len = 1024
user = "Q"
bot = "A"
interface = ":"
init_prompt = '''
Q: 企鹅会飞吗
A: 企鹅是不会飞的它们的翅膀主要用于游泳和平衡而不是飞行
Q: 西瓜是什么
A: 西瓜是一种常见的水果是一种多年生蔓生藤本植物西瓜的果实呈圆形或卵形通常是绿色的里面有红色或黄色的肉和很多的籽西瓜味甜多吃可以增加水分是夏季非常受欢迎的水果之一
'''
HELP_MSG = '''指令:
直接输入内容 --> 和机器人聊天\\n代表换行
+alt --> 让机器人换个回答
+reset --> 重置对话
+gen 某某内容 --> 续写任何中英文内容\\n代表换行
+qa 某某问题 --> 问独立的问题忽略上下文\\n代表换行
+more --> 继续 +gen / +qa 的回答
+retry --> 换个 +gen / +qa 的回答
现在可以输入内容和机器人聊天注意它不怎么懂中文它可能更懂英文请经常使用 +reset 重置机器人记忆
'''
# Load Model
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
MODEL_NAME = args.MODEL_NAME
print(f'loading... {MODEL_NAME}')
model = RWKV_RNN(args)
model_tokens = []
current_state = None
########################################################################################################
def run_rnn(tokens, newline_adj = 0):
global model_tokens, current_state
for i in range(len(tokens)):
model_tokens += [int(tokens[i])]
if i == len(tokens) - 1:
out, current_state = model.forward(model_tokens, current_state)
else:
current_state = model.forward(model_tokens, current_state, preprocess_only = True)
# print(f'### model ###\n[{tokenizer.tokenizer.decode(model_tokens)}]')
out[0] = -999999999 # disable <|endoftext|>
out[187] += newline_adj
# if newline_adj > 0:
# out[15] += newline_adj / 2 # '.'
return out
all_state = {}
def save_all_stat(srv, name, last_out):
n = f'{name}_{srv}'
all_state[n] = {}
all_state[n]['out'] = last_out
all_state[n]['rnn'] = copy.deepcopy(current_state)
all_state[n]['token'] = copy.deepcopy(model_tokens)
def load_all_stat(srv, name):
global model_tokens, current_state
n = f'{name}_{srv}'
current_state = copy.deepcopy(all_state[n]['rnn'])
model_tokens = copy.deepcopy(all_state[n]['token'])
return all_state[n]['out']
########################################################################################################
# Run inference
print(f'\nRun prompt...')
out = run_rnn(tokenizer.tokenizer.encode(init_prompt))
gc.collect()
torch.cuda.empty_cache()
save_all_stat('', 'chat_init', out)
srv_list = ['dummy_server']
for s in srv_list:
save_all_stat(s, 'chat', out)
print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n')
def reply_msg(msg):
print(f'{bot}{interface} {msg}\n')
def on_message(message):
global model_tokens, current_state
srv = 'dummy_server'
msg = message.replace('\\n','\n').strip()
if len(msg) > 1000:
reply_msg('your message is too long (max 1000 tokens)')
return
x_temp = 1.0
x_top_p = 0.85
if ("-temp=" in msg):
x_temp = float(msg.split("-temp=")[1].split(" ")[0])
msg = msg.replace("-temp="+f'{x_temp:g}', "")
# print(f"temp: {x_temp}")
if ("-top_p=" in msg):
x_top_p = float(msg.split("-top_p=")[1].split(" ")[0])
msg = msg.replace("-top_p="+f'{x_top_p:g}', "")
# print(f"top_p: {x_top_p}")
if x_temp <= 0.2:
x_temp = 0.2
if x_temp >= 5:
x_temp = 5
if x_top_p <= 0:
x_top_p = 0
if msg == '+reset':
out = load_all_stat('', 'chat_init')
save_all_stat(srv, 'chat', out)
reply_msg("Chat reset.")
return
elif msg[:5].lower() == '+gen ' or msg[:4].lower() == '+qa ' or msg.lower() == '+more' or msg.lower() == '+retry':
if msg[:5].lower() == '+gen ':
new = '\n' + msg[5:].strip()
# print(f'### prompt ###\n[{new}]')
current_state = None
out = run_rnn(tokenizer.tokenizer.encode(new))
save_all_stat(srv, 'gen_0', out)
elif msg[:4].lower() == '+qa ':
out = load_all_stat('', 'chat_init')
real_msg = msg[4:].strip()
new = f"{user}{interface} {real_msg}\n\n{bot}{interface}"
# print(f'### qa ###\n[{new}]')
out = run_rnn(tokenizer.tokenizer.encode(new))
save_all_stat(srv, 'gen_0', out)
# new = f"\nThe following is an excellent Q&A session consists of detailed and factual information.\n\nQ: What is 3+5?\nA: The answer is 8.\n\nQ: {msg[9:].strip()}\nA:"
# print(f'### prompt ###\n[{new}]')
# current_state = None
# out = run_rnn(tokenizer.tokenizer.encode(new))
# save_all_stat(srv, 'gen_0', out)
elif msg.lower() == '+more':
try:
out = load_all_stat(srv, 'gen_1')
save_all_stat(srv, 'gen_0', out)
except:
return
elif msg.lower() == '+retry':
try:
out = load_all_stat(srv, 'gen_0')
except:
return
begin = len(model_tokens)
out_last = begin
for i in range(150):
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p_usual=x_top_p,
top_p_newline=x_top_p,
)
if msg[:4].lower() == '+qa ':
out = run_rnn([token], newline_adj=-1)
else:
out = run_rnn([token])
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx:
print(xxx, end='', flush=True)
out_last = begin + i + 1
print('\n')
# send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
# print(f'### send ###\n[{send_msg}]')
# reply_msg(send_msg)
save_all_stat(srv, 'gen_1', out)
else:
if msg.lower() == '+alt':
try:
out = load_all_stat(srv, 'chat_pre')
except:
return
else:
out = load_all_stat(srv, 'chat')
new = f"{user}{interface} {msg}\n\n{bot}{interface}"
# print(f'### add ###\n[{new}]')
out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999)
save_all_stat(srv, 'chat_pre', out)
begin = len(model_tokens)
out_last = begin
print(f'{bot}{interface}', end='', flush=True)
for i in range(999):
if i <= 0:
newline_adj = -999999999
elif i <= 30:
newline_adj = (i - 30) / 10
elif i <= 130:
newline_adj = 0
else:
newline_adj = (i - 130) * 0.25 # MUST END THE GENERATION
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p_usual=x_top_p,
top_p_newline=x_top_p,
)
out = run_rnn([token], newline_adj=newline_adj)
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx:
print(xxx, end='', flush=True)
out_last = begin + i + 1
send_msg = tokenizer.tokenizer.decode(model_tokens[begin:])
if '\n\n' in send_msg:
send_msg = send_msg.strip()
break
# send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
# if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!!
# send_msg = send_msg[:-len(f'{user}{interface}')].strip()
# break
# if send_msg.endswith(f'{bot}{interface}'):
# send_msg = send_msg[:-len(f'{bot}{interface}')].strip()
# break
# print(f'{model_tokens}')
# print(f'[{tokenizer.tokenizer.decode(model_tokens)}]')
# print(f'### send ###\n[{send_msg}]')
# reply_msg(send_msg)
save_all_stat(srv, 'chat', out)
print(HELP_MSG)
while True:
msg = input(f'{user}{interface} ')
if len(msg.strip()) > 0:
on_message(msg)
else:
print('Erorr: please say something')

@ -1,133 +0,0 @@
#include <stdio.h>
#include <assert.h>
#define MIN_VALUE (-1e38)
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
F *__restrict__ const _y) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
F *__restrict__ const y = _y + _offset;
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
F aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
const F kk = k[ii];
const F vv = v[ii];
F ww = u + kk;
F p = max(pp, ww);
F e1 = exp(pp - p);
F e2 = exp(ww - p);
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
}
template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
const F *__restrict__ const _y, const F *__restrict__ const _gy,
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
const F *__restrict__ const y = _y + _offset;
const F *__restrict__ const gy = _gy + _offset;
F *__restrict__ const gk = _gk + _offset;
F *__restrict__ const gv = _gv + _offset;
F q[Tmax], r[Tmax];
F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
const F kk = k[ii];
const F vv = v[ii];
const F yy = y[ii];
F ww = u + kk;
F p = max(pp, ww);
F e1 = exp(pp - p);
F e2 = exp(ww - p);
const F qq = gy[ii] / (e1 * bb + e2);
gw += (ga - gb * yy) * e1 * qq;
gu += (vv - yy) * e2 * qq;
q[i] = qq;
r[i] = ww - p;
ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
ga = e1 * (aa + ga);
gb = e1 * (bb + gb);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
_gu[_offsetBC] = gu;
aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) {
const int ii = i * C;
const F kk = k[ii];
const F vv = v[ii];
const F yy = y[ii];
const F qq = q[i];
const F rr = r[i];
F e1 = qq * exp(rr);
F e2 = exp(kk + pp);
gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
gv[ii] = e1 + e2 * aa;
const F ww = w + pp;
const F www = rr - u - kk;
const F p = max(ww, www);
e1 = exp(ww - p);
e2 = qq * exp(www - p);
aa = e1 * aa + e2;
bb = e1 * bb - e2 * yy;
pp = p;
}
}
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
}
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
}

@ -1,132 +0,0 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
#define MIN_VALUE (-1e38)
typedef at::BFloat16 bf16;
__global__ void kernel_forward(const int B, const int T, const int C,
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
bf16 *__restrict__ const _y) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
float u = float(_u[_c]);
float w = _w[_c];
const bf16 *__restrict__ const k = _k + _offset;
const bf16 *__restrict__ const v = _v + _offset;
bf16 *__restrict__ const y = _y + _offset;
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
float aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
const float kk = float(k[ii]);
const float vv = float(v[ii]);
float ww = u + kk;
float p = max(pp, ww);
float e1 = exp(pp - p);
float e2 = exp(ww - p);
y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
}
__global__ void kernel_backward(const int B, const int T, const int C,
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy,
bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
float u = float(_u[_c]);
float w = _w[_c];
const bf16 *__restrict__ const k = _k + _offset;
const bf16 *__restrict__ const v = _v + _offset;
const bf16 *__restrict__ const y = _y + _offset;
const bf16 *__restrict__ const gy = _gy + _offset;
bf16 *__restrict__ const gk = _gk + _offset;
bf16 *__restrict__ const gv = _gv + _offset;
float q[Tmax], r[Tmax];
float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
const float kk = float(k[ii]);
const float vv = float(v[ii]);
const float yy = float(y[ii]);
float ww = u + kk;
float p = max(pp, ww);
float e1 = exp(pp - p);
float e2 = exp(ww - p);
const float qq = float(gy[ii]) / (e1 * bb + e2);
gw += (ga - gb * yy) * e1 * qq;
gu += (vv - yy) * e2 * qq;
q[i] = qq;
r[i] = ww - p;
ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
ga = e1 * (aa + ga);
gb = e1 * (bb + gb);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
_gu[_offsetBC] = bf16(gu);
aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) {
const int ii = i * C;
const float kk = float(k[ii]);
const float vv = float(v[ii]);
const float yy = float(y[ii]);
const float qq = q[i];
const float rr = r[i];
float e1 = qq * exp(rr);
float e2 = exp(kk + pp);
gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
gv[ii] = bf16(e1 + e2 * aa);
const float ww = w + pp;
const float www = rr - u - kk;
const float p = max(ww, www);
e1 = exp(ww - p);
e2 = qq * exp(www - p);
aa = e1 * aa + e2;
bb = e1 * bb - e2 * yy;
pp = p;
}
}
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
}
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
}

@ -1,21 +0,0 @@
#include <torch/extension.h>
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv forward");
m.def("backward", &backward, "wkv backward");
}
TORCH_LIBRARY(wkv, m) {
m.def("forward", forward);
m.def("backward", backward);
}

@ -1,25 +0,0 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
typedef at::BFloat16 bf16;
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv forward");
m.def("backward", &backward, "wkv backward");
}
TORCH_LIBRARY(wkv, m) {
m.def("forward", forward);
m.def("backward", backward);
}

@ -1,165 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import torch, types, os
import numpy as np
from PIL import Image
import torch.nn as nn
from torch.nn import functional as F
import torchvision as vision
import torchvision.transforms as transforms
np.set_printoptions(precision=4, suppress=True, linewidth=200)
print(f'loading...')
########################################################################################################
model_prefix = 'test/image_trained/out-v7c_d8_256-224-13bit-OB32x0.5-201'
input_img = 'test/img_ae_test/test0.png'
########################################################################################################
class ToBinary(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone() # pass-through
class R_ENCODER(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
dd = 8
self.Bxx = nn.BatchNorm2d(dd*64)
self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*4)
self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*64)
self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
def forward(self, img):
ACT = F.mish
x = self.CIN(img)
xx = self.Bxx(F.pixel_unshuffle(x, 8))
x = x + self.Cx1(ACT(self.Cx0(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
x = x + self.C13(ACT(self.C12(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
x = self.COUT(x + xx)
return torch.sigmoid(x)
class R_DECODER(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
dd = 8
self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*64)
self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*4)
self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
def forward(self, code):
ACT = F.mish
x = self.CIN(code)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
x = x + self.C13(ACT(self.C12(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.Cx1(ACT(self.Cx0(x)))
x = self.COUT(x)
return torch.sigmoid(x)
########################################################################################################
print(f'building model...')
args = types.SimpleNamespace()
args.my_img_bit = 13
encoder = R_ENCODER(args).eval().cuda()
decoder = R_DECODER(args).eval().cuda()
zpow = torch.tensor([2**i for i in range(0,13)]).reshape(13,1,1).cuda().long()
encoder.load_state_dict(torch.load(f'{model_prefix}-E.pth'))
decoder.load_state_dict(torch.load(f'{model_prefix}-D.pth'))
########################################################################################################
print(f'test image...')
img_transform = transforms.Compose([
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Resize((224, 224))
])
with torch.no_grad():
img = img_transform(Image.open(input_img)).unsqueeze(0).cuda()
z = encoder(img)
z = ToBinary.apply(z)
zz = torch.sum(z.squeeze().long() * zpow, dim=0)
print(f'Code shape = {zz.shape}\n{zz.cpu().numpy()}\n')
out = decoder(z)
vision.utils.save_image(out, f"{input_img.split('.')[0]}-out-13bit.jpg")

@ -1,237 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math, os, sys, types, time, gc
import torch
from src.utils import TOKENIZER
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
args = types.SimpleNamespace()
########################################################################################################
# Step 1: set model & config (use v4 to run your trained-from-scratch models. v4 and v4neo are compatible)
########################################################################################################
args.RUN_DEVICE = "cuda" # 'cuda' // 'cpu' (already fast)
args.FLOAT_MODE = "fp16" # fp16 (good for GPU, does not work for CPU) // fp32 (good for CPU) // bf16 (less accurate, but works for CPU)
# if args.RUN_DEVICE == "cuda":
# os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output
os.environ["RWKV_JIT_ON"] = '1' # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!!
TOKEN_MODE = "pile"
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None
vocab_size = 50277
# Download Pile models: https://huggingface.co/BlinkDL
# or, set MODEL_NAME to your fine-tuned model
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
# n_layer = 12
# n_embd = 768
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066'
# n_layer = 24
# n_embd = 1024
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040'
# n_layer = 24
# n_embd = 2048
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023'
# n_layer = 32
# n_embd = 2560
# ctx_len = 1024
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
n_layer = 32
n_embd = 4096
ctx_len = 1024
args.MODEL_NAME = MODEL_NAME
args.n_layer = n_layer
args.n_embd = n_embd
args.ctx_len = ctx_len
args.vocab_size = vocab_size
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
########################################################################################################
# Step 2: set prompt & sampling stuffs
########################################################################################################
# context = 'A'
# context = "\nIn the"
# context = '\nSugar:'
context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
# context = "\n深圳是" # test Chinese
# context = "\n東京は" # test Japanese
# ###### A good prompt for Q&A ######
# context = '''
# Questions & Helpful Answers
# Ask Research Experts
# Question:
# Can penguins fly?
# Full Answer:
# '''
# ###### A good prompt for chatbot ######
# context = '''
# The following is a conversation between a highly knowledgeable and intelligent AI assistant called Bot, and a human user called User. In the following interactions, User and Bot converse in natural language, and Bot always answer User's questions. Bot is very smart, polite and humorous. Bot knows a lot, and always tells the truth. The conversation begins.
# User: who is president of usa?
# Bot: Its Joe Biden; he was sworn in earlier this year.
# User: french revolution what year
# Bot: It started in 1789, but it lasted 10 years until 1799.
# User: guess i marry who ?
# Bot: Only if you tell me more about yourself - what are your interests?
# User: wat is lhc
# Bot: Its a large and very expensive piece of science equipment. If I understand correctly, its a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
# User:''' # type your question here
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 333
TEMPERATURE = 1.0
top_p = 0.8
top_p_newline = 0.9 # only used in TOKEN_MODE = char
DEBUG_DEBUG = False # True False --> show softmax output
########################################################################################################
print(f'\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...')
from src.model_run import RWKV_RNN
model = RWKV_RNN(args)
print(f'\nOptimizing speed...')
out, _ = model.forward([187], None)
# print(out)
gc.collect()
torch.cuda.empty_cache()
# input(0)
print(f'\nLoading tokenizer {WORD_NAME}...')
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
if TOKEN_MODE == "pile":
assert tokenizer.tokenizer.decode([187]) == '\n'
########################################################################################################
if tokenizer.charMode:
context = tokenizer.refine_context(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
else:
ctx = tokenizer.tokenizer.encode(context)
src_len = len(ctx)
src_ctx = ctx.copy()
print("\nYour prompt has " + str(src_len) + " tokens.")
print(
"Note: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n"
)
time_slot = {}
time_ref = time.time_ns()
def record_time(name):
if name not in time_slot:
time_slot[name] = 1e20
tt = (time.time_ns() - time_ref) / 1e9
if tt < time_slot[name]:
time_slot[name] = tt
init_state = None
init_out = None
state = None
out = None
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
print(("-" * 50) + '\n' + context, end="")
time_ref = time.time_ns()
ctx = src_ctx.copy()
if TRIAL == 0:
for i in range(src_len):
x = ctx[: i + 1]
if i == src_len - 1:
init_out, init_state = model.forward(x, init_state)
else:
init_state = model.forward(x, init_state, preprocess_only=True)
gc.collect()
torch.cuda.empty_cache()
record_time('preprocess')
out_last = src_len
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
x = ctx[: i + 1]
x = x[-ctx_len:]
if i == src_len:
out = init_out.clone()
state = init_state.clone()
else:
out, state = model.forward(x, state)
if DEBUG_DEBUG:
print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy()))
if TOKEN_MODE == "pile":
out[0] = -999999999 # disable <|endoftext|>
ttt = tokenizer.sample_logits(
out,
x,
ctx_len,
temperature=TEMPERATURE,
top_p_usual=top_p,
top_p_newline=top_p_newline,
)
ctx += [ttt]
if tokenizer.charMode:
char = tokenizer.itos[ttt]
print(char, end="", flush=True)
else:
char = tokenizer.tokenizer.decode(ctx[out_last:])
if '\ufffd' not in char: # is valid utf8 string?
print(char, end="", flush=True)
out_last = i+1
record_time('total')
# print(f'\n\n{time_slot}\n\n')
print(
f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = ''
)
print(("-" * 50) + '\n')

@ -1,269 +0,0 @@
from lib2to3.pgen2 import token
import os
import torch
import numpy as np
import shutil
import struct
from functools import lru_cache
from itertools import accumulate
def print_rank_0(*message):
pass
# """If distributed is initialized print only on rank 0."""
# if torch.distributed.is_initialized():
# if torch.distributed.get_rank() == 0:
# print(*message, flush=True)
# else:
# print(*message, flush=True)
def _warmup_mmap_file(path):
pass
# with open(path, "rb") as stream:
# while stream.read(100 * 1024 * 1024):
# pass
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: float,
7: np.double,
8: np.uint16,
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + ".idx"
def data_file_path(prefix_path):
return prefix_path + ".bin"
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b"MMIDIDX\x00\x00"
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, "wb")
# Write Magic string so we can check the file format then opening it again.
self._file.write(cls._HDR_MAGIC)
# Write version number
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", 1))
# Little endian unsigned 8 Bit integer
self._file.write(struct.pack("<B", code(dtype)))
return self
@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers
def write(self, sizes, doc_idx):
pointers = self._get_pointers(sizes)
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", len(sizes)))
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", len(doc_idx)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers
doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order="C"))
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path, skip_warmup=False):
with open(path, "rb") as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
# Little endian unsigned 64 Bit integer
version = struct.unpack("<Q", stream.read(8))
assert (1,) == version
# Little endian unsigned 8 Bit integer
(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack("<Q", stream.read(8))[0]
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()
if not skip_warmup:
print_rank_0(" warming up index mmap file...")
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print_rank_0(" reading sizes...")
self._sizes = np.frombuffer(
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes,
)
print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@property
def doc_idx(self):
return self._doc_idx
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path, skip_warmup=False):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path, skip_warmup)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, path, skip_warmup):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup:
print_rank_0(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print_rank_0(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(
data_file_path(self._path), mode="r", order="C"
)
print_rank_0(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index
def __len__(self):
return len(self._index)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
ptr, size = self._index[idx]
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError(
"Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
)
sents = np.split(np_array, offsets[:-1])
return sents
def get(self, idx, offset=0, length=None):
"""Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
)
return np_array
@property
def sizes(self):
return self._index.sizes
@property
def doc_idx(self):
return self._index.doc_idx
def get_doc_idx(self):
return self._index._doc_idx
def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(
data_file_path(path)
)

@ -1,240 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import json, math, random, os, sys
import numpy as np
import torch
from torch.utils.data import Dataset
from pytorch_lightning.utilities import rank_zero_info
from .binidx import MMapIndexedDataset
from .utils import MaybeIsPrime
class MyDataset(Dataset):
def __init__(self, args):
self.args = args
if args.data_type == "binidx":
self.vocab_size = args.vocab_size
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
if args.my_pile_version == 1:
self.data = MMapIndexedDataset(args.data_file)
self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
rank_zero_info(f"Data has {self.data_size} tokens.")
else:
data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n')
data_list = [i.strip().split(' ') for i in data_list]
self.data = []
self.data_size = int(data_list[-1][-1])
rank_zero_info(f"Data has {self.data_size} chunks.")
for d in data_list:
data = MMapIndexedDataset(d[0])
data_size = len(data._bin_buffer) // data._index._dtype_size
assert (data_size - args.ctx_len) == int(d[1])
self.data += [[int(d[-1]), int(d[1]), data]]
# rank_zero_info(self.data)
if args.my_qa_mask > 0:
self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document')
# self.data_pile = MMapIndexedDataset('/fsx/pile_deduped/pile_0.87_deduped_text_document')
self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
if args.my_pile_stage > 0:
# assert self.data_size == 332115325534 and self.vocab_size == 50277
self.samples_per_epoch = args.epoch_steps * args.real_bsz
assert self.samples_per_epoch == 40320
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
dataset_slot = self.data_size // args.ctx_len
if args.my_pile_stage != 4:
assert MaybeIsPrime(args.magic_prime)
assert args.magic_prime % 3 == 2
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data)
rank_zero_info(f"Data has {self.data_size} tokens.")
elif args.data_type == "uint16":
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
self.vocab_size = args.vocab_size
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = self.data.shape[0]
rank_zero_info(f"Data has {self.data_size} samples.")
elif args.data_type == "wds_img":
self.vocab_size = -1
self.data_size = -1
self.data = None
self.error_count = 0
else:
if args.data_type == "dummy":
rank_zero_info("Building dummy data...")
self.data = ""
for i in range(100000):
aa = (i) % 10000
bb = (i * i) % 10000
cc = aa + bb
self.data += f".{aa}+{bb}={cc}."
else:
self.data = open(args.data_file, "r", encoding=args.data_type).read()
rank_zero_info("Building token list...")
unique = sorted(list(set(self.data)))
self.vocab_size = len(unique)
# rank_zero_info()
# for u in unique:
# print(u, end=' ')
# rank_zero_info('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
self.data_size = len(self.data)
rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
def __len__(self):
return self.args.epoch_steps * self.args.micro_bsz
def __getitem__(self, idx):
args = self.args
rank = self.global_rank
epoch = self.real_epoch
world_size = self.world_size
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
if args.data_type == "wds_img":
def init_wds(self, bias=0):
def identity(x):
return x
import webdataset as wds
import torchvision.transforms as transforms
# img_transform = transforms.Compose(
# [transforms.CenterCrop(256)]
# )
img_transform = transforms.Compose([
transforms.CenterCrop(512),
transforms.Resize((args.my_img_size))
])
self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
for pp in self.data_raw.pipeline:
if 'Resampled' in str(pp):
pp.deterministic = True
def worker_seed():
return rank*100000+epoch+bias*1e9
pp.worker_seed = worker_seed
self.data = iter(self.data_raw)
# print(f"WebDataset loaded for rank {rank} epoch {epoch}")
if self.data == None:
init_wds(self)
trial = 0
while trial < 10:
try:
dd = next(self.data) # jpg, json, txt
break
except:
print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]')
self.error_count += 1
init_wds(self, self.error_count)
trial += 1
pass
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}")
# with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp:
# tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n")
return dd[0], dd[2]
else:
if args.data_type == "uint16":
i = np.random.randint(0, self.data_size-1)
dix = self.data[i]
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
else:
ctx_len = args.ctx_len
req_len = ctx_len + 1
magic_prime = args.magic_prime
data = self.data
if args.my_pile_stage > 0 and args.my_pile_stage != 4:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
if args.my_qa_mask > 0:
ii_orig = ii
if ii % 2 == 0:
ii = -1
data = self.data_pile
else:
ii = ii // 2
if ii < 0:
i = np.random.randint(0, self.data_pile_size - req_len)
else:
factor = (math.sqrt(5) - 1) / 2
factor = int(magic_prime * factor)
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
i = i + args.my_pile_shift
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
elif args.my_pile_stage == 4:
# cheat: pick a random spot in dataset
if args.my_pile_version == 1:
i = np.random.randint(0, self.data_size - req_len)
else:
i = np.random.randint(0, self.data_size)
else:
# cheat: pick a random spot in dataset
i = np.random.randint(0, self.data_size - req_len)
if args.data_type == "binidx":
if args.my_pile_version == 1:
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
else:
# self.data : cutoff, chunk_count, data
for j in range(len(data)):
if i < data[j][0]:
ii = i
i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1]
dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int)
# print(ii, j, i)
break
elif args.data_type == "numpy":
dix = data[i : i + req_len]
else:
dix = [self.stoi[s] for s in data[i : i + req_len]]
if args.my_qa_mask == 1:
if data == self.data_pile:
z = [1] * ctx_len
else:
z = [0] * ctx_len
z_sum = 0
isGood = False
for i in range(3, ctx_len):
if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
isGood = True
if dix[i] == 0:
isGood = False
if isGood:
z[i] = 1
z_sum += 1
if z_sum == 0:
z = [1] * ctx_len
i = np.random.randint(0, self.data_pile_size - req_len)
dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
z = torch.tensor(z, dtype=torch.bfloat16)
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
# if ii_orig < 50:
# # if rank == 1:
# print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
# else:
# exit(0)
if args.my_qa_mask == 1:
return x, y, z
return x, y

@ -1,610 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os, math, gc, importlib
import torch
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
import torch.nn as nn
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
if importlib.util.find_spec('deepspeed'):
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
try:
print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
except:
os.environ["RWKV_MY_TESTING"] = ''
def __nop(ob):
return ob
MyModule = nn.Module
MyFunction = __nop
if os.environ["RWKV_JIT_ON"] == "1":
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
########################################################################################################
# CUDA Kernel
########################################################################################################
T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
w = -torch.exp(w.float().contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
ctx.save_for_backward(w, u, k, v, y)
return y
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
w, u, k, v, y = ctx.saved_tensors
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw, gu, gk, gv)
else:
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
if "32" in os.environ["RWKV_FLOAT_MODE"]:
w = -torch.exp(w.contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
else:
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
ctx.save_for_backward(w, u, k, v, y)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
return y
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
return y.half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
return y.bfloat16()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
w, u, k, v, y = ctx.saved_tensors
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
return (None, None, None, gw, gu, gk, gv)
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w, u, k, v)
########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
class RWKV_TimeMix(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.ctx_len = args.ctx_len
self.n_embd = args.n_embd
with torch.no_grad(): # fancy init
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
# fancy time_decay
decay_speed = torch.ones(args.dim_att)
for h in range(args.dim_att):
decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
self.time_decay = nn.Parameter(decay_speed)
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
# fancy time_first
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5
self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)
# fancy time_mix
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
if 'a' in os.environ["RWKV_MY_TESTING"]:
self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
d_qkv = args.n_embd // 16
self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
self.vv = nn.Linear(args.n_embd, d_qkv, bias=False)
self.oo = nn.Linear(d_qkv, args.n_embd, bias=False)
with torch.no_grad():
self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
if 'a' not in os.environ["RWKV_MY_TESTING"]:
@MyFunction
def jit_func(self, x):
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)
return sr, k, v
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v = self.jit_func(x)
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
return self.output(rwkv)
if 'a' in os.environ["RWKV_MY_TESTING"]:
@MyFunction
def QKV(self, q, k, v):
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.att_mask == 0, float('-inf'))
att = F.softmax(att, dim = -1)
x = att @ v
return x
@MyFunction
def jit_funcQKV(self, x):
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq)
xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk)
xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv)
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)
qq = self.qq(xqq)
kk = self.kk(xkk)
vv = self.vv(xvv)
return sr, k, v, qq, kk, vv
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
return rwkv
########################################################################################################
class RWKV_ChannelMix(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
@MyFunction
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
return torch.sigmoid(self.receptance(xr)) * kv
class MishGLU(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad():
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
x = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
x[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
@MyFunction
def forward(self, x):
xx = self.time_shift(x)
xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
a = self.aa(xa)
b = self.bb(xb)
return self.value(a * F.mish(b))
########################################################################################################
# The RWKV Model with our blocks
########################################################################################################
class Block(nn.Module):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(args.n_embd)
self.ln2 = nn.LayerNorm(args.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(args.n_embd)
if args.my_pos_emb > 0:
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
if self.layer_id == 0 and self.args.pre_ffn > 0:
self.ffnPre = RWKV_ChannelMix(args, 0)
else:
self.att = RWKV_TimeMix(args, layer_id)
if 'g' in os.environ["RWKV_MY_TESTING"]:
self.ffn = MishGLU(args, layer_id)
else:
self.ffn = RWKV_ChannelMix(args, layer_id)
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
self.tiny_ln = nn.LayerNorm(args.n_embd)
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
def forward(self, x, x_emb=None):
args = self.args
B, T, C = x.size()
if self.layer_id == 0:
x = self.ln0(x)
if args.my_pos_emb > 0:
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
x = x + pos_emb
if self.layer_id == 0 and args.pre_ffn > 0:
x = x + self.ffnPre(self.ln1(x))
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
xx = self.tiny_ln(x)
q = self.tiny_q(xx)[:, :T, :]
k = self.tiny_k(xx)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
x = x + c @ self.tiny_v(x_emb)
return x
class L2Wrap(torch.autograd.Function):
@staticmethod
def forward(ctx, loss, y):
ctx.save_for_backward(y)
return loss
@staticmethod
def backward(ctx, grad_output):
y = ctx.saved_tensors[0]
# to encourage the logits to be close to 0
factor = 1e-4 / (y.shape[0] * y.shape[1])
maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y)
gy.scatter_(-1, ids, maxx * factor)
return (grad_output, gy)
class RWKV(pl.LightningModule):
def __init__(self, args):
super().__init__()
self.args = args
if not hasattr(args, 'dim_att'):
args.dim_att = args.n_embd
if not hasattr(args, 'dim_ffn'):
args.dim_ffn = args.n_embd * 4
if not hasattr(args, 'tiny_att_layer'):
args.tiny_att_layer = -1
if not hasattr(args, 'tiny_att_dim'):
args.tiny_att_dim = -1
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
self.ln_out = nn.LayerNorm(args.n_embd)
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
if args.head_qk > 0:
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
def configure_optimizers(self):
args = self.args
if args.layerwise_lr > 0:
lr_1x = set()
lr_2x = set()
lr_3x = set()
for n, p in self.named_parameters():
if "time_mix" in n:
if args.my_pile_stage == 2:
lr_2x.add(n)
else:
lr_1x.add(n)
elif "time_decay" in n:
if args.my_pile_stage == 2:
lr_3x.add(n)
else:
lr_2x.add(n)
elif "time_first" in n:
lr_3x.add(n)
else:
lr_1x.add(n)
lr_1x = sorted(list(lr_1x))
lr_2x = sorted(list(lr_2x))
lr_3x = sorted(list(lr_3x))
# print('1x', lr_1x)
# print('2x', lr_2x)
# print('3x', lr_3x)
param_dict = {n: p for n, p in self.named_parameters()}
if args.my_pile_stage == 2:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
]
else:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
]
else:
optim_groups = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
]
if self.deepspeed_offload:
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
@property
def deepspeed_offload(self) -> bool:
strategy = self.trainer.strategy
if isinstance(strategy, DeepSpeedStrategy):
cfg = strategy.config["zero_optimization"]
return cfg.get("offload_optimizer") or cfg.get("offload_param")
return False
def forward(self, idx):
args = self.args
B, T = idx.size()
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
x = self.emb(idx)
x_emb = x
if args.tiny_att_dim > 0:
for block in self.blocks:
if args.grad_cp == 1:
x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
else:
x = block(x, x_emb)
else:
for block in self.blocks:
if args.grad_cp == 1:
x = deepspeed.checkpointing.checkpoint(block, x)
else:
x = block(x)
x = self.ln_out(x)
if args.head_qk > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
x = self.head(x) + c
else:
x = self.head(x)
return x
def training_step(self, batch, batch_idx):
args = self.args
if args.my_qa_mask != 1:
idx, targets = batch
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
else:
idx, targets, mask = batch
mask = mask.view(-1)
sum_mask = torch.sum(mask).item()
# if sum_mask == 0:
# return torch.tensor([0.0], requires_grad=True)
logits = self(idx)
if sum_mask == mask.shape[0]:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
# print('rank', self.global_rank, 'loss', loss.item())
else:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
# loss_raw = loss
loss = torch.sum(loss * mask) / sum_mask
# torch.set_printoptions(threshold=10000)
# if True: #self.global_rank == 1:
# tmp = ''
# sss = 0
# ccc = 0
# for i in range(mask.shape[0]):
# if mask[i] > 0:
# tmp += str(idx.view(-1)[i].item()) + ','
# sss += loss_raw.view(-1)[i].float().item()
# ccc += 1
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
return L2Wrap.apply(loss, logits)
def training_step_end(self, batch_parts):
all = self.all_gather(batch_parts)
if self.trainer.is_global_zero:
self.trainer.my_loss_all = all
def generate_init_weight(self):
print(
f"""
############################################################################
#
# Init model weight (slow for large models)...
#
############################################################################
"""
)
m = {}
for n in self.state_dict():
p = self.state_dict()[n]
shape = p.shape
gain = 1.0
scale = 1.0
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
m[n] = p
else:
if n == "emb.weight":
scale = -1 * self.args.lr_init
else:
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
if kk in n:
scale = 0
if n == "head.weight":
scale = 0.5
if "head_k." in n:
scale = 0.1
if "head_q." in n:
scale = 0
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
if self.args.accelerator.upper() == "GPU":
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
else:
m[n] = torch.empty((shape[0], shape[1]))
if scale == 0:
nn.init.zeros_(m[n])
elif scale < 0:
nn.init.uniform_(m[n], a=scale, b=-scale)
else:
nn.init.orthogonal_(m[n], gain=gain * scale)
m[n] = m[n].cpu()
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
m[n] = m[n].half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
m[n] = m[n].bfloat16()
# if n == "emb.weight":
# print(m[n])
gc.collect()
torch.cuda.empty_cache()
return m

@ -1,446 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import os, math, gc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as vision
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from pytorch_msssim import MS_SSIM
def __nop(ob):
return ob
MyModule = torch.jit.ScriptModule
# MyFunction = __nop
MyFunction = torch.jit.script_method
import clip
from transformers import CLIPModel
class L2pooling(nn.Module):
def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
super(L2pooling, self).__init__()
self.padding = (filter_size - 2) // 2
self.stride = stride
self.channels = channels
a = np.hanning(filter_size)[1:-1]
g = torch.Tensor(a[:, None] * a[None, :])
g = g / torch.sum(g)
self.register_buffer(
"filter", g[None, None, :, :].repeat((self.channels, 1, 1, 1))
)
def forward(self, input):
input = input**2
out = F.conv2d(
input,
self.filter,
stride=self.stride,
padding=self.padding,
groups=input.shape[1],
)
return (out + 1e-12).sqrt()
class DISTS(torch.nn.Module):
def __init__(self, load_weights=True):
super(DISTS, self).__init__()
vgg_pretrained_features = vision.models.vgg16(
weights="VGG16_Weights.IMAGENET1K_V1"
).features
self.stage1 = torch.nn.Sequential()
self.stage2 = torch.nn.Sequential()
self.stage3 = torch.nn.Sequential()
self.stage4 = torch.nn.Sequential()
self.stage5 = torch.nn.Sequential()
for x in range(0, 4):
self.stage1.add_module(str(x), vgg_pretrained_features[x])
self.stage2.add_module(str(4), L2pooling(channels=64))
for x in range(5, 9):
self.stage2.add_module(str(x), vgg_pretrained_features[x])
self.stage3.add_module(str(9), L2pooling(channels=128))
for x in range(10, 16):
self.stage3.add_module(str(x), vgg_pretrained_features[x])
self.stage4.add_module(str(16), L2pooling(channels=256))
for x in range(17, 23):
self.stage4.add_module(str(x), vgg_pretrained_features[x])
self.stage5.add_module(str(23), L2pooling(channels=512))
for x in range(24, 30):
self.stage5.add_module(str(x), vgg_pretrained_features[x])
self.register_buffer(
"mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1)
)
self.register_buffer(
"std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1)
)
self.chns = [3, 64, 128, 256, 512, 512]
self.register_buffer(
"alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1))
)
self.register_buffer("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
self.alpha.data.normal_(0.1, 0.01)
self.beta.data.normal_(0.1, 0.01)
weights = torch.load("test/DISTS_weights.pt")
self.alpha.data = weights["alpha"]
self.beta.data = weights["beta"]
for param in self.parameters():
param.requires_grad = False
def forward_once(self, x):
h = (x - self.mean) / self.std
h = self.stage1(h)
h_relu1_2 = h
h = self.stage2(h)
h_relu2_2 = h
h = self.stage3(h)
h_relu3_3 = h
h = self.stage4(h)
h_relu4_3 = h
h = self.stage5(h)
h_relu5_3 = h
return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
def forward(self, x, y, require_grad=False, batch_average=False):
if require_grad:
feats0 = self.forward_once(x)
feats1 = self.forward_once(y)
else:
with torch.no_grad():
feats0 = self.forward_once(x)
feats1 = self.forward_once(y)
dist1 = 0
dist2 = 0
c1 = 1e-6
c2 = 1e-6
w_sum = self.alpha.sum() + self.beta.sum()
alpha = torch.split(self.alpha / w_sum, self.chns, dim=1)
beta = torch.split(self.beta / w_sum, self.chns, dim=1)
for k in range(len(self.chns)):
x_mean = feats0[k].mean([2, 3], keepdim=True)
y_mean = feats1[k].mean([2, 3], keepdim=True)
S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1)
dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True)
x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True)
y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True)
xy_cov = (feats0[k] * feats1[k]).mean(
[2, 3], keepdim=True
) - x_mean * y_mean
S2 = (2 * xy_cov + c2) / (x_var + y_var + c2)
dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True)
score = 1 - (dist1 + dist2).squeeze()
if batch_average:
return score.mean()
else:
return score
class ToBinary(torch.autograd.Function):
@staticmethod
def forward(ctx, x):#, noise_scale):
# if noise_scale > 0:
# noise_min = 0.5 - noise_scale / 2
# noise_max = 0.5 + noise_scale / 2
# return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max))
# else:
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone()#, None
########################################################################################################
class R_ENCODER(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
dd = 8
self.Bxx = nn.BatchNorm2d(dd*64)
self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*4)
self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*64)
self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.B21 = nn.BatchNorm2d(dd*64)
# self.C24 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C25 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.C26 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C27 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
@MyFunction
def forward(self, img):
ACT = F.mish
x = self.CIN(img)
xx = self.Bxx(F.pixel_unshuffle(x, 8))
x = x + self.Cx1(ACT(self.Cx0(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
x = x + self.C13(ACT(self.C12(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
# x = x + self.C25(ACT(self.C24(ACT(self.B21(x)))))
# x = x + self.C27(ACT(self.C26(x)))
x = self.COUT(x + xx)
return torch.sigmoid(x)
########################################################################################################
class R_DECODER(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
dd = 8
self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*64)
self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.B01 = nn.BatchNorm2d(dd*64)
# self.C04 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C05 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.C06 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C07 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*4)
self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
@MyFunction
def forward(self, code):
ACT = F.mish
x = self.CIN(code)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
# x = x + self.C05(ACT(self.C04(ACT(self.B01(x)))))
# x = x + self.C07(ACT(self.C06(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
x = x + self.C13(ACT(self.C12(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.Cx1(ACT(self.Cx0(x)))
x = self.COUT(x)
return torch.sigmoid(x)
########################################################################################################`
def cosine_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return 1 - torch.einsum('ij,ij->i',[x,y])
class RWKV_IMG(pl.LightningModule):
def __init__(self, args):
super().__init__()
self.args = args
self.encoder = R_ENCODER(args)
self.decoder = R_DECODER(args)
self.clip_model = None
clip_name = args.my_img_clip
if clip_name == 'B32':
clip_name = 'ViT-B/32'
elif clip_name == 'B16':
clip_name = 'ViT-B/16'
elif clip_name == 'L14':
clip_name = 'ViT-L/14'
elif clip_name == 'OB32':
clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
self.clip_model = CLIPModel.from_pretrained(clip_name)
self.clip_model.encode_image = self.clip_model.get_image_features
if self.clip_model == None:
self.clip_model, _ = clip.load(clip_name, jit = True)
self.register_buffer(
"clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1)
)
self.register_buffer(
"clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1)
)
for n, p in self.named_parameters():
if 'clip_model' in n:
p.requires_grad = False
self.loss_dists = DISTS()
# self.loss_ssim = MS_SSIM(data_range=1, size_average=True, channel=3)
def configure_optimizers(self):
args = self.args
optim_groups = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
]
if self.deepspeed_offload:
return DeepSpeedCPUAdam(
optim_groups,
lr=self.args.lr_init,
betas=self.args.betas,
eps=self.args.adam_eps,
bias_correction=True,
adamw_mode=False,
weight_decay=0,
amsgrad=False,
)
return FusedAdam(
optim_groups,
lr=self.args.lr_init,
betas=self.args.betas,
eps=self.args.adam_eps,
bias_correction=True,
adam_w_mode=False,
weight_decay=0,
amsgrad=False,
)
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
@property
def deepspeed_offload(self) -> bool:
strategy = self.trainer.strategy
if isinstance(strategy, DeepSpeedStrategy):
config = strategy.config["zero_optimization"]
return config.get("offload_optimizer") or config.get("offload_param")
return False
def forward(self, img):
z = self.encoder(img)
z = ToBinary.apply(z)#, self.args.my_img_noise_scale)
out = self.decoder(z)
return out
def training_step(self, batch, batch_idx):
args = self.args
img, txt = batch
out = self(img)
if self.trainer.is_global_zero:
if (self.trainer.global_step + 1) % (100 * int(args.devices)) == 0:
img_dir = f"test/image_model/{args.run_name}"
if not os.path.exists(img_dir):
os.makedirs(img_dir)
vision.utils.save_image(
img[:4], f"{img_dir}/{self.trainer.global_step}-src.jpg"#, padding=0
)
vision.utils.save_image(
out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg"#, padding=0
)
# loss_ssim = 1 - self.loss_ssim(out, img)
loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True)
iii = self.clip_model.encode_image((img - self.clip_mean) / self.clip_std)
ooo = self.clip_model.encode_image((out - self.clip_mean) / self.clip_std)
loss_clip = torch.mean(cosine_loss(iii, ooo))
if args.my_img_l1_scale > 0:
loss_l1 = F.l1_loss(out, img)
return loss_dists + loss_clip * args.my_img_clip_scale + loss_l1 * args.my_img_l1_scale
else:
return loss_dists + loss_clip * args.my_img_clip_scale
def training_step_end(self, batch_parts):
all = self.all_gather(batch_parts)
if self.trainer.is_global_zero:
self.trainer.my_loss_all = all
def generate_init_weight(self):
print(
f"""
############################################################################
#
# Init model weight (slow for large models)...
#
############################################################################
"""
)
m = {}
for n in self.state_dict():
scale = 1
p = self.state_dict()[n]
shape = p.shape
ss = n.split('.')
# if ss[0] in ['encoder', 'decoder']:
# if ss[2] == 'bias':
# scale = 0
# # elif n == 'encoder.CIN.weight':
# # nn.init.dirac_(p)
# else:
# try:
# if ss[1][0] == 'C' and (int(ss[1][2]) % 2 == 1):
# scale = 0
# except:
# pass
# m[n] = p * scale
m[n] = p
m[n] = m[n].cpu()
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
m[n] = m[n].half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
m[n] = m[n].bfloat16()
gc.collect()
torch.cuda.empty_cache()
return m

@ -1,237 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import types
import torch
import math, os, gc
from torch.nn import functional as F
import torch.nn as nn
from typing import List, Dict
MyModule = nn.Module
def __nop(ob):
return ob
MyFunction = __nop
# # try torchdynamo
# import torchdynamo
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
# try torch jit --> faster for fp32, slower for fp16 (why?)
if os.environ["RWKV_JIT_ON"] == "1":
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n')
DEBUG_TIME = False # True False - show trained time-coeffs
RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer
############################################################################################################
class RWKV_RNN(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.FLOAT_MODE = args.FLOAT_MODE
self.RUN_DEVICE = args.RUN_DEVICE
with torch.no_grad():
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
# refine weights and send to correct device
keys = list(w.keys())
if 'pos_emb_x' in keys:
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:]
keys = list(w.keys())
print_need_newline = False
for x in keys:
block_id = 0
if 'blocks.' in x:
block_id = int(x.split('.')[1])
if 'att.output.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
if 'ffn.value.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
if '.time_' in x:
w[x] = w[x].squeeze()
if DEBUG_TIME:
print(x, w[x].numpy())
if '.time_decay' in x:
w[x] = w[x].float()
w[x] = -torch.exp(w[x])
elif '.time_first' in x:
w[x] = w[x].float()
else:
if self.FLOAT_MODE == "fp32":
w[x] = w[x].float()
elif self.FLOAT_MODE == "bf16":
w[x] = w[x].bfloat16()
elif self.FLOAT_MODE == "fp16":
w[x] = w[x].half()
w[x].requires_grad = False
if args.RUN_DEVICE == 'cuda' and x != 'emb.weight':
w[x] = w[x].cuda()
if ('blocks.' not in x) or ('blocks.0.' in x):
if print_need_newline:
print('\n', end = '')
print_need_newline = False
print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device)
else:
print_need_newline = True
print('.', end = '', flush = True)
# store weights in self.w
keys = list(w.keys())
self.w = types.SimpleNamespace()
for x in keys:
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.eval()
gc.collect()
torch.cuda.empty_cache()
def LN(self, x, w):
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
# state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
@MyFunction
def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+0] = x.float()
elif self.FLOAT_MODE == "fp16":
xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r)
state[5*i+0] = x.float()
else:
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
state[5*i+0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk))
kv = vw @ k
return r * kv
@MyFunction
def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+1] = x.float()
elif self.FLOAT_MODE == "fp16":
xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r)
state[5*i+1] = x.float()
else:
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
state[5*i+1] = x
r = torch.sigmoid(rw @ xr)
k = kw @ xk
v = vw @ xv
if '16' in self.FLOAT_MODE:
kk = k.float()
vv = v.float()
else:
kk = k
vv = v
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = time_first + kk
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * vv
b = e1 * bb + e2
ww = pp + time_decay
p = torch.maximum(ww, kk)
e1 = torch.exp(ww - p)
e2 = torch.exp(kk - p)
state[5*i+2] = e1 * aa + e2 * vv
state[5*i+3] = e1 * bb + e2
state[5*i+4] = p
if self.FLOAT_MODE == "bf16":
wkv = (a / b).type(torch.bfloat16)
elif self.FLOAT_MODE == "fp16":
wkv = (a / b).half()
else:
wkv = a / b
return ow @ (r * wkv)
def forward(self, ctx, state, preprocess_only = False):
with torch.no_grad():
w = self.w
args = self.args
x = w.emb.weight[ctx[-1]]
if self.RUN_DEVICE == 'cuda':
x = x.cuda()
try:
pos_emb = w.pos_emb[len(ctx)-1]
x = x + pos_emb
except:
pass
if state == None:
state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE)
for i in range(args.n_layer):
state[5*i+4] -= 1e30
for i in range(args.n_layer):
if i == 0:
x = self.LN(x, w.blocks[i].ln0)
ww = w.blocks[i].att
x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i,
ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay,
ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight)
ww = w.blocks[i].ffn
x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i,
ww.time_mix_k, ww.time_mix_r,
ww.key.weight, ww.value.weight, ww.receptance.weight)
if (i+1) % RWKV_RESCALE_LAYER == 0:
x = x / 2
if preprocess_only:
return state
x = self.LN(x, w.ln_out)
x = w.head.weight @ x
return x.float(), state

@ -1,190 +0,0 @@
import os, math, time, datetime, subprocess
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
def my_save(dd, ff):
if '14b-run1' not in ff:
torch.save(dd, ff)
else:
fn = ff.split('/')[-1]
fff = '/dev/shm/' + fn
torch.save(dd, fff)
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
class train_callback(pl.Callback):
def __init__(self, args):
super().__init__()
self.args = args
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
args = self.args
# if args.cuda_cleanup > 0:
# torch.cuda.empty_cache()
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
# LR schedule
w_step = args.warmup_steps
if args.lr_final == args.lr_init or args.epoch_count == 0:
lr = args.lr_init
else:
decay_step = real_step - args.my_pile_edecay * args.epoch_steps
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
progress = (decay_step - w_step + 1) / (decay_total - w_step)
progress = min(1, max(0, progress))
if args.lr_final == 0 or args.lr_init == 0: # linear decay
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
if trainer.global_step < w_step:
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
# if trainer.is_global_zero:
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
for param_group in trainer.optimizers[0].param_groups:
if args.layerwise_lr > 0:
param_group["lr"] = lr * param_group["my_lr_scale"]
# print(param_group["lr"], param_group["my_lr_scale"])
else:
param_group["lr"] = lr
trainer.my_lr = lr
# rank_zero_info(f"{real_step} {lr}")
if trainer.global_step == 0:
if trainer.is_global_zero: # logging
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
try:
print(f"\n{trainer.strategy.config}\n")
trainer.my_log.write(f"{trainer.strategy.config}\n")
except:
pass
trainer.my_log.flush()
if len(args.wandb) > 0:
print("Login to wandb...")
import wandb
wandb.init(
project=args.wandb,
name=args.run_name + " " + args.my_timestamp,
config=args,
save_code=False,
)
trainer.my_wandb = wandb
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
args = self.args
if trainer.is_global_zero: # logging
t_now = time.time_ns()
token_per_step = args.ctx_len * args.real_bsz
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
kt_s = 0
try:
t_cost = (t_now - trainer.my_time_ns) / 1e9
kt_s = token_per_step / t_cost / 1000
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
except:
pass
trainer.my_time_ns = t_now
trainer.my_loss = trainer.my_loss_all.float().mean().item()
trainer.my_loss_sum += trainer.my_loss
trainer.my_loss_count += 1
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
# self.log("s", real_step, prog_bar=True, on_step=True)
if len(args.wandb) > 0:
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
if kt_s > 0:
lll["kt/s"] = kt_s
trainer.my_wandb.log(lll, step=int(real_step))
if args.magic_prime > 0:
expand_factor = 2 if args.my_qa_mask > 0 else 1
if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1:
to_save_dict = pl_module.state_dict()
my_save(
to_save_dict,
f"{args.proj_dir}/rwkv-final.pth",
)
def on_train_epoch_start(self, trainer, pl_module):
args = self.args
dataset = trainer.train_dataloader.dataset.datasets
assert "MyDataset" in str(dataset)
dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
def on_train_epoch_end(self, trainer, pl_module):
args = self.args
if trainer.is_global_zero: # logging & save state_dict
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
if args.data_type == 'wds_img':
raw_dict = pl_module.state_dict()
to_save_dict = {}
for k in raw_dict:
if k.startswith('encoder.') or k.startswith('decoder.'):
to_save_dict[k] = raw_dict[k]
else:
to_save_dict = pl_module.state_dict()
try:
my_save(
to_save_dict,
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
)
except Exception as e:
print('Error\n\n', e, '\n\n')
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
trainer.my_log.flush()
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
@rank_zero_only
def generate_init_weight(model, init_weight_name):
mm = model.generate_init_weight()
if model.args.my_pile_stage == 1:
if len(model.args.load_model) > 0:
print(f"Combine weights from {model.args.load_model}...")
load_dict = torch.load(model.args.load_model, map_location="cpu")
for k in load_dict:
assert k in mm
src = load_dict[k]
try:
mm[k] = src.reshape(mm[k].shape)
except:
tmp = mm[k].squeeze().clone()
print(k, src.shape, '-->', mm[k].shape)
ss = src.shape[0]
dd = tmp.shape[0]
for i in range(dd):
pos = i / dd * ss
if pos >= ss - 1:
tmp[i] = src[ss-1]
else:
p0 = int(math.floor(pos))
ii = pos - p0
tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii)
mm[k] = tmp.reshape(mm[k].shape)
sss = src.squeeze().float().cpu().numpy()
print(sss[:10], '...', sss[-10:])
mmm = mm[k].squeeze().float().cpu().numpy()
print(mmm[:10], '...', mmm[-10:])
print(f"Save to {init_weight_name}...")
torch.save(mm, init_weight_name)
if model.args.my_pile_stage == 1:
print("Done. Now go for stage 2.")
exit(0)

@ -1,130 +0,0 @@
import json, time, random, os
import numpy as np
import torch
from torch.nn import functional as F
time_slot = {}
time_ref = time.time_ns()
def record_time(name):
if name not in time_slot:
time_slot[name] = 1e20
tt = (time.time_ns() - time_ref) / 1e9
if tt < time_slot[name]:
time_slot[name] = tt
class TOKENIZER():
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
if 'list' in str(type(WORD_NAME)):
self.charMode = False
if WORD_NAME[0] == WORD_NAME[1]:
from transformers import PreTrainedTokenizerFast
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
else:
from transformers import GPT2TokenizerFast
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
self.vocab_size = len(self.tokenizer)
else:
self.charMode = True
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
self.word_table = json.load(result_file)
self.vocab_size = len(self.word_table)
self.stoi = {v: int(k) for k, v in self.word_table.items()}
self.itos = {int(k): v for k, v in self.word_table.items()}
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
# out[self.UNKNOWN_CHAR] = -float('Inf')
lastChar = int(x[-1])
probs = F.softmax(out, dim=-1)
if self.charMode:
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
else:
top_p = top_p_usual
if os.environ["RWKV_RUN_DEVICE"] == "cpu":
probs = probs.numpy()
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return out
else:
sorted_probs = torch.sort(probs, descending=True)[0]
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0]
return out
def MaybeIsPrime(number):
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
return True
else:
return False
def FermatPrimalityTest(number):
if number > 1:
for time in range(3):
randomNumber = random.randint(2, number) - 1
if pow(randomNumber, number - 1, number) != 1:
return False
return True
else:
return False
def MillerRabinPrimalityTest(number):
if number == 2:
return True
elif number == 1 or number % 2 == 0:
return False
oddPartOfNumber = number - 1
timesTwoDividNumber = 0
while oddPartOfNumber % 2 == 0:
oddPartOfNumber = oddPartOfNumber // 2
timesTwoDividNumber = timesTwoDividNumber + 1
for time in range(3):
while True:
randomNumber = random.randint(2, number) - 1
if randomNumber != 0 and randomNumber != 1:
break
randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
iterationNumber = 1
while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1):
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
iterationNumber = iterationNumber + 1
if randomNumberWithPower != (number - 1):
return False
return True

@ -1,349 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
if __name__ == "__main__":
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
rank_zero_info("########## work in progress ##########")
########################################################################################################
#
# example: train a simple L12-D768 RWKV on dummy data
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "" --data_type "dummy" --vocab_size 0 \
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: train a simple L6-D512 RWKV from scratch on enwik8
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
parser = ArgumentParser()
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--random_seed", default="-1", type=int)
parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--dim_att", default=0, type=int)
parser.add_argument("--dim_ffn", default=0, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float)
parser.add_argument("--warmup_steps", default=-1, type=int) # try 50 if you load a model
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
parser.add_argument("--my_pile_edecay", default=0, type=int)
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser.add_argument("--my_img_version", default=0, type=str)
parser.add_argument("--my_img_size", default=0, type=int)
parser.add_argument("--my_img_bit", default=0, type=int)
parser.add_argument("--my_img_clip", default='x', type=str)
parser.add_argument("--my_img_clip_scale", default=1, type=float)
parser.add_argument("--my_img_l1_scale", default=0, type=float)
parser.add_argument("--my_img_encoder", default='x', type=str)
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser.add_argument("--my_sample_len", default=0, type=int)
parser.add_argument("--my_ffn_shift", default=1, type=int)
parser.add_argument("--my_att_shift", default=1, type=int)
parser.add_argument("--my_pos_emb", default=0, type=int)
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
########################################################################################################
import os, warnings, math, datetime, sys, time, importlib
import numpy as np
import torch
from torch.utils.data import DataLoader
if "deepspeed" in args.strategy:
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
seed_everything(args.random_seed)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
# os.environ["WDS_SHOW_SEED"] = "1"
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
args.enable_checkpointing = False
args.replace_sampler_ddp = False
args.logger = False
args.gradient_clip_val = 1.0
args.num_sanity_val_steps = 0
args.check_val_every_n_epoch = int(1e20)
args.log_every_n_steps = int(1e20)
args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
os.environ["RWKV_MY_TESTING"] = args.my_testing
if args.dim_att <= 0:
args.dim_att = args.n_embd
if args.dim_ffn <= 0:
args.dim_ffn = args.n_embd * 4
if args.data_type == "wds_img":
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
else:
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
if not os.path.exists(args.proj_dir):
os.makedirs(args.proj_dir)
if args.my_pile_stage > 0:
magic_prime_bak = args.magic_prime
if args.my_pile_version == 1:
if args.ctx_len == 1024:
args.magic_prime = 324331313
args.epoch_count = 8043
elif args.ctx_len == 2048:
args.magic_prime = 162165671
args.epoch_count = 4021
elif args.ctx_len == 4096:
args.magic_prime = 81082817
args.epoch_count = 2010
elif args.ctx_len == 8192:
args.magic_prime = 40541399
args.epoch_count = 1005
else:
if args.ctx_len == 1024:
args.magic_prime = 1694947181
args.epoch_count = 42036
elif args.ctx_len == 2048:
args.magic_prime = 847473509
args.epoch_count = 21017
elif args.ctx_len == 4096:
args.magic_prime = 423736637
args.epoch_count = 10508
elif args.ctx_len == 6144:
args.magic_prime = 282491051
args.epoch_count = 7005
elif args.ctx_len == 8192:
args.magic_prime = 211868243
args.epoch_count = 5253
if args.my_pile_shift < 0:
args.my_pile_shift = 0
if magic_prime_bak > 0:
args.magic_prime = magic_prime_bak
args.epoch_steps = 40320 // args.real_bsz
assert args.epoch_steps * args.real_bsz == 40320
if args.my_pile_stage == 2:
assert args.lr_final == args.lr_init
if args.my_pile_stage >= 2: # find latest saved model
list_p = []
for p in os.listdir(args.proj_dir):
if p.startswith("rwkv") and p.endswith(".pth"):
p = ((p.split("-"))[1].split("."))[0]
if p == "init":
p = -1
else:
p = int(p)
list_p += [p]
list_p.sort()
max_p = list_p[-1]
if len(list_p) > 1:
args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted
if max_p == -1:
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
if args.warmup_steps < 0:
if args.my_pile_stage == 2:
args.warmup_steps = 10
else:
args.warmup_steps = 30
args.epoch_begin = max_p + 1
samples_per_epoch = args.epoch_steps * args.real_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len
rank_zero_info(
f"""
############################################################################
#
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
#
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
#
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
#
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
#
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
#
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
#
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
# Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
#
############################################################################
"""
)
rank_zero_info(str(vars(args)) + "\n")
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
if args.lr_final == 0 or args.lr_init == 0:
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision
if args.precision == "fp32":
for i in range(10):
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
os.environ["RWKV_JIT_ON"] = "1"
if "deepspeed_stage_3" in args.strategy:
os.environ["RWKV_JIT_ON"] = "0"
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
if args.precision == "fp32":
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
if "32" in args.precision:
args.precision = 32
elif args.precision == "fp16":
args.precision = 16
else:
args.precision = "bf16"
########################################################################################################
from src.trainer import train_callback, generate_init_weight
from src.dataset import MyDataset
train_data = MyDataset(args)
args.vocab_size = train_data.vocab_size
if args.data_type == 'wds_img':
from src.model_img import RWKV_IMG
model = RWKV_IMG(args)
else:
from src.model import RWKV
model = RWKV(args)
if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
generate_init_weight(model, init_weight_name) # save initial weights
args.load_model = init_weight_name
rank_zero_info(f"########## Loading {args.load_model}... ##########")
try:
load_dict = torch.load(args.load_model, map_location="cpu")
except:
rank_zero_info(f"Bad checkpoint {args.load_model}")
if args.my_pile_stage >= 2: # try again using another checkpoint
max_p = args.my_pile_prev_p
if max_p == -1:
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
args.epoch_begin = max_p + 1
rank_zero_info(f"Trying {args.load_model}")
load_dict = torch.load(args.load_model, map_location="cpu")
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in model.state_dict():
if k not in load_keys:
load_dict[k] = model.state_dict()[k]
model.load_state_dict(load_dict)
trainer = Trainer.from_argparse_args(
args,
callbacks=[train_callback(args)],
)
if trainer.global_rank == 0:
for n in model.state_dict():
shape = model.state_dict()[n].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
else:
print(f"{str(shape[0]).ljust(5)} {n}")
if "deepspeed" in args.strategy:
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
trainer.fit(model, data_loader)

@ -1,104 +0,0 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
# this is for verifying the results of different models and make sure they agree with each other
import os, sys, types
import numpy as np
import torch
np.set_printoptions(precision=4, suppress=True, linewidth=200)
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # bf16 or fp32
os.environ['RWKV_RUN_DEVICE'] = 'cuda' # currently model_train requires CUDA
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
TOKEN_MODE = 'pile'
if TOKEN_MODE == 'pile':
WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json']
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
n_layer = 32
n_embd = 2560
ctx_len = 1024
UNKNOWN_CHAR = None
from src.utils import TOKENIZER
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
if TOKEN_MODE == 'pile':
tokenizer.vocab_size = 50277
########################################################################################################
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_T_MAX"] = str(ctx_len)
from src.model_run import RWKV_RNN
from src.model import RWKV
args = types.SimpleNamespace()
args.vocab_size = tokenizer.vocab_size
args.ctx_len = ctx_len
args.n_embd = n_embd
args.n_layer = n_layer
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
model_train = RWKV(args).to(RUN_DEVICE)
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
model_train = model_train.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
model_train = model_train.bfloat16()
print('loading ' + MODEL_NAME)
m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu')
model_train.load_state_dict(m2)
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
model_train = model_train.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
model_train = model_train.bfloat16()
args.MODEL_NAME = MODEL_NAME
args.RUN_DEVICE = RUN_DEVICE
args.FLOAT_MODE = os.environ['RWKV_FLOAT_MODE']
model_rnn = RWKV_RNN(args)
########################################################################################################
print(f"\nVerifying {os.environ['RWKV_RUN_DEVICE']} {os.environ['RWKV_FLOAT_MODE']}")
# context = '\nIn a'
context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.'
if TOKEN_MODE == 'pile':
ctx = tokenizer.tokenizer.encode(context)
print(f'input len {len(ctx)} data {ctx}')
########################################################################################################
with torch.no_grad():
print('\nRWKV-train output')
out = model_train.forward(torch.tensor([ctx]).to(RUN_DEVICE))[0].detach().cpu().float().numpy()
print(out, '\n')
print('\nRWKV-RNN output')
state = None
out = None
src_len = len(ctx)
for i in range(src_len):
x = ctx[:i+1]
out, state = model_rnn.forward(x, state)
if i < 3 or i >= src_len - 3:
print(out.detach().cpu().numpy())
if i == 2:
print('...')

Binary file not shown.

Before

Width:  |  Height:  |  Size: 534 KiB

@ -63,20 +63,26 @@ class RWKV_TimeMix(nn.Module):
self.head_size = config.n_attn // config.n_head
with torch.no_grad(): # initial time_w curves for better convergence
ww = torch.ones(config.n_head, config.ctx_len)
curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
ww = torch.zeros(config.n_head, config.ctx_len)
curve = torch.tensor([0.9 ** (config.ctx_len - 1 - i) for i in range(config.ctx_len)])
curve = curve * 2 + 0.7
for h in range(config.n_head):
if h < config.n_head - 1:
decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
if config.n_head > 1:
mix_strength = 1 - 1.2 * h / (config.n_head - 1) # mix_strength from 1 to -0.2
else:
decay_speed = 0.0
ww[h] = torch.exp(curve * decay_speed)
# print('layer', layer_id, 'head', h, 'decay_speed', round(decay_speed, 4), ww[h][:5].numpy(), '...', ww[h][-5:].numpy())
mix_strength = 0.5
ww[h] = (1 - mix_strength) + curve * mix_strength
# special tweaks because of time_shift
ww[h][config.ctx_len - 3] = (ww[h][config.ctx_len - 3] * 2 + 1) / 3
ww[h][config.ctx_len - 2] = (ww[h][config.ctx_len - 2] * 1 + 2) / 3
ww[h][config.ctx_len - 1] = 1
# print(h, mix_strength, ww[h])
self.time_w = nn.Parameter(ww)
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
@ -84,8 +90,8 @@ class RWKV_TimeMix(nn.Module):
self.value = nn.Linear(config.n_embd, config.n_attn)
self.receptance = nn.Linear(config.n_embd, config.n_attn)
# if config.rwkv_tiny_attn > 0:
# self.tiny_att = RWKV_TinyAttn(config)
if config.rwkv_tiny_attn > 0:
self.tiny_att = RWKV_TinyAttn(config)
self.output = nn.Linear(config.n_attn, config.n_embd)
@ -101,10 +107,12 @@ class RWKV_TimeMix(nn.Module):
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT-1:] # w is now a circulant matrix
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
self.mask = self.mask[:T, :T]
w = w.masked_fill(self.mask == 0, 0)
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
# if hasattr(self, 'tiny_att'):
# tiny_att = self.tiny_att(x, self.mask)
if hasattr(self, 'tiny_att'):
tiny_att = self.tiny_att(x, self.mask)
k = self.key(x)
v = self.value(x)
@ -121,8 +129,8 @@ class RWKV_TimeMix(nn.Module):
rwkv = torch.sigmoid(r) * wkv / sum_k
rwkv = self.output(rwkv)
# if hasattr(self, 'tiny_att'):
# rwkv += tiny_att
if hasattr(self, 'tiny_att'):
rwkv += tiny_att
return rwkv * self.time_gamma[:T, :]
@ -434,12 +442,6 @@ class GPT(nn.Module):
self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.head_q = nn.Linear(config.n_embd, 256)
self.head_q.scale_init = 0.01
self.head_k = nn.Linear(config.n_embd, 256)
self.head_k.scale_init = 0.01
self.register_buffer("copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len
if self.config.model_type == 'RWKV':
@ -500,15 +502,8 @@ class GPT(nn.Module):
x = self.blocks(x)
x = self.ln_f(x)
q = self.head_q(x)[:,:T,:]
k = self.head_k(x)[:,:T,:]
c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
x = x * self.time_out[:, :T, :] # reduce confidence of early tokens
x = self.head(x) + c
x = self.head(x)
loss = None
if targets is not None:

@ -8,8 +8,8 @@ from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader
logger = logging.getLogger(__name__)
# print('logging to wandb... (comment it if you don\'t have wandb)')
# import wandb # comment this if you don't have wandb
print('logging to wandb... (comment it if you don\'t have wandb)')
import wandb # comment it if you don't have wandb
class TrainerConfig:
max_epochs = 10
@ -22,8 +22,7 @@ class TrainerConfig:
lr_decay = False # linear warmup followed by cosine decay
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper
final_tokens = 260e9 # at which point do we reach lr_final
epoch_save_frequency = 0
epoch_save_path = 'trained-'
ckpt_path = None
num_workers = 0 # for DataLoader
def __init__(self, **kwargs):
@ -57,6 +56,11 @@ class Trainer:
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
return run_name
def save_checkpoint(self): # DataParallel wrappers keep raw model object in .module attribute
raw_model = self.model.module if hasattr(self.model, "module") else self.model
logger.info("saving %s", self.config.ckpt_path)
torch.save(raw_model.state_dict(), self.config.ckpt_path)
def train(self):
model, config = self.model, self.config
raw_model = model.module if hasattr(self.model, "module") else model
@ -73,11 +77,12 @@ class Trainer:
pbar = tqdm(enumerate(loader), total=len(loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
for it, (x, y) in pbar:
x = x.to(self.device) # place data on the correct device
y = y.to(self.device)
with torch.set_grad_enabled(is_train):
_, loss = model(x, y) # forward the model
logits, loss = model(x, y) # forward the model
loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
if is_train: # backprop and update the parameters
@ -89,15 +94,14 @@ class Trainer:
if config.lr_decay: # decay the learning rate based on our progress
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
lr_final_factor = config.lr_final / config.learning_rate
if self.tokens < config.warmup_tokens:
# linear warmup
lr_mult = lr_final_factor + (1 - lr_final_factor) * float(self.tokens) / float(config.warmup_tokens)
lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
progress = 0
else:
# cosine learning rate decay
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
# progress = min(progress * 1.1, 1.0) # more fine-tuning with low LR
lr_final_factor = config.lr_final / config.learning_rate
lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1
lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups:
@ -114,17 +118,20 @@ class Trainer:
if self.avg_loss < 0:
self.avg_loss = now_loss
else:
# factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1))
factor = 1 / (it + 1)
factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1))
self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor
pbar.set_description(f"epoch {epoch+1} progress {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
while True:
best_loss = float('inf')
self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):
run_epoch('train')
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
raw_model = self.model.module if hasattr(self.model, "module") else self.model # DataParallel wrappers keep raw model object in .module
torch.save(raw_model, self.config.epoch_save_path + str(epoch+1) + '.pth')
if self.test_dataset is not None:
test_loss = run_epoch('test')
# supports early stopping based on the test loss, or just save always if no test set is provided
good_model = self.test_dataset is None or test_loss < best_loss
if self.config.ckpt_path is not None and good_model:
best_loss = test_loss
self.save_checkpoint()

@ -25,50 +25,36 @@ model_type = 'RWKV'
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt"
datafile_encoding = 'utf-8'
# datafile = u"D:\\NLP-Data\\ww100M.txt"
# datafile = u"D:\\NLP-Data\\__2019.txt"
# datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt"
# datafile = u"V:\\NLP\\enwik8-shift-300.bpe"
# datafile_encoding = 'utf-16'
# datafile = u"V:\\NLP\\simplebooks-shift-utf32.word"
# datafile_encoding = 'utf-32'
datafile_type = 0 # use 0 for char-level english. use 1 for chinese. only affects some RWKV hyperparametrs
#################################### VERY IMPORTANT ####################################
epoch_save_frequency = 10 # 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc.
epoch_save_path = 'trained-'
batch_size = 32 # if you see "CUDA out of memory", reduce this.
# if you have good GPU, increase this.
# use GPU-Z to find the highest value for your VRAM.
n_epoch = 100 # the 'epoch' here is actually very short (and of fixed length)
########################################################################################
model_level = 'character' # 'character' (recommended) or 'word'
ctx_len = 256 # context length, try 512 or 1024 if you have good GPU
n_layer = 6 # try 12 for 100M, 24 for 300M
n_head = 8 # try 12 for 100M, 16 for 300M
ctx_len = 256 # context length
n_layer = 5
n_head = 8
n_embd = n_head * 64
n_attn = n_embd
n_ffn = n_embd
lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr. 8e-4 = 0.0008 4e-4 = 0.0004
lr_final = 4e-5
batch_size = 64
n_epoch = 50 # the 'epoch' here is actually very short (and of fixed length)
lr_init = 8e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr
lr_final = 2e-4
betas = (0.9, 0.99) if model_type == 'RWKV' else (0.9, 0.99)
eps = 4e-9
betas = (0.9, 0.999) if model_type == 'RWKV' else (0.9, 0.99)
eps = 1e-8
weight_decay = 0 if model_type == 'RWKV' else 0.01 # wd is not useful when we have enough data
epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
######## special hyperparameters for RWKV model ########
rwkv_emb_scale = 0.4 # scale of initial embedding. 0.4 is a good choice
rwkv_tiny_attn = 0#64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english
rwkv_tiny_attn = 64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english
rwkv_tiny_head = 1 # 1 is good enough. 8 is slow
# n_side_proj = 512 # extra 'side projection', quite useful for BPE models
########################################################################################################
# Load data
@ -90,15 +76,6 @@ class Dataset(Dataset):
# for u in unique:
# print(u, end=' ')
# print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
data_size, vocab_size = len(data), len(unique)
print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size))
self.stoi = { ch:i for i,ch in enumerate(unique) }
@ -128,15 +105,63 @@ model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_typ
rwkv_emb_scale=rwkv_emb_scale, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head,
n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn))
# load a trained model
# model.load_state_dict(torch.load('trained-xxx.pth').state_dict())
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn)
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps,
warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0)
trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()
torch.save(model, 'trained-' + trainer.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')
########################################################################################################
# Run model to generate text
########################################################################################################
from src.utils import sample_logits
NUM_OF_RUNS = 5
LENGTH_OF_EACH = 300
for run in range(NUM_OF_RUNS):
context = "it was"
if model_level == 'word':
x = np.array([train_dataset.stoi[s] for s in context.strip().lower().split(' ')], dtype=np.int64)
else:
x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64)
real_len = len(x)
if real_len < ctx_len:
x = np.pad(x, (0, ctx_len - real_len))
print_begin = 0
for i in range(LENGTH_OF_EACH):
if i == 0:
print(('-' * 80) + '\n' + context, end = '')
print_begin = real_len
with torch.no_grad():
xxx = torch.tensor(x[-ctx_len:], dtype=torch.long)[None,...].to("cuda:0")
out, _ = model(xxx)
pos = -1 if real_len >= ctx_len else real_len - 1
char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) # our special sampling method
if real_len < ctx_len:
x[real_len] = char
else:
x = np.append(x, char)
real_len += 1
if i % 10 == 9 or i == LENGTH_OF_EACH-1:
if model_level == 'word':
completion = ' ' + ' '.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]])
completion = completion.replace('\n ', '\n')
else:
completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]])
print(completion, end = '')
print_begin = real_len
print()
Loading…
Cancel
Save