Compare commits

...

124 Commits
4.00 ... main

Author SHA1 Message Date
PENG Bo 99a5933f54
Update README.md 3 years ago
PENG Bo 5b08ee1718
Update README.md 3 years ago
PENG Bo 4e962eb850
Update README.md 3 years ago
PENG Bo 8f428408a3
Update README.md 3 years ago
PENG Bo a9007581d0
Update README.md 3 years ago
BlinkDL 79915b3696 better 3 years ago
BlinkDL 0c7cd08255 fix 3 years ago
PENG Bo 3d43eaa1c8
Update README.md 3 years ago
PENG Bo 04a564d7d0
Update README.md 3 years ago
PENG Bo 5713df51ec
Update README.md 3 years ago
PENG Bo 14d21f5a00
Update README.md 3 years ago
BlinkDL 4ca274aad7 fix 3 years ago
BlinkDL 1945cb58ed pile v2 3 years ago
BlinkDL 3d2b04ba0c Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL 107f167b4a misc 3 years ago
PENG Bo ac4ba411e6
Add files via upload 3 years ago
PENG Bo 099919058b
Update README.md 3 years ago
PENG Bo 87fab90435
Update README.md 3 years ago
PENG Bo 13b8784502
Update README.md 3 years ago
PENG Bo 513d3eb552
Add files via upload 3 years ago
PENG Bo 8b615ccc74
Update README.md 3 years ago
PENG Bo 1e0dba0421
Update README.md 3 years ago
PENG Bo f8134fb96e
Update README.md 3 years ago
PENG Bo 62fba64244
Update README.md 3 years ago
PENG Bo 123536b2a7
Update README.md 3 years ago
PENG Bo decd8e29f5
Add files via upload 3 years ago
BlinkDL 6d4dec7288 better 3 years ago
PENG Bo 8e99ac1138
Update README.md 3 years ago
PENG Bo 4056bfeba7
Add files via upload 3 years ago
PENG Bo d16e25661c
Add files via upload 3 years ago
PENG Bo 0ff2170277
Update README.md 3 years ago
PENG Bo e615f1c718
Add files via upload 3 years ago
PENG Bo 1430d4edcf
Update README.md 3 years ago
PENG Bo 4378fe6b4f
Update README.md 3 years ago
PENG Bo f38b7e3574
Update README.md 3 years ago
BlinkDL 6739df885e saves VRAM 3 years ago
BlinkDL 9f557219c4 misc 3 years ago
BlinkDL 58e9d8d972 Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL 1d72a48db0 faster bf16 & saves VRAM 3 years ago
PENG Bo 52ac194d54
Update README.md 3 years ago
BlinkDL 93d671c287 better cuda kernel 3 years ago
PENG Bo 760db55fa6
Update README.md 3 years ago
PENG Bo 6b59d8fee1
Update README.md 3 years ago
PENG Bo fc047a20b1
Update README.md 3 years ago
PENG Bo 0c77cfbbee
Update README.md 3 years ago
PENG Bo 904de99a14
Update README.md 3 years ago
PENG Bo 5f6ffc987a
Update README.md 3 years ago
PENG Bo 02178d79c9
Update README.md 3 years ago
PENG Bo ad1836b27f
Update README.md 3 years ago
PENG Bo 404c593213
Update README.md 3 years ago
PENG Bo 9917078f93
Update README.md 3 years ago
PENG Bo e0dc08a2ce
Update README.md 3 years ago
PENG Bo a3e6156136
Update README.md 3 years ago
PENG Bo 55f63f0aeb
Update README.md 3 years ago
PENG Bo 114d677bc8
Update README.md 3 years ago
BlinkDL d008cc6d8e Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL c13879ab97 misc 3 years ago
PENG Bo 78579a00d2
Update README.md 3 years ago
PENG Bo e6d9e4979a
Update README.md 3 years ago
PENG Bo 8d72d882e4
Update README.md 3 years ago
PENG Bo 3f5ac97f77
Update README.md 3 years ago
PENG Bo 81aa6dda7b
Update README.md 3 years ago
PENG Bo 366b000ee6
Add files via upload 3 years ago
PENG Bo 11acd5e5b5
Update README.md 3 years ago
PENG Bo 374b086911
Add files via upload 3 years ago
BlinkDL 7476c69f32 fix 3 years ago
BlinkDL 6ed3a3db09 Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL e2ec7ae023 test 3 years ago
PENG Bo 71a46ca0f3
Update README.md 3 years ago
PENG Bo b97c25b9e7
Update README.md 3 years ago
PENG Bo 3d15f41a16
Update README.md 3 years ago
PENG Bo bbacb62b89
Add files via upload 3 years ago
BlinkDL c7b1900270 prepare for v4c 3 years ago
BlinkDL 038f06b996 rwkv-4b 3 years ago
PENG Bo f03efd0218
Update README.md 3 years ago
PENG Bo 9721b8f9c5
Update README.md 3 years ago
PENG Bo 79aa59ff2b
Add files via upload 3 years ago
PENG Bo 8e63b75f2c
Add files via upload 3 years ago
PENG Bo 5c8eda8595
Add files via upload 3 years ago
PENG Bo 13c6149205
Update README.md 3 years ago
PENG Bo f6cb1a1947
Update README.md 3 years ago
PENG Bo aeae6c8aac
Update README.md 3 years ago
PENG Bo b562097da1
Update README.md 3 years ago
BlinkDL f79d082053 testing 3 years ago
PENG Bo 8bf7061705
Update README.md 3 years ago
BlinkDL dd3845752a Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL b4925900e7 misc 3 years ago
PENG Bo fcb0b9819d
Update README.md 3 years ago
PENG Bo be18c53fec
Add files via upload 3 years ago
PENG Bo eac471da29
Update README.md 3 years ago
PENG Bo d1bb270fb3
Update README.md 3 years ago
PENG Bo 5837ee32c4
Update README.md 3 years ago
PENG Bo 8e1130e12a
Update README.md 3 years ago
BlinkDL b2a240d73d misc 3 years ago
PENG Bo bc47cb9f1a
Update README.md 3 years ago
PENG Bo 3461b2f6fb
Add files via upload 3 years ago
BlinkDL 295af9a517 info 3 years ago
BlinkDL dc26998708 torch jit 3 years ago
BlinkDL 75929cbbba torch jit (xx% faster inference) 3 years ago
PENG Bo 819f2730b2
Update README.md 3 years ago
PENG Bo 66c1dabb94
Add files via upload 3 years ago
PENG Bo 379c97890b
Update README.md 3 years ago
PENG Bo 83a4512b74
Update README.md 3 years ago
BlinkDL 59e6deeb58 better chat 3 years ago
BlinkDL cf340264dc better 3 years ago
PENG Bo 935d8d3e87
Update README.md 3 years ago
BlinkDL 0d0cedfcd9 better chatbot 3 years ago
PENG Bo 511b7adb4f
Add files via upload 3 years ago
BlinkDL aaf1341af7 better chat 3 years ago
BlinkDL e64ce9b0ff Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL 3e0f8054c6 better prompt 3 years ago
PENG Bo 0131543e48
Update README.md 3 years ago
PENG Bo 315ce82e38
Add files via upload 3 years ago
BlinkDL 9eb1a7b3d3 Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL 7a7c06aed3 misc 3 years ago
PENG Bo 2e5704097d
Update README.md 3 years ago
BlinkDL 529be15c67 Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL 5c73bccd5a bug fix 3 years ago
PENG Bo 7bdfd1cb64
Update README.md 3 years ago
BlinkDL eaac3f7e66 chatbot 3 years ago
PENG Bo 14544aea94
Update README.md 3 years ago
PENG Bo 3fc16a86ed
Update README.md 3 years ago
PENG Bo eb3ca86f0b
Add files via upload 3 years ago
BlinkDL 23f64aeebc misc improvements 3 years ago

@ -1,28 +1,65 @@
# The RWKV Language Model (and my tricks for LMs) # The RWKV Language Model (and my LM tricks)
## RWKV: RNN with Transformer-level LLM Performance ## RWKV: Parallelizable RNN with Transformer-level LLM Performance (pronounced as "RwaKuv", from 4 major params: R W K V)
RWKV is a RNN with Transformer-level LLM performance, which can also be directly trained like a GPT transformer (parallelizable). And it's attention-free. You only need the hidden state at position t to compute the state at position t+1. You can use the "GPT" mode to quickly computer the hidden state for the "RNN" mode. RWKV is an RNN with Transformer-level LLM performance, which can also be directly trained like a GPT transformer (parallelizable). And it's 100% attention-free. You only need the hidden state at position t to compute the state at position t+1. You can use the "GPT" mode to quickly compute the hidden state for the "RNN" mode.
So it's combining the best of RNN and transformer - **great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding** (using the final hidden state). So it's combining the best of RNN and transformer - **great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding** (using the final hidden state).
RWKV-3 1.5B on A40 (tf32) = always 0.015 sec/token, tested using simple pytorch code (no CUDA), GPU utilization 45%, VRAM 7823M **HuggingFace Gradio demo (14B ctx8192)**: https://huggingface.co/spaces/BlinkDL/ChatRWKV-gradio
GPT2-XL 1.3B on A40 (tf32) = 0.032 sec/token (for ctxlen 1000), tested using HF, GPU utilization 45% too (interesting), VRAM 9655M Raven (7B finetuned on Alpaca) Demo: https://huggingface.co/spaces/BlinkDL/Raven-RWKV-7B
Training speed: RWKV-4 1.5B BF16 ctxlen1024 = 106K tokens/s on 8xA100 40G. **ChatRWKV:** with "stream" and "split" strategies and INT8. **3G VRAM is enough to run RWKV 14B :)** https://github.com/BlinkDL/ChatRWKV
I am doing image experiments too (For example: https://huggingface.co/BlinkDL/clip-guided-binary-autoencoder) and RWKV will be able to do txt2img diffusion :) My idea: 256x256 rgb image -> 32x32x13bit latents -> apply RWKV to compute transition probability for each of the 32x32 grid -> pretend the grids are independent and "diffuse" using these probabilities. **RWKV pip package**: https://pypi.org/project/rwkv/
```python
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
from rwkv.model import RWKV # pip install rwkv
model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda fp16')
out, state = model.forward([187, 510, 1563, 310, 247], None) # use 20B_tokenizer.json
print(out.detach().cpu().numpy()) # get logits
out, state = model.forward([187, 510], None)
out, state = model.forward([1563], state) # RNN has state (use deepcopy if you want to clone it)
out, state = model.forward([310, 247], state)
print(out.detach().cpu().numpy()) # same result as above
```
**Download RWKV-4 0.1/0.4/1.5/3/7/14B weights**: https://huggingface.co/BlinkDL
## Join Our Discord: https://discord.gg/bDSBUMeFpc (lots of developers)
**Twitter**: https://twitter.com/BlinkDL_AI
**RWKV in 150 lines** (model, inference, text generation): https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py
ChatRWKV with RWKV 14B ctx8192:
## Join our Discord: https://discord.gg/bDSBUMeFpc :) ![RWKV-chat](RWKV-chat.png)
You are welcome to join the RWKV discord https://discord.gg/bDSBUMeFpc to build upon it. We have plenty of potential compute (A100 40Gs) now (thanks to Stability and EleutherAI), so if you have interesting ideas I can run them. You are welcome to join the RWKV discord https://discord.gg/bDSBUMeFpc to build upon it. We have plenty of potential compute (A100 40Gs) now (thanks to Stability and EleutherAI), so if you have interesting ideas I can run them.
Twitter: https://twitter.com/BlinkDL_AI ![RWKV-eval2](RWKV-eval2.png)
**Download RWKV-4 0.1/0.4/1.5/3/7/14B weights**: https://huggingface.co/BlinkDL RWKV [loss vs token position] for 10000 ctx4k+ documents in Pile. RWKV 1B5-4k is mostly flat after ctx1500, but 3B-4k and 7B-4k and 14B-4k have some slopes, and they are getting better. This debunks the old view that RNNs cannot model long ctxlens. We can predict that RWKV 100B will be great, and RWKV 1T is probably all you need :)
![RWKV-ctxlen](RWKV-ctxlen.png)
I believe RNN is a better candidate for fundamental models, because: (1) It's more friendly for ASICs (no kv cache). (2) It's more friendly for RL. (3) When we write, our brain is more similar to RNN. (4) The universe is like an RNN too (because of locality). Transformers are non-local models.
I am training RWKV-4 14B on the Pile: https://wandb.ai/blinkdl/RWKV-v4-Pile RWKV-3 1.5B on A40 (tf32) = always 0.015 sec/token, tested using simple pytorch code (no CUDA), GPU utilization 45%, VRAM 7823M
GPT2-XL 1.3B on A40 (tf32) = 0.032 sec/token (for ctxlen 1000), tested using HF, GPU utilization 45% too (interesting), VRAM 9655M
Training speed: (new training code) RWKV-4 14B BF16 ctxlen4096 = 114K tokens/s on 8x8 A100 80G (ZERO2+CP). (old training code) RWKV-4 1.5B BF16 ctxlen1024 = 106K tokens/s on 8xA100 40G.
I am doing image experiments too (For example: https://huggingface.co/BlinkDL/clip-guided-binary-autoencoder) and RWKV will be able to do txt2img diffusion :) My idea: 256x256 rgb image -> 32x32x13bit latents -> apply RWKV to compute transition probability for each of the 32x32 grid -> pretend the grids are independent and "diffuse" using these probabilities.
Smooth training - no loss spikes! (lr & bsz change around 15G tokens)
![RWKV-loss](RWKV-loss.png)
![RWKV-eval](RWKV-eval.png) ![RWKV-eval](RWKV-eval.png)
@ -32,6 +69,8 @@ How it works: RWKV gathers information to a number of channels, which are also d
**RWKV is parallelizable because the time-decay of each channel is data-independent (and trainable)**. For example, in usual RNN you can adjust the time-decay of a channel from say 0.8 to 0.5 (these are called "gates"), while in RWKV you simply move the information from a W-0.8-channel to a W-0.5-channel to achieve the same effect. Moreover, you can fine-tune RWKV into a non-parallelizable RNN (then you can use outputs of later layers of the previous token) if you want extra performance. **RWKV is parallelizable because the time-decay of each channel is data-independent (and trainable)**. For example, in usual RNN you can adjust the time-decay of a channel from say 0.8 to 0.5 (these are called "gates"), while in RWKV you simply move the information from a W-0.8-channel to a W-0.5-channel to achieve the same effect. Moreover, you can fine-tune RWKV into a non-parallelizable RNN (then you can use outputs of later layers of the previous token) if you want extra performance.
![RWKV-formula](RWKV-formula.png)
Here are some of my TODOs. Let's work together :) Here are some of my TODOs. Let's work together :)
* HuggingFace integration (check https://github.com/huggingface/transformers/issues/17230 * HuggingFace integration (check https://github.com/huggingface/transformers/issues/17230
@ -54,31 +93,64 @@ You can find me (BlinkDL) in the EleutherAI Discord too: https://www.eleuther.ai
## Quick start ## Quick start
Use https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v4 or https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v4neo (latest code). Use https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v4neo (latest code, compatible with v4).
### Inference Here is a great prompt for testing Q&A of LLMs. Works for any model: (found by minimizing ChatGPT ppls for RWKV 1.5B)
```python
prompt = f'\nQ & A\n\nQuestion:\n{qq}\n\nDetailed Expert Answer:\n' # let the model generate after this
```
**Run RWKV-4 Pile models:** Download models from https://huggingface.co/BlinkDL. Set TOKEN_MODE = 'pile' in run.py and run it. It's fast even on CPU (the default mode). **Cool Community RWKV Projects (check them!)**:
**Colab for RWKV-4 Pile 1.5B**: https://colab.research.google.com/drive/1F7tZoPZaWJf1fsCmZ5tjw6sYHiFOYVWM https://pypi.org/project/rwkvstic/ a pip package (with 8bit & offload for low VRAM GPUs)
Run RWKV-4 Pile models in your browser (and onnx version): see this issue https://github.com/BlinkDL/RWKV-LM/issues/7 https://github.com/harrisonvanderbyl/rwkv_chatbot a chatbot
RWKV-4 Web Demo: https://josephrocca.github.io/rwkv-v4-web/demo/ (note: only greedy sampling for now) https://github.com/hizkifw/WebChatRWKVstic WebUI (WIP)
https://github.com/gururise/rwkv_gradio RWKV Gradio
https://github.com/cryscan/eloise RWKV QQ bot
https://github.com/Blealtan/RWKV-LM-LoRA LoRA fine-tuning
https://github.com/mrsteyk/RWKV-LM-jax
More resources: https://github.com/wozeparrot/tinyrwkv RWKV in tinygrad (nice simple DL framework)
https://github.com/huggingface/transformers/issues/17230 https://github.com/huggingface/transformers/issues/17230 RWKV HF package (WIP)
https://github.com/ArEnSc/Production-RWKV https://github.com/ArEnSc/Production-RWKV RWKV HF package source
https://github.com/harrisonvanderbyl/rwkv_chatbot https://github.com/nlpodyssey/verbaflow RWKV in Go
https://github.com/Pathos14489/RWKVDistributedInference https://github.com/nlpodyssey/rwkv RWKV in Go
https://github.com/AXKuhta/rwkv-onnx-dml https://github.com/mrsteyk/rwkvk-rs RWKV in Rust
https://github.com/josephrocca/rwkv-v4-web https://github.com/josephrocca/rwkv-v4-web RWKV in browser
https://github.com/imxcstar/CSharp-RWKV-V4 RWKV in C#
https://github.com/mrsteyk/RWKV-LM-deepspeed Another training fork
https://github.com/resloved/RWKV-notebooks RWKV colab notebooks
https://colab.research.google.com/github/harrisonvanderbyl/rwkvstic/blob/master/notebooks/chatbot.ipynb RWKV chatbot colab notebook
https://github.com/Pathos14489/RWKVDistributedInference RWKV Distributed Inference
https://github.com/AXKuhta/rwkv-onnx-dml RWKV ONNX
### Inference
**Run RWKV-4 Pile models:** Download models from https://huggingface.co/BlinkDL. Set TOKEN_MODE = 'pile' in run.py and run it. It's fast even on CPU (the default mode).
**Colab for RWKV-4 Pile 1.5B**: https://colab.research.google.com/drive/1F7tZoPZaWJf1fsCmZ5tjw6sYHiFOYVWM
Run RWKV-4 Pile models in your browser (and onnx version): see this issue https://github.com/BlinkDL/RWKV-LM/issues/7
RWKV-4 Web Demo: https://josephrocca.github.io/rwkv-v4-web/demo/ (note: only greedy sampling for now)
For the old RWKV-2: see the release here for a 27M params model on enwik8 with 0.72 BPC(dev). Run run.py in https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN. You can even run it in your browser: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng https://blinkdl.github.io/AI-Writer/eng/ (this is using tf.js WASM single-thread mode). For the old RWKV-2: see the release here for a 27M params model on enwik8 with 0.72 BPC(dev). Run run.py in https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN. You can even run it in your browser: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng https://blinkdl.github.io/AI-Writer/eng/ (this is using tf.js WASM single-thread mode).
@ -90,10 +162,86 @@ You will be training the "GPT" version because it's paralleziable and faster to
**Fine-tuning RWKV-4 Pile models:** use 'prepare-data.py' in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/tree/main/RWKV-v3 to tokenize .txt into train.npy data. Then set EXPRESS_PILE_MODE to True in train.py, and run it. **Fine-tuning RWKV-4 Pile models:** use 'prepare-data.py' in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/tree/main/RWKV-v3 to tokenize .txt into train.npy data. Then set EXPRESS_PILE_MODE to True in train.py, and run it.
Read the inference code in src/model.py and try using the final hidden state.xx .aa .bb) as a faithful sentence embedding for other tasks. Probably you shall begin with .xx and .aa/.bb (.aa divided by .bb). Read the inference code in src/model.py and try using the final hidden state.xx .aa .bb) as a faithful sentence embedding for other tasks. Probably you should begin with .xx and .aa/.bb (.aa divided by .bb).
Colab for fine-tuning RWKV-4 Pile models: https://colab.research.google.com/github/resloved/RWKV-notebooks/blob/master/RWKV_v4_RNN_Pile_Fine_Tuning.ipynb Colab for fine-tuning RWKV-4 Pile models: https://colab.research.google.com/github/resloved/RWKV-notebooks/blob/master/RWKV_v4_RNN_Pile_Fine_Tuning.ipynb
**Large corpus:** Use https://github.com/EleutherAI/gpt-neox to convert .jsonl into .bin and .idx
```
python tools/preprocess_data.py --input ./my_data.jsonl --output-prefix ./data/my_data --vocab ./20B_tokenizer.json --dataset-impl mmap --tokenizer-type HFTokenizer --append-eod
```
The jsonl format sample (one line for each document):
```
{"meta": {"ID": 101}, "text": "This is the first document."}
{"meta": {"ID": 102}, "text": "Hello\nWorld"}
{"meta": {"ID": 103}, "text": "1+1=2\n1+2=3\n2+2=4"}
```
generated by code like this:
```
ss = json.dumps({"meta": meta, "text": text}, ensure_ascii=False)
out.write(ss + "\n")
```
## Towards RWKV-5 (just to record some new ideas)
### Some ideas
1. Now time decay is like 0.999^T (0.999 is learnable). Change it to something like (0.999^T + 0.1) where 0.1 is learnable too. The 0.1 part will be kept forever. Or, A^T + B^T + C = fast-decay + slow-decay + constant. Can even use different formulas (for example, K^2 instead of e^K for a decay component, or, without normalization).
2. Use complex-valued decay (so, rotation instead of decay) in some channels.
3. Inject some trainable and extrapolatable positional encoding?
4. Aside from 2d rotation, we can try other Lie groups such as 3d rotation ( SO(3) ). Non-abelian RWKV lol.
5. RWKV might be great on analog devices (search for Analog Matrix-vector multiplication & Photonic Matrix-vector multiplication). The RNN mode is very hardware-friendly (processing-in-memory). Can be a SNN too (https://github.com/ridgerchu/SpikeGPT). I wonder if it can be optimized for quantum computation.
6. Trainable initial hidden state (xx aa bb pp xx).
### Vision Tasks
1. I find it's good to add a 2d pos encoding:
```
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
...
x = x + pos_emb_x + pos_emb_y
```
2. In a BPE langauge model, it's the best to use [tokenShift of 1 token] (you can mix more tokens in a char-level English model). However you can try [tokenShift of N (or N-1) (or N+1) tokens] if the image size is N x N, because that will be like mixing [the token above the current positon (or the token above the to-be-predicted positon)] with [current token]. You can use try different tokenShift styles for "ATT" & "FFN", or mixing different tokenShift styles - such as mixing [token A] with [token A-1] and [token A-(N-1)] etc.
### Misc
I have an idea to improve tokenization. We can hardcode some channels to have meanings. Example:
Channel 0 = "space"
Channel 1 = "capitalize first letter"
Channel 2 = "capitalize all letters"
Therefore:
Embedding of "abc": [0, 0, 0, x0, x1, x2 , ..]
Embedding of " abc": [1, 0, 0, x0, x1, x2, ..]
Embedding of " Abc": [1, 1, 0, x0, x1, x2, ..]
Embedding of "ABC": [0, 0, 1, x0, x1, x2, ...]
......
so they will share most of the embedding. And we can rapidly compute the output probability of all variations of "abc".
Note: the above method is assuming that p(" xyz") / p("xyz") is the same for any "xyz", which can be wrong.
Better: define emb_space emb_capitalize_first emb_capitalize_all to be a function of emb.
Maybe the Best: let 'abc' ' abc' etc. to share the last 90% of their embeddings.
At this moment, all our tokenizers spend too many items to represent all variations of 'abc' ' abc' ' Abc' etc. Moreover the model cannot discover that these are actually similar if some of these variations are rare in the dataset. The method here can improve this. I plan to test this in a new version of RWKV.
## How it works ## How it works
RWKV is inspired by Apple's AFT (https://arxiv.org/abs/2105.14103). RWKV is inspired by Apple's AFT (https://arxiv.org/abs/2105.14103).
@ -289,6 +437,10 @@ I believe RWKV is performant because W is like repeatedly applying a diagonal ma
Moreover it's possible to turn it into a continuous ODE (a bit similar to State Space Models). I will write about it later. Moreover it's possible to turn it into a continuous ODE (a bit similar to State Space Models). I will write about it later.
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=BlinkDL/RWKV-LM&type=Date)](https://star-history.com/#BlinkDL/RWKV-LM&Date)
## Multimodal ideas ## Multimodal ideas
I have an idea for [text --> 32x32 RGB image] using a LM (transformer, RWKV, etc.). Will test it soon. I have an idea for [text --> 32x32 RGB image] using a LM (transformer, RWKV, etc.). Will test it soon.
@ -311,7 +463,7 @@ Multi-task training might help too. I will try this dataset format:
[TxtFirst] [Desc of Img (txt tokens)] [Img] [img tokens] [TxtFirst] [Desc of Img (txt tokens)] [Img] [img tokens]
and sometimes and sometimes
[ImgFirst] [img tokens] [Txt] [Desc of Img (txt tokens)] [ImgFirst] [img tokens] [Txt] [Desc of Img (txt tokens)]
... the order of the imgs shall be randomized in the DataLoader, and [TxtFirst] [ImgFirst] [Img] [Txt] are special tokens ... the order of the imgs should be randomized in the DataLoader, and [TxtFirst] [ImgFirst] [Img] [Txt] are special tokens
and do random sampling of the full dataset. So sometimes the model will see the img tokens first and then the corresponding txt tokens, which is a [img -> txt] task. And the model will see some partial imgs and partial txts. I think a char-level LM might help the model to write correct text on images. and do random sampling of the full dataset. So sometimes the model will see the img tokens first and then the corresponding txt tokens, which is a [img -> txt] task. And the model will see some partial imgs and partial txts. I think a char-level LM might help the model to write correct text on images.
## How to sample a large dataset (for training) ## How to sample a large dataset (for training)

Binary file not shown.

After

Width:  |  Height:  |  Size: 161 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 67 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 MiB

After

Width:  |  Height:  |  Size: 410 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 153 KiB

After

Width:  |  Height:  |  Size: 359 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

@ -27,7 +27,7 @@ dtypes = {
3: np.int16, 3: np.int16,
4: np.int32, 4: np.int32,
5: np.int64, 5: np.int64,
6: np.float, 6: float,
7: np.double, 7: np.double,
8: np.uint16, 8: np.uint16,
} }

@ -0,0 +1,361 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
print('Loading...')
from src.model_run import RWKV_RNN
import numpy as np
import os, copy, types, gc, sys
import torch
from src.utils import TOKENIZER
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
CHAT_LANG = 'English' # English Chinese
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
args = types.SimpleNamespace()
args.RUN_DEVICE = "cuda" # 'cpu' (already very fast) // 'cuda'
args.FLOAT_MODE = "fp16" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate)
args.vocab_size = 50277
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230108-5170'
args.n_layer = 40
args.n_embd = 5120
args.ctx_len = 1024
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
# args.n_layer = 32
# args.n_embd = 4096
# args.ctx_len = 1024
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023'
# args.n_layer = 32
# args.n_embd = 2560
# args.ctx_len = 1024
if CHAT_LANG == 'English':
user = "User"
bot = "Bot"
interface = ":"
# The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
# The following is a conversation between a highly knowledgeable and intelligent AI called {bot}, and a human called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} do its best to answer {user}'s questions. {bot} is respectful, polite and inclusive. {bot} knows a lot, and always tells the truth.
init_prompt = f'''
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
{user}{interface} french revolution what year
{bot}{interface} The French Revolution started in 1789, and lasted 10 years until 1799.
{user}{interface} 3+5=?
{bot}{interface} The answer is 8.
{user}{interface} guess i marry who ?
{bot}{interface} Only if you tell me more about yourself - what are your interests?
{user}{interface} solve for a: 9-a=2
{bot}{interface} The answer is a = 7, because 9 - 7 = 2.
{user}{interface} wat is lhc
{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
'''
HELP_MSG = '''Commands:
say something --> chat with bot. use \\n for new line.
+alt --> alternate chat reply
+reset --> reset chat
+gen YOUR PROMPT --> free generation with any prompt. use \\n for new line.
+qa YOUR QUESTION --> free generation - ask any question (just ask the question). use \\n for new line.
+more --> continue last free generation (only for +gen / +qa)
+retry --> retry last free generation (only for +gen / +qa)
Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results.
This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +gen for free generation.
'''
elif CHAT_LANG == 'Chinese':
args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run3z/rwkv-293'
args.n_layer = 32
args.n_embd = 4096
args.ctx_len = 1024
user = "Q"
bot = "A"
interface = ":"
init_prompt = '''
Q: 企鹅会飞吗
A: 企鹅是不会飞的它们的翅膀主要用于游泳和平衡而不是飞行
Q: 西瓜是什么
A: 西瓜是一种常见的水果是一种多年生蔓生藤本植物西瓜的果实呈圆形或卵形通常是绿色的里面有红色或黄色的肉和很多的籽西瓜味甜多吃可以增加水分是夏季非常受欢迎的水果之一
'''
HELP_MSG = '''指令:
直接输入内容 --> 和机器人聊天\\n代表换行
+alt --> 让机器人换个回答
+reset --> 重置对话
+gen 某某内容 --> 续写任何中英文内容\\n代表换行
+qa 某某问题 --> 问独立的问题忽略上下文\\n代表换行
+more --> 继续 +gen / +qa 的回答
+retry --> 换个 +gen / +qa 的回答
现在可以输入内容和机器人聊天注意它不怎么懂中文它可能更懂英文请经常使用 +reset 重置机器人记忆
'''
# Load Model
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
MODEL_NAME = args.MODEL_NAME
print(f'loading... {MODEL_NAME}')
model = RWKV_RNN(args)
model_tokens = []
current_state = None
########################################################################################################
def run_rnn(tokens, newline_adj = 0):
global model_tokens, current_state
for i in range(len(tokens)):
model_tokens += [int(tokens[i])]
if i == len(tokens) - 1:
out, current_state = model.forward(model_tokens, current_state)
else:
current_state = model.forward(model_tokens, current_state, preprocess_only = True)
# print(f'### model ###\n[{tokenizer.tokenizer.decode(model_tokens)}]')
out[0] = -999999999 # disable <|endoftext|>
out[187] += newline_adj
# if newline_adj > 0:
# out[15] += newline_adj / 2 # '.'
return out
all_state = {}
def save_all_stat(srv, name, last_out):
n = f'{name}_{srv}'
all_state[n] = {}
all_state[n]['out'] = last_out
all_state[n]['rnn'] = copy.deepcopy(current_state)
all_state[n]['token'] = copy.deepcopy(model_tokens)
def load_all_stat(srv, name):
global model_tokens, current_state
n = f'{name}_{srv}'
current_state = copy.deepcopy(all_state[n]['rnn'])
model_tokens = copy.deepcopy(all_state[n]['token'])
return all_state[n]['out']
########################################################################################################
# Run inference
print(f'\nRun prompt...')
out = run_rnn(tokenizer.tokenizer.encode(init_prompt))
gc.collect()
torch.cuda.empty_cache()
save_all_stat('', 'chat_init', out)
srv_list = ['dummy_server']
for s in srv_list:
save_all_stat(s, 'chat', out)
print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n')
def reply_msg(msg):
print(f'{bot}{interface} {msg}\n')
def on_message(message):
global model_tokens, current_state
srv = 'dummy_server'
msg = message.replace('\\n','\n').strip()
if len(msg) > 1000:
reply_msg('your message is too long (max 1000 tokens)')
return
x_temp = 1.0
x_top_p = 0.85
if ("-temp=" in msg):
x_temp = float(msg.split("-temp=")[1].split(" ")[0])
msg = msg.replace("-temp="+f'{x_temp:g}', "")
# print(f"temp: {x_temp}")
if ("-top_p=" in msg):
x_top_p = float(msg.split("-top_p=")[1].split(" ")[0])
msg = msg.replace("-top_p="+f'{x_top_p:g}', "")
# print(f"top_p: {x_top_p}")
if x_temp <= 0.2:
x_temp = 0.2
if x_temp >= 5:
x_temp = 5
if x_top_p <= 0:
x_top_p = 0
if msg == '+reset':
out = load_all_stat('', 'chat_init')
save_all_stat(srv, 'chat', out)
reply_msg("Chat reset.")
return
elif msg[:5].lower() == '+gen ' or msg[:4].lower() == '+qa ' or msg.lower() == '+more' or msg.lower() == '+retry':
if msg[:5].lower() == '+gen ':
new = '\n' + msg[5:].strip()
# print(f'### prompt ###\n[{new}]')
current_state = None
out = run_rnn(tokenizer.tokenizer.encode(new))
save_all_stat(srv, 'gen_0', out)
elif msg[:4].lower() == '+qa ':
out = load_all_stat('', 'chat_init')
real_msg = msg[4:].strip()
new = f"{user}{interface} {real_msg}\n\n{bot}{interface}"
# print(f'### qa ###\n[{new}]')
out = run_rnn(tokenizer.tokenizer.encode(new))
save_all_stat(srv, 'gen_0', out)
# new = f"\nThe following is an excellent Q&A session consists of detailed and factual information.\n\nQ: What is 3+5?\nA: The answer is 8.\n\nQ: {msg[9:].strip()}\nA:"
# print(f'### prompt ###\n[{new}]')
# current_state = None
# out = run_rnn(tokenizer.tokenizer.encode(new))
# save_all_stat(srv, 'gen_0', out)
elif msg.lower() == '+more':
try:
out = load_all_stat(srv, 'gen_1')
save_all_stat(srv, 'gen_0', out)
except:
return
elif msg.lower() == '+retry':
try:
out = load_all_stat(srv, 'gen_0')
except:
return
begin = len(model_tokens)
out_last = begin
for i in range(150):
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p_usual=x_top_p,
top_p_newline=x_top_p,
)
if msg[:4].lower() == '+qa ':
out = run_rnn([token], newline_adj=-1)
else:
out = run_rnn([token])
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx:
print(xxx, end='', flush=True)
out_last = begin + i + 1
print('\n')
# send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
# print(f'### send ###\n[{send_msg}]')
# reply_msg(send_msg)
save_all_stat(srv, 'gen_1', out)
else:
if msg.lower() == '+alt':
try:
out = load_all_stat(srv, 'chat_pre')
except:
return
else:
out = load_all_stat(srv, 'chat')
new = f"{user}{interface} {msg}\n\n{bot}{interface}"
# print(f'### add ###\n[{new}]')
out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999)
save_all_stat(srv, 'chat_pre', out)
begin = len(model_tokens)
out_last = begin
print(f'{bot}{interface}', end='', flush=True)
for i in range(999):
if i <= 0:
newline_adj = -999999999
elif i <= 30:
newline_adj = (i - 30) / 10
elif i <= 130:
newline_adj = 0
else:
newline_adj = (i - 130) * 0.25 # MUST END THE GENERATION
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p_usual=x_top_p,
top_p_newline=x_top_p,
)
out = run_rnn([token], newline_adj=newline_adj)
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx:
print(xxx, end='', flush=True)
out_last = begin + i + 1
send_msg = tokenizer.tokenizer.decode(model_tokens[begin:])
if '\n\n' in send_msg:
send_msg = send_msg.strip()
break
# send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
# if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!!
# send_msg = send_msg[:-len(f'{user}{interface}')].strip()
# break
# if send_msg.endswith(f'{bot}{interface}'):
# send_msg = send_msg[:-len(f'{bot}{interface}')].strip()
# break
# print(f'{model_tokens}')
# print(f'[{tokenizer.tokenizer.decode(model_tokens)}]')
# print(f'### send ###\n[{send_msg}]')
# reply_msg(send_msg)
save_all_stat(srv, 'chat', out)
print(HELP_MSG)
while True:
msg = input(f'{user}{interface} ')
if len(msg.strip()) > 0:
on_message(msg)
else:
print('Erorr: please say something')

@ -18,28 +18,33 @@ __global__ void kernel_forward(const int B, const int T, const int C,
const F *__restrict__ const v = _v + _offset; const F *__restrict__ const v = _v + _offset;
F *__restrict__ const y = _y + _offset; F *__restrict__ const y = _y + _offset;
F p = 0, q = 0, o = MIN_VALUE; // aa and bb are running sums divided by exp(pp) (to avoid overflow)
// p and q are running sums divided by exp(o) (to avoid overflows) F aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) { for (int i = 0; i < T; i++) {
const int ii = i * C; const int ii = i * C;
const F kk = k[ii];
F no = max(o, u + k[ii]); const F vv = v[ii];
F A = exp(o - no);
F B = exp(u + k[ii] - no); F ww = u + kk;
y[ii] = (A * p + B * v[ii]) / (A * q + B); F p = max(pp, ww);
F e1 = exp(pp - p);
no = max(w + o, k[ii]); F e2 = exp(ww - p);
A = exp(w + o - no); y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
B = exp(k[ii] - no);
p = A * p + B * v[ii]; ww = w + pp;
q = A * q + B; p = max(ww, kk);
o = no; e1 = exp(ww - p);
e2 = exp(kk - p);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
} }
} }
template <typename F> template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C, __global__ void kernel_backward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
const F *__restrict__ const _y, const F *__restrict__ const _gy,
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x; const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C; const int _b = idx / C;
@ -50,64 +55,67 @@ __global__ void kernel_backward(const int B, const int T, const int C,
F w = _w[_c]; F w = _w[_c];
const F *__restrict__ const k = _k + _offset; const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset; const F *__restrict__ const v = _v + _offset;
const F *__restrict__ const y = _y + _offset;
const F *__restrict__ const gy = _gy + _offset; const F *__restrict__ const gy = _gy + _offset;
F *__restrict__ const gk = _gk + _offset; F *__restrict__ const gk = _gk + _offset;
F *__restrict__ const gv = _gv + _offset; F *__restrict__ const gv = _gv + _offset;
F y[Tmax], z[Tmax], zexp[Tmax]; F q[Tmax], r[Tmax];
F gw = 0, gu = 0; F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
F p = 0, q = 0;
F dpdw = 0, dqdw = 0;
F o = MIN_VALUE;
for (int i = 0; i < T; i++) { for (int i = 0; i < T; i++) {
const int ii = i * C; const int ii = i * C;
F no = max(o, k[ii] + u); const F kk = k[ii];
F A = exp(o - no); const F vv = v[ii];
F B = exp(k[ii] + u - no); const F yy = y[ii];
F num = A * p + B * v[ii]; F ww = u + kk;
F iden = 1 / (A * q + B); F p = max(pp, ww);
F e1 = exp(pp - p);
y[i] = num * iden; F e2 = exp(ww - p);
z[i] = iden; const F qq = gy[ii] / (e1 * bb + e2);
zexp[i] = k[ii] + u - no; gw += (ga - gb * yy) * e1 * qq;
gu += (vv - yy) * e2 * qq;
gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; q[i] = qq;
gu += gy[ii] * (v[ii] - y[i]) * B * iden; r[i] = ww - p;
no = max(w + o, k[ii]); ww = w + pp;
A = exp(w + o - no); p = max(ww, kk);
B = exp(k[ii] - no); e1 = exp(ww - p);
dpdw = A * (p + dpdw); e2 = exp(kk - p);
dqdw = A * (q + dqdw); ga = e1 * (aa + ga);
p = A * p + B * v[ii]; gb = e1 * (bb + gb);
q = A * q + B; aa = e1 * aa + e2 * vv;
o = no; bb = e1 * bb + e2;
pp = p;
} }
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
_gu[_offsetBC] = gu;
F gp = 0, gq = 0; aa = 0, bb = 0, pp = MIN_VALUE;
o = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) { for (int i = T - 1; i >= 0; i--) {
const int ii = i * C; const int ii = i * C;
F A = gy[ii] * z[i] * exp(zexp[i]); const F kk = k[ii];
F B = exp(k[ii] + o); const F vv = v[ii];
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); const F yy = y[ii];
gv[ii] = A + B * gp; const F qq = q[i];
const F rr = r[i];
F no = max(w + o, zexp[i] - k[ii] - u);
A = exp(w + o - no); F e1 = qq * exp(rr);
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); F e2 = exp(kk + pp);
gp = A * gp + B; gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
gq = A * gq - B * y[i]; gv[ii] = e1 + e2 * aa;
o = no;
const F ww = w + pp;
const F www = rr - u - kk;
const F p = max(ww, www);
e1 = exp(ww - p);
e2 = qq * exp(www - p);
aa = e1 * aa + e2;
bb = e1 * bb - e2 * yy;
pp = p;
} }
// Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] += gw * _w[_c];
_gu[_offsetBC] += gu;
} }
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
@ -117,9 +125,9 @@ void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, f
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y); kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
} }
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0); assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x); dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
} }

@ -0,0 +1,132 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
#define MIN_VALUE (-1e38)
typedef at::BFloat16 bf16;
__global__ void kernel_forward(const int B, const int T, const int C,
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
bf16 *__restrict__ const _y) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
float u = float(_u[_c]);
float w = _w[_c];
const bf16 *__restrict__ const k = _k + _offset;
const bf16 *__restrict__ const v = _v + _offset;
bf16 *__restrict__ const y = _y + _offset;
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
float aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
const float kk = float(k[ii]);
const float vv = float(v[ii]);
float ww = u + kk;
float p = max(pp, ww);
float e1 = exp(pp - p);
float e2 = exp(ww - p);
y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
}
__global__ void kernel_backward(const int B, const int T, const int C,
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy,
bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
float u = float(_u[_c]);
float w = _w[_c];
const bf16 *__restrict__ const k = _k + _offset;
const bf16 *__restrict__ const v = _v + _offset;
const bf16 *__restrict__ const y = _y + _offset;
const bf16 *__restrict__ const gy = _gy + _offset;
bf16 *__restrict__ const gk = _gk + _offset;
bf16 *__restrict__ const gv = _gv + _offset;
float q[Tmax], r[Tmax];
float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
const float kk = float(k[ii]);
const float vv = float(v[ii]);
const float yy = float(y[ii]);
float ww = u + kk;
float p = max(pp, ww);
float e1 = exp(pp - p);
float e2 = exp(ww - p);
const float qq = float(gy[ii]) / (e1 * bb + e2);
gw += (ga - gb * yy) * e1 * qq;
gu += (vv - yy) * e2 * qq;
q[i] = qq;
r[i] = ww - p;
ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
ga = e1 * (aa + ga);
gb = e1 * (bb + gb);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
_gu[_offsetBC] = bf16(gu);
aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) {
const int ii = i * C;
const float kk = float(k[ii]);
const float vv = float(v[ii]);
const float yy = float(y[ii]);
const float qq = q[i];
const float rr = r[i];
float e1 = qq * exp(rr);
float e2 = exp(kk + pp);
gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
gv[ii] = bf16(e1 + e2 * aa);
const float ww = w + pp;
const float www = rr - u - kk;
const float p = max(ww, www);
e1 = exp(ww - p);
e2 = qq * exp(www - p);
aa = e1 * aa + e2;
bb = e1 * bb - e2 * yy;
pp = p;
}
}
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
}
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
}

@ -1,13 +1,13 @@
#include <torch/extension.h> #include <torch/extension.h>
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>()); cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
} }
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>()); cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

@ -0,0 +1,25 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
typedef at::BFloat16 bf16;
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv forward");
m.def("backward", &backward, "wkv backward");
}
TORCH_LIBRARY(wkv, m) {
m.def("forward", forward);
m.def("backward", backward);
}

@ -17,14 +17,15 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200)
args = types.SimpleNamespace() args = types.SimpleNamespace()
######################################################################################################## ########################################################################################################
# Step 1: set model & config # Step 1: set model & config (use v4 to run your trained-from-scratch models. v4 and v4neo are compatible)
# Do this first: pip install torchdynamo
######################################################################################################## ########################################################################################################
args.RUN_DEVICE = "cpu" # 'cpu' (already very fast) // 'cuda' args.RUN_DEVICE = "cuda" # 'cuda' // 'cpu' (already fast)
args.FLOAT_MODE = "fp32" # fp32 (good for cpu) // fp16 (might overflow) // bf16 (less accurate) args.FLOAT_MODE = "fp16" # fp16 (good for GPU, does not work for CPU) // fp32 (good for CPU) // bf16 (less accurate, but works for CPU)
# if args.RUN_DEVICE == "cuda": # if args.RUN_DEVICE == "cuda":
# os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output # os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output
os.environ["RWKV_JIT_ON"] = '1' # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!!
TOKEN_MODE = "pile" TOKEN_MODE = "pile"
WORD_NAME = [ WORD_NAME = [
@ -85,27 +86,37 @@ context = "\nIn a shocking finding, scientist discovered a herd of dragons livin
# context = "\n深圳是" # test Chinese # context = "\n深圳是" # test Chinese
# context = "\n東京は" # test Japanese # context = "\n東京は" # test Japanese
###### A good prompt for chatbot ###### # ###### A good prompt for Q&A ######
# context = '''
# Questions & Helpful Answers
# Ask Research Experts
# Question:
# Can penguins fly?
# Full Answer:
# '''
# ###### A good prompt for chatbot ######
# context = ''' # context = '''
# The following is a conversation between a highly knowledgeable and intelligent AI assistant, called RWKV, and a human user, called User. In the following interactions, User and RWKV will converse in natural language, and RWKV will do its best to answer Users questions. RWKV was built to be respectful, polite and inclusive. It knows a lot, and always tells the truth. The conversation begins. # The following is a conversation between a highly knowledgeable and intelligent AI assistant called Bot, and a human user called User. In the following interactions, User and Bot converse in natural language, and Bot always answer User's questions. Bot is very smart, polite and humorous. Bot knows a lot, and always tells the truth. The conversation begins.
# User: OK RWKV, Im going to start by quizzing you with a few warm-up questions. Who is currently the president of the USA? # User: who is president of usa?
# RWKV: Its Joe Biden; he was sworn in earlier this year. # Bot: Its Joe Biden; he was sworn in earlier this year.
# User: What year was the French Revolution? # User: french revolution what year
# RWKV: It started in 1789, but it lasted 10 years until 1799. # Bot: It started in 1789, but it lasted 10 years until 1799.
# User: Can you guess who I might want to marry? # User: guess i marry who ?
# RWKV: Only if you tell me more about yourself - what are your interests? # Bot: Only if you tell me more about yourself - what are your interests?
# User: Aha, Im going to refrain from that for now. Now for a science question. What can you tell me about the Large Hadron Collider (LHC)? # User: wat is lhc
# RWKV: Its a large and very expensive piece of science equipment. If I understand correctly, its a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. # Bot: Its a large and very expensive piece of science equipment. If I understand correctly, its a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
# User:''' # User:''' # type your question here
NUM_TRIALS = 999 NUM_TRIALS = 999
LENGTH_PER_TRIAL = 333 LENGTH_PER_TRIAL = 333
@ -213,7 +224,7 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
print(char, end="", flush=True) print(char, end="", flush=True)
else: else:
char = tokenizer.tokenizer.decode(ctx[out_last:]) char = tokenizer.tokenizer.decode(ctx[out_last:])
if '\ufffd' not in char: if '\ufffd' not in char: # is valid utf8 string?
print(char, end="", flush=True) print(char, end="", flush=True)
out_last = i+1 out_last = i+1

@ -28,7 +28,7 @@ dtypes = {
3: np.int16, 3: np.int16,
4: np.int32, 4: np.int32,
5: np.int64, 5: np.int64,
6: np.float, 6: float,
7: np.double, 7: np.double,
8: np.uint16, 8: np.uint16,
} }
@ -49,6 +49,58 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object): class Index(object):
_HDR_MAGIC = b"MMIDIDX\x00\x00" _HDR_MAGIC = b"MMIDIDX\x00\x00"
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, "wb")
# Write Magic string so we can check the file format then opening it again.
self._file.write(cls._HDR_MAGIC)
# Write version number
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", 1))
# Little endian unsigned 8 Bit integer
self._file.write(struct.pack("<B", code(dtype)))
return self
@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers
def write(self, sizes, doc_idx):
pointers = self._get_pointers(sizes)
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", len(sizes)))
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", len(doc_idx)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers
doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order="C"))
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path, skip_warmup=False): def __init__(self, path, skip_warmup=False):
with open(path, "rb") as stream: with open(path, "rb") as stream:
magic_test = stream.read(9) magic_test = stream.read(9)

@ -2,7 +2,7 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
######################################################################################################## ########################################################################################################
import json, math, random import json, math, random, os, sys
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -16,33 +16,53 @@ class MyDataset(Dataset):
self.args = args self.args = args
if args.data_type == "binidx": if args.data_type == "binidx":
self.data = MMapIndexedDataset(args.data_file)
self.vocab_size = args.vocab_size self.vocab_size = args.vocab_size
print("Current vocab size =", self.vocab_size, "(make sure it's correct)") rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
self.data_size = len(self.data._bin_buffer) // 2
print(f"Data has {self.data_size} tokens.") if args.my_pile_version == 1:
self.data = MMapIndexedDataset(args.data_file)
self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
rank_zero_info(f"Data has {self.data_size} tokens.")
else:
data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n')
data_list = [i.strip().split(' ') for i in data_list]
self.data = []
self.data_size = int(data_list[-1][-1])
rank_zero_info(f"Data has {self.data_size} chunks.")
for d in data_list:
data = MMapIndexedDataset(d[0])
data_size = len(data._bin_buffer) // data._index._dtype_size
assert (data_size - args.ctx_len) == int(d[1])
self.data += [[int(d[-1]), int(d[1]), data]]
# rank_zero_info(self.data)
if args.my_qa_mask > 0:
self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document')
# self.data_pile = MMapIndexedDataset('/fsx/pile_deduped/pile_0.87_deduped_text_document')
self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
if args.my_pile_stage > 0: if args.my_pile_stage > 0:
assert self.data_size == 332115325534 and self.vocab_size == 50277 # assert self.data_size == 332115325534 and self.vocab_size == 50277
self.samples_per_epoch = args.epoch_steps * args.real_bsz self.samples_per_epoch = args.epoch_steps * args.real_bsz
assert self.samples_per_epoch == 40320 assert self.samples_per_epoch == 40320
print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
dataset_slot = self.data_size // args.ctx_len dataset_slot = self.data_size // args.ctx_len
assert MaybeIsPrime(args.magic_prime) if args.my_pile_stage != 4:
assert args.magic_prime % 3 == 2 assert MaybeIsPrime(args.magic_prime)
assert args.magic_prime / dataset_slot > 0.999999 and args.magic_prime / dataset_slot <= 1 assert args.magic_prime % 3 == 2
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
elif args.data_type == "numpy": elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int") self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size self.vocab_size = args.vocab_size
print("Current vocab size =", self.vocab_size, "(make sure it's correct)") rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data) self.data_size = len(self.data)
print(f"Data has {self.data_size} tokens.") rank_zero_info(f"Data has {self.data_size} tokens.")
elif args.data_type == "uint16": elif args.data_type == "uint16":
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len) self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
self.vocab_size = args.vocab_size self.vocab_size = args.vocab_size
print("Current vocab size =", self.vocab_size, "(make sure it's correct)") rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = self.data.shape[0] self.data_size = self.data.shape[0]
print(f"Data has {self.data_size} samples.") rank_zero_info(f"Data has {self.data_size} samples.")
elif args.data_type == "wds_img": elif args.data_type == "wds_img":
self.vocab_size = -1 self.vocab_size = -1
self.data_size = -1 self.data_size = -1
@ -50,7 +70,7 @@ class MyDataset(Dataset):
self.error_count = 0 self.error_count = 0
else: else:
if args.data_type == "dummy": if args.data_type == "dummy":
print("Building dummy data...") rank_zero_info("Building dummy data...")
self.data = "" self.data = ""
for i in range(100000): for i in range(100000):
aa = (i) % 10000 aa = (i) % 10000
@ -59,13 +79,13 @@ class MyDataset(Dataset):
self.data += f".{aa}+{bb}={cc}." self.data += f".{aa}+{bb}={cc}."
else: else:
self.data = open(args.data_file, "r", encoding=args.data_type).read() self.data = open(args.data_file, "r", encoding=args.data_type).read()
print("Building token list...") rank_zero_info("Building token list...")
unique = sorted(list(set(self.data))) unique = sorted(list(set(self.data)))
self.vocab_size = len(unique) self.vocab_size = len(unique)
# print() # rank_zero_info()
# for u in unique: # for u in unique:
# print(u, end=' ') # print(u, end=' ')
# print('\n\n') # rank_zero_info('\n\n')
xx = 0 xx = 0
xxObj = {} xxObj = {}
for u in unique: for u in unique:
@ -74,7 +94,7 @@ class MyDataset(Dataset):
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file: with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
self.data_size = len(self.data) self.data_size = len(self.data)
print("Data has %d tokens, %d vocab size." % (self.data_size, self.vocab_size)) rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
self.stoi = {ch: i for i, ch in enumerate(unique)} self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)} self.itos = {i: ch for i, ch in enumerate(unique)}
@ -136,25 +156,85 @@ class MyDataset(Dataset):
else: else:
ctx_len = args.ctx_len ctx_len = args.ctx_len
req_len = ctx_len + 1 req_len = ctx_len + 1
magic_prime = args.magic_prime
data = self.data
if args.my_pile_stage > 0: if args.my_pile_stage > 0 and args.my_pile_stage != 4:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
factor = (math.sqrt(5) - 1) / 2
factor = int(args.magic_prime * factor) if args.my_qa_mask > 0:
i = ((factor * ii * ii * ii) % args.magic_prime) * ctx_len ii_orig = ii
i = i + args.my_pile_shift if ii % 2 == 0:
ii = -1
data = self.data_pile
else:
ii = ii // 2
if ii < 0:
i = np.random.randint(0, self.data_pile_size - req_len)
else:
factor = (math.sqrt(5) - 1) / 2
factor = int(magic_prime * factor)
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
i = i + args.my_pile_shift
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
elif args.my_pile_stage == 4:
# cheat: pick a random spot in dataset
if args.my_pile_version == 1:
i = np.random.randint(0, self.data_size - req_len)
else:
i = np.random.randint(0, self.data_size)
else: else:
# cheat: pick a random spot in dataset # cheat: pick a random spot in dataset
i = np.random.randint(0, self.data_size - req_len) i = np.random.randint(0, self.data_size - req_len)
if args.data_type == "binidx": if args.data_type == "binidx":
dix = self.data.get(idx=0, offset=i, length=req_len).astype(int) if args.my_pile_version == 1:
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
else:
# self.data : cutoff, chunk_count, data
for j in range(len(data)):
if i < data[j][0]:
ii = i
i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1]
dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int)
# print(ii, j, i)
break
elif args.data_type == "numpy": elif args.data_type == "numpy":
dix = self.data[i : i + req_len] dix = data[i : i + req_len]
else: else:
dix = [self.stoi[s] for s in self.data[i : i + req_len]] dix = [self.stoi[s] for s in data[i : i + req_len]]
if args.my_qa_mask == 1:
if data == self.data_pile:
z = [1] * ctx_len
else:
z = [0] * ctx_len
z_sum = 0
isGood = False
for i in range(3, ctx_len):
if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
isGood = True
if dix[i] == 0:
isGood = False
if isGood:
z[i] = 1
z_sum += 1
if z_sum == 0:
z = [1] * ctx_len
i = np.random.randint(0, self.data_pile_size - req_len)
dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
z = torch.tensor(z, dtype=torch.bfloat16)
x = torch.tensor(dix[:-1], dtype=torch.long) x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long) y = torch.tensor(dix[1:], dtype=torch.long)
# if ii_orig < 50:
# # if rank == 1:
# print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
# else:
# exit(0)
if args.my_qa_mask == 1:
return x, y, z
return x, y return x, y

@ -2,18 +2,25 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
######################################################################################################## ########################################################################################################
import os, math, gc import os, math, gc, importlib
import torch import torch
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed if importlib.util.find_spec('deepspeed'):
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
try:
print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
except:
os.environ["RWKV_MY_TESTING"] = ''
def __nop(ob): def __nop(ob):
return ob return ob
@ -35,61 +42,93 @@ T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", f"-DTmax={T_MAX}"]) if os.environ["RWKV_FLOAT_MODE"] == "bf16":
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
class WKV(torch.autograd.Function):
class WKV(torch.autograd.Function): @staticmethod
@staticmethod def forward(ctx, B, T, C, w, u, k, v):
def forward(ctx, B, T, C, w, u, k, v): ctx.B = B
ctx.B = B ctx.T = T
ctx.T = T ctx.C = C
ctx.C = C assert T <= T_MAX
assert T <= T_MAX assert B * C % min(C, 32) == 0
assert B * C % min(C, 32) == 0 w = -torch.exp(w.float().contiguous())
if "32" in os.environ["RWKV_FLOAT_MODE"]:
w = -torch.exp(w.contiguous())
u = u.contiguous() u = u.contiguous()
k = k.contiguous() k = k.contiguous()
v = v.contiguous() v = v.contiguous()
else: y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
w = -torch.exp(w.float().contiguous()) wkv_cuda.forward(B, T, C, w, u, k, v, y)
u = u.float().contiguous() ctx.save_for_backward(w, u, k, v, y)
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
return y return y
elif os.environ["RWKV_FLOAT_MODE"] == "fp16": @staticmethod
return y.half() def backward(ctx, gy):
elif os.environ["RWKV_FLOAT_MODE"] == "bf16": B = ctx.B
return y.bfloat16() T = ctx.T
C = ctx.C
@staticmethod assert T <= T_MAX
def backward(ctx, gy): assert B * C % min(C, 32) == 0
B = ctx.B w, u, k, v, y = ctx.saved_tensors
T = ctx.T gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
C = ctx.C gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
assert T <= T_MAX gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
assert B * C % min(C, 32) == 0 gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
w, u, k, v = ctx.saved_tensors wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
gw = torch.zeros((B, C), device=gy.device).contiguous() gw = torch.sum(gw, dim=0)
gu = torch.zeros((B, C), device=gy.device).contiguous() gu = torch.sum(gu, dim=0)
gk = torch.zeros((B, T, C), device=gy.device).contiguous()
gv = torch.zeros((B, T, C), device=gy.device).contiguous()
if "32" in os.environ["RWKV_FLOAT_MODE"]:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
return (None, None, None, gw, gu, gk, gv) return (None, None, None, gw, gu, gk, gv)
elif os.environ["RWKV_FLOAT_MODE"] == "fp16": else:
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
elif os.environ["RWKV_FLOAT_MODE"] == "bf16": class WKV(torch.autograd.Function):
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) @staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
if "32" in os.environ["RWKV_FLOAT_MODE"]:
w = -torch.exp(w.contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
else:
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
ctx.save_for_backward(w, u, k, v, y)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
return y
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
return y.half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
return y.bfloat16()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
w, u, k, v, y = ctx.saved_tensors
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
return (None, None, None, gw, gu, gk, gv)
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
def RUN_CUDA(B, T, C, w, u, k, v): def RUN_CUDA(B, T, C, w, u, k, v):
@ -109,76 +148,140 @@ class RWKV_TimeMix(MyModule):
self.ctx_len = args.ctx_len self.ctx_len = args.ctx_len
self.n_embd = args.n_embd self.n_embd = args.n_embd
attn_sz = args.n_embd
with torch.no_grad(): # fancy init with torch.no_grad(): # fancy init
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
# fancy time_decay # fancy time_decay
decay_speed = torch.ones(attn_sz) decay_speed = torch.ones(args.dim_att)
for h in range(attn_sz): for h in range(args.dim_att):
decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1) decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
self.time_decay = nn.Parameter(decay_speed) self.time_decay = nn.Parameter(decay_speed)
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
# fancy time_first # fancy time_first
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5 zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5
self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag) self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)
# fancy time_mix # fancy time_mix
x = torch.ones(1, 1, args.n_embd) self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
for i in range(args.n_embd): self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
x[0, 0, i] = i / args.n_embd self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
if 'a' in os.environ["RWKV_MY_TESTING"]:
self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
d_qkv = args.n_embd // 16
self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
self.vv = nn.Linear(args.n_embd, d_qkv, bias=False)
self.oo = nn.Linear(d_qkv, args.n_embd, bias=False)
with torch.no_grad():
self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
if 'a' not in os.environ["RWKV_MY_TESTING"]:
@MyFunction
def jit_func(self, x):
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)
return sr, k, v
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v = self.jit_func(x)
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
return self.output(rwkv)
if 'a' in os.environ["RWKV_MY_TESTING"]:
@MyFunction
def QKV(self, q, k, v):
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.att_mask == 0, float('-inf'))
att = F.softmax(att, dim = -1)
x = att @ v
return x
@MyFunction
def jit_funcQKV(self, x):
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq)
xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk)
xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv)
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)
qq = self.qq(xqq)
kk = self.kk(xkk)
vv = self.vv(xvv)
return sr, k, v, qq, kk, vv
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
return rwkv
self.key = nn.Linear(args.n_embd, attn_sz, bias=False) ########################################################################################################
self.value = nn.Linear(args.n_embd, attn_sz, bias=False)
self.receptance = nn.Linear(args.n_embd, attn_sz, bias=False)
self.output = nn.Linear(attn_sz, args.n_embd, bias=False) class RWKV_ChannelMix(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
@MyFunction with torch.no_grad(): # fancy init of time_mix
def jit_func(self, x): ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
# Mix x with the previous timestep to produce xk, xv, xr self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
@MyFunction
def forward(self, x):
xx = self.time_shift(x) xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk) k = self.key(xk)
v = self.value(xv) k = torch.square(torch.relu(k))
r = self.receptance(xr) kv = self.value(k)
sr = torch.sigmoid(r) return torch.sigmoid(self.receptance(xr)) * kv
return sr, k, v
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v = self.jit_func(x)
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
rwkv = self.output(rwkv)
return rwkv
class RWKV_ChannelMix(MyModule): class MishGLU(MyModule):
def __init__(self, args, layer_id): def __init__(self, args, layer_id):
super().__init__() super().__init__()
self.args = args self.args = args
self.layer_id = layer_id self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix with torch.no_grad():
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
x = torch.ones(1, 1, args.n_embd) x = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd): for i in range(args.n_embd):
@ -186,25 +289,18 @@ class RWKV_ChannelMix(MyModule):
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
hidden_sz = 4 * args.n_embd self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.key = nn.Linear(args.n_embd, hidden_sz, bias=False) self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, args.n_embd, bias=False)
@MyFunction @MyFunction
def forward(self, x): def forward(self, x):
xx = self.time_shift(x) xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
a = self.aa(xa)
k = self.key(xk) b = self.bb(xb)
k = torch.square(torch.relu(k)) return self.value(a * F.mish(b))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
######################################################################################################## ########################################################################################################
# The RWKV Model with our blocks # The RWKV Model with our blocks
@ -231,7 +327,10 @@ class Block(nn.Module):
else: else:
self.att = RWKV_TimeMix(args, layer_id) self.att = RWKV_TimeMix(args, layer_id)
self.ffn = RWKV_ChannelMix(args, layer_id) if 'g' in os.environ["RWKV_MY_TESTING"]:
self.ffn = MishGLU(args, layer_id)
else:
self.ffn = RWKV_ChannelMix(args, layer_id)
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
self.tiny_ln = nn.LayerNorm(args.n_embd) self.tiny_ln = nn.LayerNorm(args.n_embd)
@ -286,6 +385,14 @@ class RWKV(pl.LightningModule):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.args = args self.args = args
if not hasattr(args, 'dim_att'):
args.dim_att = args.n_embd
if not hasattr(args, 'dim_ffn'):
args.dim_ffn = args.n_embd * 4
if not hasattr(args, 'tiny_att_layer'):
args.tiny_att_layer = -1
if not hasattr(args, 'tiny_att_dim'):
args.tiny_att_dim = -1
self.emb = nn.Embedding(args.vocab_size, args.n_embd) self.emb = nn.Embedding(args.vocab_size, args.n_embd)
@ -401,9 +508,38 @@ class RWKV(pl.LightningModule):
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
args = self.args args = self.args
idx, targets = batch if args.my_qa_mask != 1:
logits = self(idx) idx, targets = batch
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
else:
idx, targets, mask = batch
mask = mask.view(-1)
sum_mask = torch.sum(mask).item()
# if sum_mask == 0:
# return torch.tensor([0.0], requires_grad=True)
logits = self(idx)
if sum_mask == mask.shape[0]:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
# print('rank', self.global_rank, 'loss', loss.item())
else:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
# loss_raw = loss
loss = torch.sum(loss * mask) / sum_mask
# torch.set_printoptions(threshold=10000)
# if True: #self.global_rank == 1:
# tmp = ''
# sss = 0
# ccc = 0
# for i in range(mask.shape[0]):
# if mask[i] > 0:
# tmp += str(idx.view(-1)[i].item()) + ','
# sss += loss_raw.view(-1)[i].float().item()
# ccc += 1
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
return L2Wrap.apply(loss, logits) return L2Wrap.apply(loss, logits)
def training_step_end(self, batch_parts): def training_step_end(self, batch_parts):
@ -428,7 +564,7 @@ class RWKV(pl.LightningModule):
gain = 1.0 gain = 1.0
scale = 1.0 scale = 1.0
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n: if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
m[n] = p m[n] = p
else: else:
if n == "emb.weight": if n == "emb.weight":
@ -436,7 +572,7 @@ class RWKV(pl.LightningModule):
else: else:
if shape[0] > shape[1]: if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1]) gain = math.sqrt(shape[0] / shape[1])
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q."]: for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
if kk in n: if kk in n:
scale = 0 scale = 0
if n == "head.weight": if n == "head.weight":

@ -9,16 +9,22 @@ from torch.nn import functional as F
import torch.nn as nn import torch.nn as nn
from typing import List, Dict from typing import List, Dict
# try: MyModule = nn.Module
# import torchdynamo
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
# except:
def __nop(ob): def __nop(ob):
return ob return ob
MyFunction = __nop MyFunction = __nop
# # try torchdynamo
# import torchdynamo
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
# try torch jit --> faster for fp32, slower for fp16 (why?)
if os.environ["RWKV_JIT_ON"] == "1":
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
RWKV_HEAD_QK_DIM = 0 RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n')
DEBUG_TIME = False # True False - show trained time-coeffs DEBUG_TIME = False # True False - show trained time-coeffs
@ -26,7 +32,7 @@ RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer
############################################################################################################ ############################################################################################################
class RWKV_RNN(nn.Module): class RWKV_RNN(MyModule):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
@ -113,7 +119,7 @@ class RWKV_RNN(nn.Module):
# state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp # state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
@MyFunction @MyFunction
def FF(self, x, state, i, time_mix_k, time_mix_r, kw, vw, rw): def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
if self.FLOAT_MODE == "bf16": if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k) xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r) xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
@ -134,7 +140,7 @@ class RWKV_RNN(nn.Module):
return r * kv return r * kv
@MyFunction @MyFunction
def SA(self, x, state, i, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
if self.FLOAT_MODE == "bf16": if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k) xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v) xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)

@ -1,9 +1,17 @@
import os, math, time, datetime import os, math, time, datetime, subprocess
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
def my_save(dd, ff):
if '14b-run1' not in ff:
torch.save(dd, ff)
else:
fn = ff.split('/')[-1]
fff = '/dev/shm/' + fn
torch.save(dd, fff)
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
class train_callback(pl.Callback): class train_callback(pl.Callback):
def __init__(self, args): def __init__(self, args):
@ -97,6 +105,15 @@ class train_callback(pl.Callback):
if kt_s > 0: if kt_s > 0:
lll["kt/s"] = kt_s lll["kt/s"] = kt_s
trainer.my_wandb.log(lll, step=int(real_step)) trainer.my_wandb.log(lll, step=int(real_step))
if args.magic_prime > 0:
expand_factor = 2 if args.my_qa_mask > 0 else 1
if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1:
to_save_dict = pl_module.state_dict()
my_save(
to_save_dict,
f"{args.proj_dir}/rwkv-final.pth",
)
def on_train_epoch_start(self, trainer, pl_module): def on_train_epoch_start(self, trainer, pl_module):
args = self.args args = self.args
@ -120,12 +137,12 @@ class train_callback(pl.Callback):
else: else:
to_save_dict = pl_module.state_dict() to_save_dict = pl_module.state_dict()
try: try:
torch.save( my_save(
to_save_dict, to_save_dict,
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
) )
except: except Exception as e:
pass print('Error\n\n', e, '\n\n')
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
trainer.my_log.flush() trainer.my_log.flush()
@ -138,14 +155,32 @@ def generate_init_weight(model, init_weight_name):
mm = model.generate_init_weight() mm = model.generate_init_weight()
if model.args.my_pile_stage == 1: if model.args.my_pile_stage == 1:
try: if len(model.args.load_model) > 0:
print(f"Combine weights from {model.args.load_model}...") print(f"Combine weights from {model.args.load_model}...")
load_dict = torch.load(model.args.load_model, map_location="cpu") load_dict = torch.load(model.args.load_model, map_location="cpu")
for k in load_dict: for k in load_dict:
assert k in mm assert k in mm
mm[k] = load_dict[k].reshape(mm[k].shape) src = load_dict[k]
except: try:
print(f"\n\n!!! FAIL !!!\n\n") mm[k] = src.reshape(mm[k].shape)
except:
tmp = mm[k].squeeze().clone()
print(k, src.shape, '-->', mm[k].shape)
ss = src.shape[0]
dd = tmp.shape[0]
for i in range(dd):
pos = i / dd * ss
if pos >= ss - 1:
tmp[i] = src[ss-1]
else:
p0 = int(math.floor(pos))
ii = pos - p0
tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii)
mm[k] = tmp.reshape(mm[k].shape)
sss = src.squeeze().float().cpu().numpy()
print(sss[:10], '...', sss[-10:])
mmm = mm[k].squeeze().float().cpu().numpy()
print(mmm[:10], '...', mmm[-10:])
print(f"Save to {init_weight_name}...") print(f"Save to {init_weight_name}...")
torch.save(mm, init_weight_name) torch.save(mm, init_weight_name)

@ -5,8 +5,9 @@
if __name__ == "__main__": if __name__ == "__main__":
from argparse import ArgumentParser from argparse import ArgumentParser
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
print("########## work in progress ##########") rank_zero_info("########## work in progress ##########")
######################################################################################################## ########################################################################################################
# #
@ -66,6 +67,8 @@ if __name__ == "__main__":
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU) parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int) parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int) parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--dim_att", default=0, type=int)
parser.add_argument("--dim_ffn", default=0, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better) parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
@ -73,12 +76,13 @@ if __name__ == "__main__":
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float) parser.add_argument("--lr_final", default=1e-5, type=float)
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model parser.add_argument("--warmup_steps", default=-1, type=int) # try 50 if you load a model
parser.add_argument("--beta1", default=0.9, type=float) parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float) parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
parser.add_argument("--my_pile_edecay", default=0, type=int) parser.add_argument("--my_pile_edecay", default=0, type=int)
@ -99,20 +103,23 @@ if __name__ == "__main__":
parser.add_argument("--my_att_shift", default=1, type=int) parser.add_argument("--my_att_shift", default=1, type=int)
parser.add_argument("--my_pos_emb", default=0, type=int) parser.add_argument("--my_pos_emb", default=0, type=int)
parser.add_argument("--load_partial", default=0, type=int) parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)
parser = Trainer.add_argparse_args(parser) parser = Trainer.add_argparse_args(parser)
args = parser.parse_args() args = parser.parse_args()
######################################################################################################## ########################################################################################################
import os, warnings, math, datetime, sys, time import os, warnings, math, datetime, sys, time, importlib
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import deepspeed if "deepspeed" in args.strategy:
import deepspeed
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
if args.random_seed >= 0: if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3) print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
@ -135,6 +142,11 @@ if __name__ == "__main__":
args.betas = (args.beta1, args.beta2) args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
os.environ["RWKV_T_MAX"] = str(args.ctx_len) os.environ["RWKV_T_MAX"] = str(args.ctx_len)
os.environ["RWKV_MY_TESTING"] = args.my_testing
if args.dim_att <= 0:
args.dim_att = args.n_embd
if args.dim_ffn <= 0:
args.dim_ffn = args.n_embd * 4
if args.data_type == "wds_img": if args.data_type == "wds_img":
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}" args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
@ -145,22 +157,42 @@ if __name__ == "__main__":
os.makedirs(args.proj_dir) os.makedirs(args.proj_dir)
if args.my_pile_stage > 0: if args.my_pile_stage > 0:
if args.ctx_len == 1024: magic_prime_bak = args.magic_prime
args.magic_prime = 324331313
args.epoch_count = 8043 if args.my_pile_version == 1:
elif args.ctx_len == 2048: if args.ctx_len == 1024:
args.magic_prime = 162165671 args.magic_prime = 324331313
args.epoch_count = 4021 args.epoch_count = 8043
elif args.ctx_len == 4096: elif args.ctx_len == 2048:
args.magic_prime = 81082817 args.magic_prime = 162165671
args.epoch_count = 2010 args.epoch_count = 4021
if args.my_pile_shift < 0: elif args.ctx_len == 4096:
args.magic_prime = 81082817
args.epoch_count = 2010
elif args.ctx_len == 8192:
args.magic_prime = 40541399
args.epoch_count = 1005
else:
if args.ctx_len == 1024: if args.ctx_len == 1024:
args.my_pile_shift = 0 args.magic_prime = 1694947181
args.epoch_count = 42036
elif args.ctx_len == 2048: elif args.ctx_len == 2048:
args.my_pile_shift = 512 args.magic_prime = 847473509
args.epoch_count = 21017
elif args.ctx_len == 4096: elif args.ctx_len == 4096:
args.my_pile_shift = 768 args.magic_prime = 423736637
args.epoch_count = 10508
elif args.ctx_len == 6144:
args.magic_prime = 282491051
args.epoch_count = 7005
elif args.ctx_len == 8192:
args.magic_prime = 211868243
args.epoch_count = 5253
if args.my_pile_shift < 0:
args.my_pile_shift = 0
if magic_prime_bak > 0:
args.magic_prime = magic_prime_bak
args.epoch_steps = 40320 // args.real_bsz args.epoch_steps = 40320 // args.real_bsz
assert args.epoch_steps * args.real_bsz == 40320 assert args.epoch_steps * args.real_bsz == 40320
@ -184,10 +216,11 @@ if __name__ == "__main__":
args.load_model = f"{args.proj_dir}/rwkv-init.pth" args.load_model = f"{args.proj_dir}/rwkv-init.pth"
else: else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
if args.my_pile_stage == 2: if args.warmup_steps < 0:
args.warmup_steps = 10 if args.my_pile_stage == 2:
else: args.warmup_steps = 10
args.warmup_steps = 30 else:
args.warmup_steps = 30
args.epoch_begin = max_p + 1 args.epoch_begin = max_p + 1
samples_per_epoch = args.epoch_steps * args.real_bsz samples_per_epoch = args.epoch_steps * args.real_bsz
@ -208,9 +241,9 @@ if __name__ == "__main__":
# #
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} # Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
# #
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer # Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions) # Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer # Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
# #
############################################################################ ############################################################################
""" """
@ -225,7 +258,8 @@ if __name__ == "__main__":
assert args.precision in ["fp32", "tf32", "fp16", "bf16"] assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision os.environ["RWKV_FLOAT_MODE"] = args.precision
if args.precision == "fp32": if args.precision == "fp32":
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n") for i in range(10):
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
if args.precision == "fp16": if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
@ -269,11 +303,11 @@ if __name__ == "__main__":
generate_init_weight(model, init_weight_name) # save initial weights generate_init_weight(model, init_weight_name) # save initial weights
args.load_model = init_weight_name args.load_model = init_weight_name
print(f"########## Loading {args.load_model}... ##########") rank_zero_info(f"########## Loading {args.load_model}... ##########")
try: try:
load_dict = torch.load(args.load_model, map_location="cpu") load_dict = torch.load(args.load_model, map_location="cpu")
except: except:
print(f"Bad checkpoint {args.load_model}") rank_zero_info(f"Bad checkpoint {args.load_model}")
if args.my_pile_stage >= 2: # try again using another checkpoint if args.my_pile_stage >= 2: # try again using another checkpoint
max_p = args.my_pile_prev_p max_p = args.my_pile_prev_p
if max_p == -1: if max_p == -1:
@ -281,7 +315,7 @@ if __name__ == "__main__":
else: else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
args.epoch_begin = max_p + 1 args.epoch_begin = max_p + 1
print(f"Trying {args.load_model}") rank_zero_info(f"Trying {args.load_model}")
load_dict = torch.load(args.load_model, map_location="cpu") load_dict = torch.load(args.load_model, map_location="cpu")
if args.load_partial == 1: if args.load_partial == 1:
@ -295,6 +329,16 @@ if __name__ == "__main__":
args, args,
callbacks=[train_callback(args)], callbacks=[train_callback(args)],
) )
if trainer.global_rank == 0:
for n in model.state_dict():
shape = model.state_dict()[n].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
else:
print(f"{str(shape[0]).ljust(5)} {n}")
if "deepspeed" in args.strategy: if "deepspeed" in args.strategy:
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000

@ -0,0 +1,104 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
# this is for verifying the results of different models and make sure they agree with each other
import os, sys, types
import numpy as np
import torch
np.set_printoptions(precision=4, suppress=True, linewidth=200)
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # bf16 or fp32
os.environ['RWKV_RUN_DEVICE'] = 'cuda' # currently model_train requires CUDA
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
TOKEN_MODE = 'pile'
if TOKEN_MODE == 'pile':
WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json']
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
n_layer = 32
n_embd = 2560
ctx_len = 1024
UNKNOWN_CHAR = None
from src.utils import TOKENIZER
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
if TOKEN_MODE == 'pile':
tokenizer.vocab_size = 50277
########################################################################################################
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_T_MAX"] = str(ctx_len)
from src.model_run import RWKV_RNN
from src.model import RWKV
args = types.SimpleNamespace()
args.vocab_size = tokenizer.vocab_size
args.ctx_len = ctx_len
args.n_embd = n_embd
args.n_layer = n_layer
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
model_train = RWKV(args).to(RUN_DEVICE)
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
model_train = model_train.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
model_train = model_train.bfloat16()
print('loading ' + MODEL_NAME)
m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu')
model_train.load_state_dict(m2)
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
model_train = model_train.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
model_train = model_train.bfloat16()
args.MODEL_NAME = MODEL_NAME
args.RUN_DEVICE = RUN_DEVICE
args.FLOAT_MODE = os.environ['RWKV_FLOAT_MODE']
model_rnn = RWKV_RNN(args)
########################################################################################################
print(f"\nVerifying {os.environ['RWKV_RUN_DEVICE']} {os.environ['RWKV_FLOAT_MODE']}")
# context = '\nIn a'
context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.'
if TOKEN_MODE == 'pile':
ctx = tokenizer.tokenizer.encode(context)
print(f'input len {len(ctx)} data {ctx}')
########################################################################################################
with torch.no_grad():
print('\nRWKV-train output')
out = model_train.forward(torch.tensor([ctx]).to(RUN_DEVICE))[0].detach().cpu().float().numpy()
print(out, '\n')
print('\nRWKV-RNN output')
state = None
out = None
src_len = len(ctx)
for i in range(src_len):
x = ctx[:i+1]
out, state = model_rnn.forward(x, state)
if i < 3 or i >= src_len - 3:
print(out.detach().cpu().numpy())
if i == 2:
print('...')
Loading…
Cancel
Save