Compare commits

...

383 Commits
0.02 ... main

Author SHA1 Message Date
PENG Bo 99a5933f54
Update README.md 3 years ago
PENG Bo 5b08ee1718
Update README.md 3 years ago
PENG Bo 4e962eb850
Update README.md 3 years ago
PENG Bo 8f428408a3
Update README.md 3 years ago
PENG Bo a9007581d0
Update README.md 3 years ago
BlinkDL 79915b3696 better 3 years ago
BlinkDL 0c7cd08255 fix 3 years ago
PENG Bo 3d43eaa1c8
Update README.md 3 years ago
PENG Bo 04a564d7d0
Update README.md 3 years ago
PENG Bo 5713df51ec
Update README.md 3 years ago
PENG Bo 14d21f5a00
Update README.md 3 years ago
BlinkDL 4ca274aad7 fix 3 years ago
BlinkDL 1945cb58ed pile v2 3 years ago
BlinkDL 3d2b04ba0c Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL 107f167b4a misc 3 years ago
PENG Bo ac4ba411e6
Add files via upload 3 years ago
PENG Bo 099919058b
Update README.md 3 years ago
PENG Bo 87fab90435
Update README.md 3 years ago
PENG Bo 13b8784502
Update README.md 3 years ago
PENG Bo 513d3eb552
Add files via upload 3 years ago
PENG Bo 8b615ccc74
Update README.md 3 years ago
PENG Bo 1e0dba0421
Update README.md 3 years ago
PENG Bo f8134fb96e
Update README.md 3 years ago
PENG Bo 62fba64244
Update README.md 3 years ago
PENG Bo 123536b2a7
Update README.md 3 years ago
PENG Bo decd8e29f5
Add files via upload 3 years ago
BlinkDL 6d4dec7288 better 3 years ago
PENG Bo 8e99ac1138
Update README.md 3 years ago
PENG Bo 4056bfeba7
Add files via upload 3 years ago
PENG Bo d16e25661c
Add files via upload 3 years ago
PENG Bo 0ff2170277
Update README.md 3 years ago
PENG Bo e615f1c718
Add files via upload 3 years ago
PENG Bo 1430d4edcf
Update README.md 3 years ago
PENG Bo 4378fe6b4f
Update README.md 3 years ago
PENG Bo f38b7e3574
Update README.md 3 years ago
BlinkDL 6739df885e saves VRAM 3 years ago
BlinkDL 9f557219c4 misc 3 years ago
BlinkDL 58e9d8d972 Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL 1d72a48db0 faster bf16 & saves VRAM 3 years ago
PENG Bo 52ac194d54
Update README.md 3 years ago
BlinkDL 93d671c287 better cuda kernel 3 years ago
PENG Bo 760db55fa6
Update README.md 3 years ago
PENG Bo 6b59d8fee1
Update README.md 3 years ago
PENG Bo fc047a20b1
Update README.md 3 years ago
PENG Bo 0c77cfbbee
Update README.md 3 years ago
PENG Bo 904de99a14
Update README.md 3 years ago
PENG Bo 5f6ffc987a
Update README.md 3 years ago
PENG Bo 02178d79c9
Update README.md 3 years ago
PENG Bo ad1836b27f
Update README.md 3 years ago
PENG Bo 404c593213
Update README.md 3 years ago
PENG Bo 9917078f93
Update README.md 3 years ago
PENG Bo e0dc08a2ce
Update README.md 3 years ago
PENG Bo a3e6156136
Update README.md 3 years ago
PENG Bo 55f63f0aeb
Update README.md 3 years ago
PENG Bo 114d677bc8
Update README.md 3 years ago
BlinkDL d008cc6d8e Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL c13879ab97 misc 3 years ago
PENG Bo 78579a00d2
Update README.md 3 years ago
PENG Bo e6d9e4979a
Update README.md 3 years ago
PENG Bo 8d72d882e4
Update README.md 3 years ago
PENG Bo 3f5ac97f77
Update README.md 3 years ago
PENG Bo 81aa6dda7b
Update README.md 3 years ago
PENG Bo 366b000ee6
Add files via upload 3 years ago
PENG Bo 11acd5e5b5
Update README.md 3 years ago
PENG Bo 374b086911
Add files via upload 3 years ago
BlinkDL 7476c69f32 fix 3 years ago
BlinkDL 6ed3a3db09 Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL e2ec7ae023 test 3 years ago
PENG Bo 71a46ca0f3
Update README.md 3 years ago
PENG Bo b97c25b9e7
Update README.md 3 years ago
PENG Bo 3d15f41a16
Update README.md 3 years ago
PENG Bo bbacb62b89
Add files via upload 3 years ago
BlinkDL c7b1900270 prepare for v4c 3 years ago
BlinkDL 038f06b996 rwkv-4b 3 years ago
PENG Bo f03efd0218
Update README.md 3 years ago
PENG Bo 9721b8f9c5
Update README.md 3 years ago
PENG Bo 79aa59ff2b
Add files via upload 3 years ago
PENG Bo 8e63b75f2c
Add files via upload 3 years ago
PENG Bo 5c8eda8595
Add files via upload 3 years ago
PENG Bo 13c6149205
Update README.md 3 years ago
PENG Bo f6cb1a1947
Update README.md 3 years ago
PENG Bo aeae6c8aac
Update README.md 3 years ago
PENG Bo b562097da1
Update README.md 3 years ago
BlinkDL f79d082053 testing 3 years ago
PENG Bo 8bf7061705
Update README.md 3 years ago
BlinkDL dd3845752a Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL b4925900e7 misc 3 years ago
PENG Bo fcb0b9819d
Update README.md 3 years ago
PENG Bo be18c53fec
Add files via upload 3 years ago
PENG Bo eac471da29
Update README.md 3 years ago
PENG Bo d1bb270fb3
Update README.md 3 years ago
PENG Bo 5837ee32c4
Update README.md 3 years ago
PENG Bo 8e1130e12a
Update README.md 3 years ago
BlinkDL b2a240d73d misc 3 years ago
PENG Bo bc47cb9f1a
Update README.md 3 years ago
PENG Bo 3461b2f6fb
Add files via upload 3 years ago
BlinkDL 295af9a517 info 3 years ago
BlinkDL dc26998708 torch jit 3 years ago
BlinkDL 75929cbbba torch jit (xx% faster inference) 3 years ago
PENG Bo 819f2730b2
Update README.md 3 years ago
PENG Bo 66c1dabb94
Add files via upload 3 years ago
PENG Bo 379c97890b
Update README.md 3 years ago
PENG Bo 83a4512b74
Update README.md 3 years ago
BlinkDL 59e6deeb58 better chat 3 years ago
BlinkDL cf340264dc better 3 years ago
PENG Bo 935d8d3e87
Update README.md 3 years ago
BlinkDL 0d0cedfcd9 better chatbot 3 years ago
PENG Bo 511b7adb4f
Add files via upload 3 years ago
BlinkDL aaf1341af7 better chat 3 years ago
BlinkDL e64ce9b0ff Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL 3e0f8054c6 better prompt 3 years ago
PENG Bo 0131543e48
Update README.md 3 years ago
PENG Bo 315ce82e38
Add files via upload 3 years ago
BlinkDL 9eb1a7b3d3 Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL 7a7c06aed3 misc 3 years ago
PENG Bo 2e5704097d
Update README.md 3 years ago
BlinkDL 529be15c67 Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM into main 3 years ago
BlinkDL 5c73bccd5a bug fix 3 years ago
PENG Bo 7bdfd1cb64
Update README.md 3 years ago
BlinkDL eaac3f7e66 chatbot 3 years ago
PENG Bo 14544aea94
Update README.md 3 years ago
PENG Bo 3fc16a86ed
Update README.md 3 years ago
PENG Bo eb3ca86f0b
Add files via upload 3 years ago
BlinkDL 23f64aeebc misc improvements 3 years ago
PENG Bo a39a175d09
Update README.md 3 years ago
PENG Bo 47f3eab0a7
Update README.md 3 years ago
PENG Bo 6528bbf850
Add files via upload 3 years ago
PENG Bo b4359937dc
Update README.md 3 years ago
PENG Bo 0957b9b531
Update README.md 3 years ago
PENG Bo 729ab9abee
Add files via upload 3 years ago
BlinkDL a268cd2e40 better tinyAtt 3 years ago
PENG Bo de8bae7778
Update README.md 3 years ago
BlinkDL e9b24370d9 RWKV-4a (tinyAtt) 3 years ago
BlinkDL 2567c8c904 rescale to avoid FP16 overflow 3 years ago
BlinkDL aef9f6f7ef fp16 3 years ago
PENG Bo 605637ca6f
Update README.md 3 years ago
PENG Bo fbe39374d6
Delete RWKV-eval.png.png 3 years ago
PENG Bo 4dd2024286
Add files via upload 3 years ago
PENG Bo 641bd6da1d
Add files via upload 3 years ago
PENG Bo a6a14f1bf9
Update README.md 3 years ago
PENG Bo 6fe2f798a1
Update README.md 3 years ago
PENG Bo 434ddb1f04
Update README.md 3 years ago
PENG Bo 54f28abafe
Add files via upload 3 years ago
PENG Bo 72a307f207
Update README.md 3 years ago
PENG Bo 3fcc930e1e
Add files via upload 3 years ago
PENG Bo 8d8ce9da7a
Update README.md 3 years ago
PENG Bo 6b58ba8116
Add files via upload 3 years ago
PENG Bo 737fb7dc8e
Update README.md 3 years ago
BlinkDL 8a4a41a3aa better 3 years ago
BlinkDL daed379db2 bf16 inference - 15G VRAM for 7b model 3 years ago
BlinkDL 1479315677 update 3 years ago
BlinkDL 40237495cc bugfix 3 years ago
PENG Bo a86b15a3da
Update README.md 3 years ago
BlinkDL c7155525bb better 3 years ago
BlinkDL fc3bc1eb0e good for UTF-8 inference (such as CJK) 3 years ago
BlinkDL 77bcaa8247 +AE model 3 years ago
BlinkDL 9e325b88cb OpenCLIP 3 years ago
BlinkDL 74fedc0d86 CLIP-guided Binary AutoEncoder 3 years ago
BlinkDL bb59fffac1 img model 3 years ago
BlinkDL 40e91dd1d7 img models (wip) 3 years ago
BlinkDL 3a7e6a6aa3 faster inference 3 years ago
PENG Bo bd6803df76
Update README.md 3 years ago
PENG Bo 4d83dd2d53
Update README.md 3 years ago
PENG Bo 56956d7b99
Update README.md 3 years ago
PENG Bo 5163323d4e
Create LICENSE 3 years ago
BlinkDL 44f07e44d2 better 3 years ago
PENG Bo c22c7dadca
Update README.md 3 years ago
PENG Bo 4c946fac2e
Add files via upload 3 years ago
PENG Bo 70b9075728
Update README.md 3 years ago
PENG Bo a16c6c3a75
Update README.md 3 years ago
BlinkDL 74ffd9bec7 multinode 3 years ago
BlinkDL 470ac7d1fa better 3 years ago
BlinkDL 24e3cf6e93 better 3 years ago
BlinkDL 1189c9e238 fix 3 years ago
BlinkDL 7b92a979d8 fix 3 years ago
BlinkDL ceafd4e7af better 3 years ago
BlinkDL 94300caba1 better 3 years ago
BlinkDL 99a3dff414 code for pile training 3 years ago
BlinkDL d1674732ed clean code 3 years ago
BlinkDL 2388bca7c3 dummy data example 3 years ago
BlinkDL 778c0a7f58 improvements 3 years ago
BlinkDL f81349f127 fix 3 years ago
BlinkDL 6ab2e71c25 finetuning 1.5B model using 16G VRAM 3 years ago
BlinkDL 23b0c74950 fix 3 years ago
BlinkDL cdb098c0e0 fix 3 years ago
BlinkDL 8abea9c08d wip 3 years ago
BlinkDL ba6e9e6264 rwkv-v4neo test 3 years ago
BlinkDL 50587bd65f fix for jit modules 3 years ago
BlinkDL 2b4539cd08 faster (use torch 1.12.1+cu116 or newer) 3 years ago
BlinkDL 2f33901c10 no message 3 years ago
BlinkDL 6ff859db80 no message 3 years ago
BlinkDL 09c76b185a no message 3 years ago
PENG Bo c49fd38ba1
Update README.md 3 years ago
PENG Bo f3cf7a5ad1
Update README.md 3 years ago
PENG Bo 5c29eb779c
Add files via upload 3 years ago
BlinkDL 8cced0383e no message 3 years ago
PENG Bo f46cf49852
Update README.md 3 years ago
BlinkDL 2815260d83 better 3 years ago
BlinkDL dc7e0802d0 faster 3 years ago
BlinkDL c43a17cfb3 10% faster training 3 years ago
BlinkDL c84e8fd952 bugfix 3 years ago
BlinkDL 73b96705d7 + fp32 mode (slow but good for verification) 3 years ago
BlinkDL 94f618c52a Merge branch 'main' of https://github.com/BlinkDL/RWKV-LM 3 years ago
BlinkDL a1bf15ac40 no message 3 years ago
PENG Bo e05c69452d
Update README.md 3 years ago
PENG Bo 5f1a473845
Update README.md 3 years ago
BlinkDL 6299c087a4 fixed VRAM consumpition 3 years ago
PENG Bo cb520e0f15
Update README.md 3 years ago
PENG Bo a01d915fcc
Update README.md 3 years ago
PENG Bo 4c2080aadd
Add files via upload 3 years ago
BlinkDL c1f7a72724 saves some VRAM for 1 GPU training 3 years ago
PENG Bo 69e7cfbf39
Update README.md 3 years ago
PENG Bo f757518277
Update README.md 3 years ago
PENG Bo 2bddd576cd
Update README.md 3 years ago
PENG Bo 1949ed8619
Update README.md 3 years ago
BlinkDL 68c486ad10 supports RWKV-4 pile models 3 years ago
BlinkDL 61b7c429df no message 3 years ago
PENG Bo 8d4fed7128
Update README.md 3 years ago
BlinkDL 7cdc8d3164 no message 3 years ago
BlinkDL 13bb641007 no message 3 years ago
BlinkDL f79137b524 supports megatron bin+idx format 3 years ago
PENG Bo aa67870849
Update README.md 3 years ago
BlinkDL 4c8a1a467b no message 3 years ago
BlinkDL 083f9504c6 + bf16 mode (more stable) 3 years ago
PENG Bo 46eebd98ca
Add files via upload 3 years ago
PENG Bo c0c4ffc7b4
Update README.md 3 years ago
BlinkDL 6667ad18c2 more training tips 3 years ago
PENG Bo 5f6e9356a2
Update README.md 3 years ago
BlinkDL 165dfd1b9e RWKV-4 with DeepSpeed & FP16 & Better CUDA Kernel 3 years ago
PENG Bo dfb75dd89d
Update README.md 3 years ago
PENG Bo e9488edafd
Update README.md 3 years ago
PENG Bo 3f158ac736
Update README.md 3 years ago
PENG Bo ec91ff1857
Update README.md 3 years ago
PENG Bo 276e322abf
Update README.md 3 years ago
PENG Bo a4b0759bf6
Update README.md 3 years ago
PENG Bo 5bd56f1f2d
Add files via upload 3 years ago
PENG Bo 3f699e6d2f
Update README.md 4 years ago
PENG Bo f4cd0a5a58
Update README.md 4 years ago
PENG Bo 0ef922ebfc
Update README.md 4 years ago
PENG Bo 19005e8067
Update README.md 4 years ago
PENG Bo 14abfebc94
Update README.md 4 years ago
PENG Bo 3ef305da6e
Update README.md 4 years ago
PENG Bo 3de62b92c3
Update README.md 4 years ago
PENG Bo 456437021f
Update README.md 4 years ago
PENG Bo b91cca965f
Add files via upload 4 years ago
PENG Bo 4cb363e5aa
Update README.md 4 years ago
BlinkDL 8d780208f2 typo fix 4 years ago
BlinkDL 8556d0fd3f no message 4 years ago
BlinkDL b6b5f4628f tips for training, with exponential lr decay 4 years ago
PENG Bo f28be63cd8
Update README.md 4 years ago
PENG Bo b6403a8aef
RWKV-3 (test deeper models (n_layer >= 12) to see the advantage) 4 years ago
PENG Bo 1f6461b90b
Update README.md 4 years ago
PENG Bo 041227cdad
Update README.md 4 years ago
PENG Bo 940a67ec27
Update README.md 4 years ago
PENG Bo 0c6ffa710d
Update README.md 4 years ago
PENG Bo 09132dea52
Update README.md 4 years ago
PENG Bo 1a49ec4eeb
Add files via upload 4 years ago
PENG Bo f57d8e3e58
Add files via upload 4 years ago
PENG Bo 234aa8a5bb
Update README.md 4 years ago
PENG Bo 803397945b
Update README.md 4 years ago
PENG Bo f45198b1db
Add files via upload 4 years ago
PENG Bo 48e5a5326d
Update README.md 4 years ago
PENG Bo a2de885fdd
Update README.md 4 years ago
PENG Bo bc85865775
Update README.md 4 years ago
PENG Bo c4e8057a63
Add files via upload 4 years ago
PENG Bo c103f0caa3
Update README.md 4 years ago
PENG Bo 1301d383bb
Update README.md 4 years ago
PENG Bo 92da17d68f
Update README.md 4 years ago
PENG Bo 8d63423a56
Update README.md 4 years ago
PENG Bo 5676b039ad
Update README.md 4 years ago
PENG Bo 99089254e4
Update README.md 4 years ago
PENG Bo b54b204074
Update README.md 4 years ago
PENG Bo 6f7240e693
Update README.md 4 years ago
PENG Bo 3523a0ab72
Update README.md 4 years ago
PENG Bo 5bb0c1167f
Update README.md 4 years ago
PENG Bo bff3b0359d
Update README.md 4 years ago
PENG Bo 0fc4bdb247
Add files via upload 4 years ago
PENG Bo f8ed36f6d4
Add files via upload 4 years ago
PENG Bo a8e71e60af
Update README.md 4 years ago
PENG Bo aa90f10547
Update README.md 4 years ago
PENG Bo a862195475
Add files via upload 4 years ago
PENG Bo 2d6d54d1db
Add files via upload 4 years ago
PENG Bo 68fcf7e296
Update README.md 4 years ago
PENG Bo d7b8fc9169
Add files via upload 4 years ago
PENG Bo 65ae5835d4
Update README.md 4 years ago
PENG Bo 00fa276018
Add files via upload 4 years ago
PENG Bo aa8c170cfe
Add files via upload 4 years ago
PENG Bo f20564a15d
Update README.md 4 years ago
PENG Bo 53f1baeaad
Update README.md 4 years ago
PENG Bo 16655d369d
Update README.md 4 years ago
PENG Bo 1f67abe4e8
Update README.md 4 years ago
PENG Bo 927aa67e04
Add files via upload 4 years ago
PENG Bo 812a7d76cf
Update README.md 4 years ago
PENG Bo a4d3a44e13
Update README.md 4 years ago
PENG Bo a8a947eda0
Update README.md 4 years ago
PENG Bo d6ff9a085f
Update README.md 4 years ago
PENG Bo 9218a1b8a5
Update README.md 4 years ago
PENG Bo 8e2a779c4c
Update README.md 4 years ago
PENG Bo af1ecec13d
Update README.md 4 years ago
PENG Bo 4829bc17fb
Update README.md 4 years ago
PENG Bo 676a59886a
Update README.md 4 years ago
PENG Bo 16761a7f20
Update README.md 4 years ago
PENG Bo 29905eedb4
Update README.md 4 years ago
PENG Bo bbd6e2223b
Update README.md 4 years ago
PENG Bo 2e9fcfc5f3
Update README.md 4 years ago
PENG Bo 601ade3e0d
Update README.md 4 years ago
PENG Bo 24b30a83c6
Update README.md 4 years ago
PENG Bo ece62474d4 Update README.md 4 years ago
PENG Bo df2fa3fed4
Update README.md 4 years ago
PENG Bo 7a47bcd096
Update README.md 4 years ago
PENG Bo 1691141765
Update README.md 4 years ago
PENG Bo d3b7092249
Add files via upload 4 years ago
PENG Bo de8c54dfd4
Update README.md 4 years ago
PENG Bo 4bbee4bb1a
Update README.md 4 years ago
PENG Bo 15db7d3e14
Update README.md 4 years ago
PENG Bo 78169ad5a9
Update README.md 4 years ago
PENG Bo f44ad5156e
Update README.md 4 years ago
PENG Bo b15642717a
Update README.md 4 years ago
PENG Bo e472f9ffbb
Add files via upload 4 years ago
PENG Bo 089308fd92
Update README.md 4 years ago
PENG Bo b07127f299
Add files via upload 4 years ago
PENG Bo 29a21844ab
Update README.md 4 years ago
PENG Bo cd19235c00
Add files via upload 4 years ago
PENG Bo 89c953b941
Update README.md 4 years ago
PENG Bo 9066053fe5
Update README.md 4 years ago
PENG Bo c5788dd30f
Update README.md 4 years ago
PENG Bo 34c9b6fc0f
Update README.md 4 years ago
PENG Bo 9e09fafd59
Update README.md 4 years ago
PENG Bo 1e1aacba54
Update README.md 4 years ago
PENG Bo caf480fac3
Add files via upload 4 years ago
PENG Bo 62d5b5d481
Update README.md 4 years ago
PENG Bo c2ad6e9d8d
Update README.md 4 years ago
PENG Bo e03c0c9467
Update README.md 4 years ago
PENG Bo 64c9015dd1
Update README.md 4 years ago
PENG Bo 325ddb76f7
Add files via upload 4 years ago
PENG Bo 7212d75e7c
Update README.md 4 years ago
PENG Bo 750c5149ab
Update README.md 4 years ago
PENG Bo 2bc3912594
Update README.md 4 years ago
PENG Bo 31a0f3944c
Update README.md 4 years ago
PENG Bo 12fe2002a5
Update README.md 4 years ago
PENG Bo 3fbdffb6c6
Update README.md 4 years ago
PENG Bo 214ebf9e47
Add files via upload 4 years ago
PENG Bo 2971dae91f
Update README.md 4 years ago
PENG Bo e65448716d
Update README.md 4 years ago
PENG Bo fffbcb6785
Update README.md 4 years ago
PENG Bo 2951b2895a
Update README.md 4 years ago
PENG Bo e541d93b97
Update README.md 4 years ago
PENG Bo 6b1ba8a9bd
Update README.md 4 years ago
PENG Bo 619c8add7a
Update README.md 4 years ago
PENG Bo 438377af6d
Add files via upload
+results
4 years ago
PENG Bo 9cf873ebb2
Update README.md 4 years ago
PENG Bo af56c2446d
Update README.md 4 years ago
PENG Bo ecaf1f98aa
Update README.md 4 years ago
PENG Bo e2f3465fe6
Update README.md 4 years ago
PENG Bo f11b97ae48
Update README.md 4 years ago
BlinkDL 88e921bf10 +eval code for 27M ppl 1.65 BPC 0.72 enwik8 model 4 years ago
BlinkDL 71538e44a9 refactoring 4 years ago
BlinkDL 5817d265c3 no message 4 years ago
BlinkDL 0b6aec3da6 Default parameters for 8G VRAM 4 years ago
PENG Bo 6aefe59c3d
Update README.md 4 years ago
PENG Bo da6f35f276
Update README.md 4 years ago
PENG Bo 72a6f28add
Update README.md 4 years ago
PENG Bo d8234047e6
Update README.md 4 years ago
PENG Bo 5ac82f37be
Update README.md 4 years ago
PENG Bo a86ac324de
Update README.md 4 years ago
PENG Bo 9ca62d3a1e
Update README.md 4 years ago
PENG Bo 4b1df60e94
Update README.md 4 years ago
PENG Bo 780bed4e19
Update README.md 4 years ago
BlinkDL 5f21ddf20d RWKV v2 RNN is here. Probably the strongest LM as of now. 4 years ago
PENG Bo 1f189a4034
Update README.md 4 years ago
PENG Bo 5ac265691f
RWKV-v2-RNN introduction 4 years ago
PENG Bo b4fd1a7209
Update model.py 4 years ago
PENG Bo c8a751ed8b
Update README.md 4 years ago
BlinkDL 0a0eae447d +headQK (compatible with 2022-02-15 AI-Writer) 4 years ago
BlinkDL b48aa1d430 no message 4 years ago
BlinkDL a19be54bf5 no message 4 years ago
BlinkDL fcd01f8851 no message 4 years ago
BlinkDL 76e241b71e saves vocab.json, and the model every X epoch 4 years ago
PENG Bo 689a6a924d
Update train.py 4 years ago
PENG Bo 34fa2ec81b
Update README.md 4 years ago
PENG Bo 58bdb908f9
Update README.md 4 years ago
PENG Bo 3d8d0373b4
Update README.md 4 years ago
BlinkDL 710d3e34b7 better init for RWKV 4 years ago

6
.gitignore vendored

@ -5,6 +5,12 @@
*.xlsx *.xlsx
*.xls *.xls
wandb/ wandb/
data/
vocab.json
*.sh
*log/
test/
tools/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

@ -1,25 +1,201 @@
BSD 2-Clause License Apache License
Version 2.0, January 2004
Copyright (c) 2021, PENG Bo http://www.apache.org/licenses/
All rights reserved.
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met: 1. Definitions.
1. Redistributions of source code must retain the above copyright notice, this "License" shall mean the terms and conditions for use, reproduction,
list of conditions and the following disclaimer. and distribution as defined by Sections 1 through 9 of this document.
2. Redistributions in binary form must reproduce the above copyright notice, "Licensor" shall mean the copyright owner or entity authorized by
this list of conditions and the following disclaimer in the documentation the copyright owner that is granting the License.
and/or other materials provided with the distribution.
"Legal Entity" shall mean the union of the acting entity and all
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" other entities that control, are controlled by, or are under common
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE control with that entity. For the purposes of this definition,
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE "control" means (i) the power, direct or indirect, to cause the
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE direction or management of such entity, whether by contract or
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL otherwise, or (ii) ownership of fifty percent (50%) or more of the
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR outstanding shares, or (iii) beneficial ownership of such entity.
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, "You" (or "Your") shall mean an individual or Legal Entity
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE exercising permissions granted by this License.
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

@ -1,4 +1,502 @@
# The RWKV Language Model # The RWKV Language Model (and my LM tricks)
## RWKV: Parallelizable RNN with Transformer-level LLM Performance (pronounced as "RwaKuv", from 4 major params: R W K V)
RWKV is an RNN with Transformer-level LLM performance, which can also be directly trained like a GPT transformer (parallelizable). And it's 100% attention-free. You only need the hidden state at position t to compute the state at position t+1. You can use the "GPT" mode to quickly compute the hidden state for the "RNN" mode.
So it's combining the best of RNN and transformer - **great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding** (using the final hidden state).
**HuggingFace Gradio demo (14B ctx8192)**: https://huggingface.co/spaces/BlinkDL/ChatRWKV-gradio
Raven (7B finetuned on Alpaca) Demo: https://huggingface.co/spaces/BlinkDL/Raven-RWKV-7B
**ChatRWKV:** with "stream" and "split" strategies and INT8. **3G VRAM is enough to run RWKV 14B :)** https://github.com/BlinkDL/ChatRWKV
**RWKV pip package**: https://pypi.org/project/rwkv/
```python
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
from rwkv.model import RWKV # pip install rwkv
model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda fp16')
out, state = model.forward([187, 510, 1563, 310, 247], None) # use 20B_tokenizer.json
print(out.detach().cpu().numpy()) # get logits
out, state = model.forward([187, 510], None)
out, state = model.forward([1563], state) # RNN has state (use deepcopy if you want to clone it)
out, state = model.forward([310, 247], state)
print(out.detach().cpu().numpy()) # same result as above
```
**Download RWKV-4 0.1/0.4/1.5/3/7/14B weights**: https://huggingface.co/BlinkDL
## Join Our Discord: https://discord.gg/bDSBUMeFpc (lots of developers)
**Twitter**: https://twitter.com/BlinkDL_AI
**RWKV in 150 lines** (model, inference, text generation): https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py
ChatRWKV with RWKV 14B ctx8192:
![RWKV-chat](RWKV-chat.png)
You are welcome to join the RWKV discord https://discord.gg/bDSBUMeFpc to build upon it. We have plenty of potential compute (A100 40Gs) now (thanks to Stability and EleutherAI), so if you have interesting ideas I can run them.
![RWKV-eval2](RWKV-eval2.png)
RWKV [loss vs token position] for 10000 ctx4k+ documents in Pile. RWKV 1B5-4k is mostly flat after ctx1500, but 3B-4k and 7B-4k and 14B-4k have some slopes, and they are getting better. This debunks the old view that RNNs cannot model long ctxlens. We can predict that RWKV 100B will be great, and RWKV 1T is probably all you need :)
![RWKV-ctxlen](RWKV-ctxlen.png)
I believe RNN is a better candidate for fundamental models, because: (1) It's more friendly for ASICs (no kv cache). (2) It's more friendly for RL. (3) When we write, our brain is more similar to RNN. (4) The universe is like an RNN too (because of locality). Transformers are non-local models.
RWKV-3 1.5B on A40 (tf32) = always 0.015 sec/token, tested using simple pytorch code (no CUDA), GPU utilization 45%, VRAM 7823M
GPT2-XL 1.3B on A40 (tf32) = 0.032 sec/token (for ctxlen 1000), tested using HF, GPU utilization 45% too (interesting), VRAM 9655M
Training speed: (new training code) RWKV-4 14B BF16 ctxlen4096 = 114K tokens/s on 8x8 A100 80G (ZERO2+CP). (old training code) RWKV-4 1.5B BF16 ctxlen1024 = 106K tokens/s on 8xA100 40G.
I am doing image experiments too (For example: https://huggingface.co/BlinkDL/clip-guided-binary-autoencoder) and RWKV will be able to do txt2img diffusion :) My idea: 256x256 rgb image -> 32x32x13bit latents -> apply RWKV to compute transition probability for each of the 32x32 grid -> pretend the grids are independent and "diffuse" using these probabilities.
Smooth training - no loss spikes! (lr & bsz change around 15G tokens)
![RWKV-loss](RWKV-loss.png)
![RWKV-eval](RWKV-eval.png)
All of the trained models will be open-source. Inference is very fast (only matrix-vector multiplications, no matrix-matrix multiplications) even on CPUs, so you can even run a LLM on your phone.
How it works: RWKV gathers information to a number of channels, which are also decaying with different speeds as you move to the next token. It's very simple once you understand it.
**RWKV is parallelizable because the time-decay of each channel is data-independent (and trainable)**. For example, in usual RNN you can adjust the time-decay of a channel from say 0.8 to 0.5 (these are called "gates"), while in RWKV you simply move the information from a W-0.8-channel to a W-0.5-channel to achieve the same effect. Moreover, you can fine-tune RWKV into a non-parallelizable RNN (then you can use outputs of later layers of the previous token) if you want extra performance.
![RWKV-formula](RWKV-formula.png)
Here are some of my TODOs. Let's work together :)
* HuggingFace integration (check https://github.com/huggingface/transformers/issues/17230
), and optimized CPU & iOS & Android & WASM & WebGL inference. RWKV is a RNN and very friendly for edge devices. Let's make it possible to run a LLM on your phone.
* Test it on bidirectional & MLM tasks, and image & audio & video tokens. I think RWKV can support Encoder-Decoder via this: for each decoder token, use a learned mixture of [decoder previous hidden state] & [encoder final hidden state]. Hence all decoder tokens will have access to the encoder output.
* Now training RWKV-4a with one single tiny extra attention (just a few extra lines comparing with RWKV-4) to further improve some difficult zeroshot tasks (such as LAMBADA) for smaller models. See https://github.com/BlinkDL/RWKV-LM/commit/a268cd2e40351ee31c30c5f8a5d1266d35b41829
User feedback:
> *I've so far toyed around the character-based model on our relatively small pre-training dataset (around 10GB of text), and the results are extremely good - similar ppl to models taking much, much longer to train.*
> *dear god rwkv is fast. i switched to another tab after starting training it from scratch & when i returned it was emitting plausible english & maori words, i left to go microwave some coffee & when i came back it was producing fully grammatically correct sentences.*
Tweet from Sepp Hochreiter (thank you!): https://twitter.com/HochreiterSepp/status/1524270961314484227
You can find me (BlinkDL) in the EleutherAI Discord too: https://www.eleuther.ai/get-involved/
![RWKV-demo](RWKV-demo.png)
## Quick start
Use https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v4neo (latest code, compatible with v4).
Here is a great prompt for testing Q&A of LLMs. Works for any model: (found by minimizing ChatGPT ppls for RWKV 1.5B)
```python
prompt = f'\nQ & A\n\nQuestion:\n{qq}\n\nDetailed Expert Answer:\n' # let the model generate after this
```
**Cool Community RWKV Projects (check them!)**:
https://pypi.org/project/rwkvstic/ a pip package (with 8bit & offload for low VRAM GPUs)
https://github.com/harrisonvanderbyl/rwkv_chatbot a chatbot
https://github.com/hizkifw/WebChatRWKVstic WebUI (WIP)
https://github.com/gururise/rwkv_gradio RWKV Gradio
https://github.com/cryscan/eloise RWKV QQ bot
https://github.com/Blealtan/RWKV-LM-LoRA LoRA fine-tuning
https://github.com/mrsteyk/RWKV-LM-jax
https://github.com/wozeparrot/tinyrwkv RWKV in tinygrad (nice simple DL framework)
https://github.com/huggingface/transformers/issues/17230 RWKV HF package (WIP)
https://github.com/ArEnSc/Production-RWKV RWKV HF package source
https://github.com/nlpodyssey/verbaflow RWKV in Go
https://github.com/nlpodyssey/rwkv RWKV in Go
https://github.com/mrsteyk/rwkvk-rs RWKV in Rust
https://github.com/josephrocca/rwkv-v4-web RWKV in browser
https://github.com/imxcstar/CSharp-RWKV-V4 RWKV in C#
https://github.com/mrsteyk/RWKV-LM-deepspeed Another training fork
https://github.com/resloved/RWKV-notebooks RWKV colab notebooks
https://colab.research.google.com/github/harrisonvanderbyl/rwkvstic/blob/master/notebooks/chatbot.ipynb RWKV chatbot colab notebook
https://github.com/Pathos14489/RWKVDistributedInference RWKV Distributed Inference
https://github.com/AXKuhta/rwkv-onnx-dml RWKV ONNX
### Inference
**Run RWKV-4 Pile models:** Download models from https://huggingface.co/BlinkDL. Set TOKEN_MODE = 'pile' in run.py and run it. It's fast even on CPU (the default mode).
**Colab for RWKV-4 Pile 1.5B**: https://colab.research.google.com/drive/1F7tZoPZaWJf1fsCmZ5tjw6sYHiFOYVWM
Run RWKV-4 Pile models in your browser (and onnx version): see this issue https://github.com/BlinkDL/RWKV-LM/issues/7
RWKV-4 Web Demo: https://josephrocca.github.io/rwkv-v4-web/demo/ (note: only greedy sampling for now)
For the old RWKV-2: see the release here for a 27M params model on enwik8 with 0.72 BPC(dev). Run run.py in https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN. You can even run it in your browser: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng https://blinkdl.github.io/AI-Writer/eng/ (this is using tf.js WASM single-thread mode).
### Training / Fine-tuning
**Training RWKV-4 from scratch:** run train.py, which by default is using the enwik8 dataset (unzip https://data.deepai.org/enwik8.zip).
You will be training the "GPT" version because it's paralleziable and faster to train. RWKV-4 can extrapolate, so training with ctxLen 1024 can work for ctxLen of 2500+. You can fine-tune the model with longer ctxLen and it can quickly adapt to longer ctxLens.
**Fine-tuning RWKV-4 Pile models:** use 'prepare-data.py' in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/tree/main/RWKV-v3 to tokenize .txt into train.npy data. Then set EXPRESS_PILE_MODE to True in train.py, and run it.
Read the inference code in src/model.py and try using the final hidden state.xx .aa .bb) as a faithful sentence embedding for other tasks. Probably you should begin with .xx and .aa/.bb (.aa divided by .bb).
Colab for fine-tuning RWKV-4 Pile models: https://colab.research.google.com/github/resloved/RWKV-notebooks/blob/master/RWKV_v4_RNN_Pile_Fine_Tuning.ipynb
**Large corpus:** Use https://github.com/EleutherAI/gpt-neox to convert .jsonl into .bin and .idx
```
python tools/preprocess_data.py --input ./my_data.jsonl --output-prefix ./data/my_data --vocab ./20B_tokenizer.json --dataset-impl mmap --tokenizer-type HFTokenizer --append-eod
```
The jsonl format sample (one line for each document):
```
{"meta": {"ID": 101}, "text": "This is the first document."}
{"meta": {"ID": 102}, "text": "Hello\nWorld"}
{"meta": {"ID": 103}, "text": "1+1=2\n1+2=3\n2+2=4"}
```
generated by code like this:
```
ss = json.dumps({"meta": meta, "text": text}, ensure_ascii=False)
out.write(ss + "\n")
```
## Towards RWKV-5 (just to record some new ideas)
### Some ideas
1. Now time decay is like 0.999^T (0.999 is learnable). Change it to something like (0.999^T + 0.1) where 0.1 is learnable too. The 0.1 part will be kept forever. Or, A^T + B^T + C = fast-decay + slow-decay + constant. Can even use different formulas (for example, K^2 instead of e^K for a decay component, or, without normalization).
2. Use complex-valued decay (so, rotation instead of decay) in some channels.
3. Inject some trainable and extrapolatable positional encoding?
4. Aside from 2d rotation, we can try other Lie groups such as 3d rotation ( SO(3) ). Non-abelian RWKV lol.
5. RWKV might be great on analog devices (search for Analog Matrix-vector multiplication & Photonic Matrix-vector multiplication). The RNN mode is very hardware-friendly (processing-in-memory). Can be a SNN too (https://github.com/ridgerchu/SpikeGPT). I wonder if it can be optimized for quantum computation.
6. Trainable initial hidden state (xx aa bb pp xx).
### Vision Tasks
1. I find it's good to add a 2d pos encoding:
```
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
...
x = x + pos_emb_x + pos_emb_y
```
2. In a BPE langauge model, it's the best to use [tokenShift of 1 token] (you can mix more tokens in a char-level English model). However you can try [tokenShift of N (or N-1) (or N+1) tokens] if the image size is N x N, because that will be like mixing [the token above the current positon (or the token above the to-be-predicted positon)] with [current token]. You can use try different tokenShift styles for "ATT" & "FFN", or mixing different tokenShift styles - such as mixing [token A] with [token A-1] and [token A-(N-1)] etc.
### Misc
I have an idea to improve tokenization. We can hardcode some channels to have meanings. Example:
Channel 0 = "space"
Channel 1 = "capitalize first letter"
Channel 2 = "capitalize all letters"
Therefore:
Embedding of "abc": [0, 0, 0, x0, x1, x2 , ..]
Embedding of " abc": [1, 0, 0, x0, x1, x2, ..]
Embedding of " Abc": [1, 1, 0, x0, x1, x2, ..]
Embedding of "ABC": [0, 0, 1, x0, x1, x2, ...]
......
so they will share most of the embedding. And we can rapidly compute the output probability of all variations of "abc".
Note: the above method is assuming that p(" xyz") / p("xyz") is the same for any "xyz", which can be wrong.
Better: define emb_space emb_capitalize_first emb_capitalize_all to be a function of emb.
Maybe the Best: let 'abc' ' abc' etc. to share the last 90% of their embeddings.
At this moment, all our tokenizers spend too many items to represent all variations of 'abc' ' abc' ' Abc' etc. Moreover the model cannot discover that these are actually similar if some of these variations are rare in the dataset. The method here can improve this. I plan to test this in a new version of RWKV.
## How it works
RWKV is inspired by Apple's AFT (https://arxiv.org/abs/2105.14103).
Moreover it's using a number of my tricks, such as:
* SmallInitEmb: https://github.com/BlinkDL/SmallInitEmb (applicable to all transformers) which helps the embedding quality, and stabilizes Post-LN (which is what I am using).
* Token-shift: https://github.com/BlinkDL/RWKV-LM#token-shift-time-shift-mixing (applicable to all transformers), especially helpful for char-level models.
* Head-QK: https://github.com/BlinkDL/RWKV-LM#the-head-qk-trick-learning-to-copy-and-avoid-tokens (applicable to all transformers). Note: it's helpful, but I disabled it in the Pile model to keep it 100% RNN.
* Extra R-gate in the FFN (applicable to all transformers). I am also using reluSquared from Primer.
* Better initilization: I init most of the matrices to ZERO (see RWKV_Init in https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v2-RNN/src/model.py).
* You can transfer some parameters from a small model to a large model (note: I sort & smooth them too), for faster and better convergence (see https://www.reddit.com/r/MachineLearning/comments/umq908/r_rwkvv2rnn_a_parallelizable_rnn_with/).
* My CUDA kernel: https://github.com/BlinkDL/RWKV-CUDA to speedup training.
## The pseudocode (execution from top to bottom):
![RWKV-v2-RNN](RWKV-v2-RNN.png)
The a b c d factors work together to build a time-decay curve: [X, 1, W, W^2, W^3, ...].
Write out the formulas for "token at pos 2" and "token at pos 3" and you will get the idea:
* a and b: EMAs of kv and k.
* c and d: these are a and b combined with "self-attention".
kv / k is the memory mechanism. The token with high k can be remembered for a long duration, if W is close to 1 in the channel.
The R-gate is important for performance. k = info strength of this token (to be passed to future tokens). r = whether to apply the info to this token.
## RWKV-3 improvements
Use different trainable TimeMix factors for R / K / V in SA and FF layers. Example:
```python
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
```
Use preLN instead of postLN (more stable & faster convergence):
```python
if self.layer_id == 0:
x = self.ln0(x)
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
```
## Explaining the code for RWKV-3 GPT mode
### The GPT mode - overview
The building blocks of RWKV-3 GPT mode are similar to that of a usual preLN GPT.
The only difference is an extra LN after embedding. Note you can absorb this LN into the embedding after finishing the training.
```python
x = self.emb(idx) # input: idx = token indices
x = self.ln_emb(x) # extra LN after embedding
x = x + self.att_0(self.ln_att_0(x)) # preLN
x = x + self.ffn_0(self.ln_ffn_0(x))
...
x = x + self.att_n(self.ln_att_n(x))
x = x + self.ffn_n(self.ln_ffn_n(x))
x = self.ln_head(x) # final LN before projection
x = self.head(x) # output: x = logits
```
It is important to initialize emb to tiny values, such as nn.init.uniform_(a=-1e-4, b=1e-4), to utilize my trick https://github.com/BlinkDL/SmallInitEmb.
For the 1.5B RWKV-3, I use Adam (no wd, no dropout) optimizer on 8 * A100 40G.
batchSz = 32 * 896, ctxLen = 896. I am using tf32 so the batchSz is a bit small.
For the first 15B tokens, LR is fixed at 3e-4, and beta=(0.9, 0.99).
Then I set beta=(0.9, 0.999), and do an exponential decay of LR, reaching 1e-5 at 332B tokens.
### The GPT mode - ATT block
The RWKV-3 does not have any attention in the usual sense, but we will call this block ATT anyway.
```python
B, T, C = x.size() # x = (Batch,Time,Channel)
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1))
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
k = torch.clamp(k, max=60) # clamp k to avoid overflow
k = torch.exp(k)
kv = k * v
# Compute the W-curve = [e^(-n * e^time_decay), e^(-(n-1) * e^time_decay), ..., 1, e^(time_first)]
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(x.device), self.time_first], dim=-1)
w = torch.exp(self.time_w)
# Use W to mix kv and k respectively. Add K_EPS to wk to avoid divide-by-zero
if RUN_DEVICE == 'cuda':
wkv = TimeX.apply(w, kv, B,C,T, 0)
wk = TimeX.apply(w, k, B,C,T, K_EPS)
else:
w = w[:,-T:].unsqueeze(1)
wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + K_EPS
# The RWKV formula
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv) # final output projection
```
The self.key, self.receptance, self.output matrices are all initialized to zero.
The time_mix, time_decay, time_first vectors are transferred from a smaller trained model (note: I sort & smooth them too).
### The GPT mode - FFN block
The FFN block has three tricks comparing with the usual GPT:
1. My time_mix trick.
2. The sqReLU from the Primer paper.
3. An extra receptance-gate (similar to the receptance-gate in ATT block).
```python
# Mix x with the previous timestep to produce xk, xr
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# The usual FFN operation
k = self.key(xk)
k = torch.square(torch.relu(k)) # from the Primer paper
kv = self.value(k)
# Apply an extra receptance-gate to kv
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
```
The self.value, self.receptance matrices are all initialized to zero.
## RWKV-4 improvements
![RWKV-v3-plan](RWKV-v3-plan.png)
## From GPT to RWKV (the formulas)
Let F[t] be the system state at t.
Let x[t] be the new external input at t.
In GPT, predicting F[t+1] requires considering F[0], F[1], .. F[t]. So it takes O(T^2) to generate a length T sequence.
The **simplified formula** for GPT:
![F[\mathrm{t}+1]=\frac{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{Q}x[\mathrm{t}] * \mathbf{K}F[\mathrm{i}]) \cdot(\mathbf{V}F[\mathrm{i}])}{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{Q}x[\mathrm{t}] * \mathbf{K}F[\mathrm{i}])}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5B%5Cmathrm%7Bt%7D%2B1%5D%3D%5Cfrac%7B%5Csum_%7B%5Cmathrm%7Bi%7D%3D0%7D%5E%7B%5Cmathrm%7Bt%7D%7D+%5Cexp+%28%5Cmathbf%7BQ%7Dx%5B%5Cmathrm%7Bt%7D%5D+%2A+%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29+%5Ccdot%28%5Cmathbf%7BV%7DF%5B%5Cmathrm%7Bi%7D%5D%29%7D%7B%5Csum_%7B%5Cmathrm%7Bi%7D%3D0%7D%5E%7B%5Cmathrm%7Bt%7D%7D+%5Cexp+%28%5Cmathbf%7BQ%7Dx%5B%5Cmathrm%7Bt%7D%5D+%2A+%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29%7D)
It's very capable in theory, however that **does not mean we can fully utilize its capability with usual optimizers**. I suspect the loss landscape is too difficult for our current methods.
Compare with the **simplified formula** for RWKV (the parallel mode, looks similar to Apple's AFT):
![F[\mathrm{t}+1]=\sigma(\mathbf{R}x[\mathrm{t}]) \cdot \frac{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i})) \cdot \exp (\mathbf{K}F[\mathrm{i}]) \cdot(\mathbf{V}F[\mathrm{i}])}{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i})) \cdot \exp (\mathbf{K }F[\mathrm{i}])}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5B%5Cmathrm%7Bt%7D%2B1%5D%3D%5Csigma%28%5Cmathbf%7BR%7Dx%5B%5Cmathrm%7Bt%7D%5D%29+%5Ccdot+%5Cfrac%7B%5Csum_%7B%5Cmathrm%7Bi%7D%3D0%7D%5E%7B%5Cmathrm%7Bt%7D%7D+%5Cexp+%28%5Cmathbf%7BW%7D+%5Ccdot%28%5Cmathrm%7Bt%7D-%5Cmathrm%7Bi%7D%29%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29+%5Ccdot%28%5Cmathbf%7BV%7DF%5B%5Cmathrm%7Bi%7D%5D%29%7D%7B%5Csum_%7B%5Cmathrm%7Bi%7D%3D0%7D%5E%7B%5Cmathrm%7Bt%7D%7D+%5Cexp+%28%5Cmathbf%7BW%7D+%5Ccdot%28%5Cmathrm%7Bt%7D-%5Cmathrm%7Bi%7D%29%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B%5Cmathrm%7Bi%7D%5D%29%7D)
The R, K, V are trainable matrices, and W is a trainable vector (time-decay factor for each channel).
In GPT, the contribution of F[i] to F[t+1] is weighted by ![ \exp (\mathbf{Q}x[\mathrm{t}] * \mathbf{K}F[\mathrm{i}]) ](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle++%5Cexp+%28%5Cmathbf%7BQ%7Dx%5B%5Cmathrm%7Bt%7D%5D+%2A+%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29+).
In RWKV-2, the contribution of F[i] to F[t+1] is weighted by ![\sigma(\mathbf{R}x[\mathrm{t}]) \cdot \exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i})) \cdot \exp (\mathbf{K}F[\mathrm{i}]) ](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+%5Csigma%28%5Cmathbf%7BR%7Dx%5B%5Cmathrm%7Bt%7D%5D%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BW%7D+%5Ccdot%28%5Cmathrm%7Bt%7D-%5Cmathrm%7Bi%7D%29%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29+).
* The ![\sigma](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+%5Csigma) is a non-linearity and we can use sigmoid.
* Note ![\sigma(\mathbf{R}x[\mathrm{t}])](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+%5Csigma%28%5Cmathbf%7BR%7Dx%5B%5Cmathrm%7Bt%7D%5D%29) is not in the denominator, and I call R the "receptance".
* The ![\exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i}))](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+%5Cexp+%28%5Cmathbf%7BW%7D+%5Ccdot%28%5Cmathrm%7Bt%7D-%5Cmathrm%7Bi%7D%29%29) is the time-decay factor. I proposed the same idea (scaling the attention by distance) in Aug 2020 and called it the "time-weighting" (check the commit history of https://github.com/BlinkDL/minGPT-tuned).
Here comes the punchline: we can rewrite it into a RNN (recursive formula). Note:
![F[1]=\sigma(\mathbf{R }x[0]) \cdot \frac{ \exp (\mathbf{K }F[0]) \cdot(\mathbf{V }F[0])}{\exp (\mathbf{K }F[0])}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5B1%5D%3D%5Csigma%28%5Cmathbf%7BR+%7Dx%5B0%5D%29+%5Ccdot+%5Cfrac%7B+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B0%5D%29+%5Ccdot%28%5Cmathbf%7BV+%7DF%5B0%5D%29%7D%7B%5Cexp+%28%5Cmathbf%7BK+%7DF%5B0%5D%29%7D)
![F[2]=\sigma(\mathbf{R }x[1]) \cdot \frac{ \exp (\mathbf{K }F[1]) \cdot(\mathbf{V }F[1])+\exp (\mathbf{W} ) \cdot \exp (\mathbf{K }F[0]) \cdot(\mathbf{V }F[0])}{ \exp (\mathbf{K }F[1])+\exp (\mathbf{W} ) \cdot \exp (\mathbf{K }F[0])}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5B2%5D%3D%5Csigma%28%5Cmathbf%7BR+%7Dx%5B1%5D%29+%5Ccdot+%5Cfrac%7B+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B1%5D%29+%5Ccdot%28%5Cmathbf%7BV+%7DF%5B1%5D%29%2B%5Cexp+%28%5Cmathbf%7BW%7D+%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B0%5D%29+%5Ccdot%28%5Cmathbf%7BV+%7DF%5B0%5D%29%7D%7B+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B1%5D%29%2B%5Cexp+%28%5Cmathbf%7BW%7D+%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK+%7DF%5B0%5D%29%7D)
Therefore it's straightforward to verify:
![F[t+1]=\sigma(\mathbf{R }x[t]) \cdot \frac{\exp (\mathbf{K}F[\mathrm{t}]) \cdot(\mathbf{V}F[\mathrm{t}])+\exp (\mathbf{W}) \cdot A[\mathrm{t}]}{ \exp (\mathbf{K}F[\mathrm{t}])+\exp (\mathbf{W}) \cdot B[\mathrm{t}]}](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+F%5Bt%2B1%5D%3D%5Csigma%28%5Cmathbf%7BR+%7Dx%5Bt%5D%29+%5Ccdot+%5Cfrac%7B%5Cexp+%28%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bt%7D%5D%29+%5Ccdot%28%5Cmathbf%7BV%7DF%5B%5Cmathrm%7Bt%7D%5D%29%2B%5Cexp+%28%5Cmathbf%7BW%7D%29+%5Ccdot+A%5B%5Cmathrm%7Bt%7D%5D%7D%7B+%5Cexp+%28%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bt%7D%5D%29%2B%5Cexp+%28%5Cmathbf%7BW%7D%29+%5Ccdot+B%5B%5Cmathrm%7Bt%7D%5D%7D)
where A[t] and B[t] are the numerator and denominator of the previous step, respectively.
I believe RWKV is performant because W is like repeatedly applying a diagonal matrix. Note (P^{-1} D P)^n = P^{-1} D^n P, so it is similar to repeatedly applying a general diagonalizable matrix.
Moreover it's possible to turn it into a continuous ODE (a bit similar to State Space Models). I will write about it later.
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=BlinkDL/RWKV-LM&type=Date)](https://star-history.com/#BlinkDL/RWKV-LM&Date)
## Multimodal ideas
I have an idea for [text --> 32x32 RGB image] using a LM (transformer, RWKV, etc.). Will test it soon.
Firstly, LM loss (instead of L2 loss), so the image will not be blurry.
Secondly, color quantization. For example, only allowing 8 levels for R/G/B. Then the image vocab size is 8x8x8 = 512 (for each pixel), instead of 2^24.
Therefore, a 32x32 RGB image = a len1024 sequence of vocab512 (image tokens), which is a typical input for usual LMs.
(Later we can use diffusion models to upsample and generate RGB888 images. We might be able to use a LM for this too.)
Thirdly, 2D positional embeddings that are easy for the model to understand.
For example, add one-hot X & Y coords to the first 64(=32+32) channels. Say if the pixel is at x=8, y=20, then we will add 1 to channel 8 and channel 52 (=32+20).
Moreover probably we can add the float X & Y coords (normalized to 0~1 range) to another 2 channels. And other periodic pos. encoding might help too (will test).
Finally, RandRound when doing the color quantization in the DataLoader.
For example, if the float level is 4.578, then there is a 57.8% chance to use 5, and (1-57.8%) chance to use 4.
And we can allow both 4 and 5 in the prediction, but the loss will be higher if the prediction is 4.
Multi-task training might help too. I will try this dataset format:
[TxtFirst] [Desc of Img (txt tokens)] [Img] [img tokens]
and sometimes
[ImgFirst] [img tokens] [Txt] [Desc of Img (txt tokens)]
... the order of the imgs should be randomized in the DataLoader, and [TxtFirst] [ImgFirst] [Img] [Txt] are special tokens
and do random sampling of the full dataset. So sometimes the model will see the img tokens first and then the corresponding txt tokens, which is a [img -> txt] task. And the model will see some partial imgs and partial txts. I think a char-level LM might help the model to write correct text on images.
## How to sample a large dataset (for training)
I am using a trick to sample the Pile deterministically yet randomly enough.
Let's say the pile has x chunks (a chunk = ctx_len tokens).
pick a prime number p just less than x, and make sure p = 2 (mod 3).
Use (step * step * step) mod p to sample it. Add some bias to step for extra randomness.
## The top-p-x sampling method (for inference)
We propose a new sampling method called top-p-x:
it's like top-p, and the only difference is you also keep all tokens whose prob > x.
Try x = 0.01 first.
## Better Learning Rate Schedule via Variantional Method of Loss Curve
I propose a simple new method to find better LR schedules. The method is cost-efficient and practical for large LMs. The takeaway is we can model the loss curve dynamics (phenomenology) w.r.t. the LR, and a nice closed-form LR curve can be directly computed from it using variantional method. Moreover we can predict the final loss with reasonable accuracy.
UPDATE: In "Conclusion 1.", use the best-fitting regime (ignore the initial steps where our approximations break down) to fit the parameters.
Try this: fixed lr for 1 hr, then exponential decay to 0.2 * lr in 12 hrs, and choose the t=[1hr, 13hr] segment.
In the last three plots, black = predicted loss curve of the new LR schedule, blue = original (unoptimized) real loss curve, orange = new LR schedule.
![better_lr_schedule](Research/better_lr_schedule.png)
# RWKV v1
We propose the RWKV language model, with alternating time-mix and channel-mix layers: We propose the RWKV language model, with alternating time-mix and channel-mix layers:
@ -20,6 +518,8 @@ alt="\begin{align*}
"https://render.githubusercontent.com/render/math?math=%5Cdisplaystyle+%5Ctext%7Bsoftmax%7D_t%28%5Ctext%7BK%7D_%7Bu%2Cc%7D%29+%3D+%5Cfrac%7B%5Cexp%28%5Ctext%7BK%7D_%7Bu%2Cc%7D%29%7D%7B%5Csum_%7Bv+%5Cleq+t%7D%5Cexp%28%5Ctext%7BK%7D_%7Bv%2Cc%7D%29%7D" "https://render.githubusercontent.com/render/math?math=%5Cdisplaystyle+%5Ctext%7Bsoftmax%7D_t%28%5Ctext%7BK%7D_%7Bu%2Cc%7D%29+%3D+%5Cfrac%7B%5Cexp%28%5Ctext%7BK%7D_%7Bu%2Cc%7D%29%7D%7B%5Csum_%7Bv+%5Cleq+t%7D%5Cexp%28%5Ctext%7BK%7D_%7Bv%2Cc%7D%29%7D"
alt="\text{softmax}_t(\text{K}_{u,c}) = \frac{\exp(\text{K}_{u,c})}{\sum_{v \leq t}\exp(\text{K}_{v,c})}"> alt="\text{softmax}_t(\text{K}_{u,c}) = \frac{\exp(\text{K}_{u,c})}{\sum_{v \leq t}\exp(\text{K}_{v,c})}">
**(UPDATE: We are using the original AFT normalization in v2)**
Initialize K and R matrices (and the output projection matrix) to ZERO for fast & stable convergence. Initialize K and R matrices (and the output projection matrix) to ZERO for fast & stable convergence.
(2) We decompose W_{t,u,c} and introduce multi-head W (here h is the corresponding head of c): (2) We decompose W_{t,u,c} and introduce multi-head W (here h is the corresponding head of c):
@ -30,15 +530,25 @@ alt="W_{t,u,c}=f_h(t-u)\cdot \alpha_h(u) \cdot \beta_h(t)">
Moreover we multiply the final output of Time-mix layer by γ(t). The reason for the α β γ factors, is because the context size is smaller when t is small, and this can be compensated using the α β γ factors. Moreover we multiply the final output of Time-mix layer by γ(t). The reason for the α β γ factors, is because the context size is smaller when t is small, and this can be compensated using the α β γ factors.
**(UPDATE: We remove α β γ factors in v2-RNN and restrict W to be of a simple form and hence able to rewrite it as RNN)**
* The Channel-mix is similar to GeGLU (https://arxiv.org/abs/2002.05202) with an extra R factor. Initialize R and W matrices to ZERO for fast & stable convergence. * The Channel-mix is similar to GeGLU (https://arxiv.org/abs/2002.05202) with an extra R factor. Initialize R and W matrices to ZERO for fast & stable convergence.
* Finally, we add extra token-shift (time-shift mixing) as in (https://github.com/BlinkDL/minGPT-tuned). * Finally, we add extra token-shift (time-shift mixing) as in (https://github.com/BlinkDL/minGPT-tuned).
# Token-shift (time-shift mixing) # Token-shift (time-shift mixing)
The token-shift means explicitly using both (half channel of this token) & (half channel of prev token) to generate all vectors. The token-shift explicitly uses (half the channels of this token) & (half the channels of prev token) to generate all vectors (QKV, RWKV, ...).
I found dividing channels by 2 and shift-1 works the best for Chinese LM. You may want to use more shift for English char-level LM. I checked the weights and found you may want to use less mixing in higher layers. ```
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
```
Dividing channels by 2 and shift-1 works great for char-level English and char-level Chinese LM.
However for BPE-level English LM, it's only effective if your embedding is large enough (at least 1024 - so the usual small L12-D768 model is not enough).
My theory on the effectiveness of token-shift: My theory on the effectiveness of token-shift:
@ -50,19 +560,44 @@ When we train a GPT, the hidden representation of a token has to accomplish two
The shifted channels can focus on (2), so we have good propagation of info. It's like some kind of residual connection, or a small RNN inside the transformer. The shifted channels can focus on (2), so we have good propagation of info. It's like some kind of residual connection, or a small RNN inside the transformer.
You can use token-shift in usual QKV self-attention too. I looked at the weights, and found V really likes the shifted channels, less so for Q. Makes sense if you think about it. You can use token-shift in usual QKV self-attention too. I looked at the weights, and found V really likes the shifted channels, less so for Q. Makes sense if you think about it. I also found you may want to use less mixing in higher layers.
p.s. There is a MHA_pro model in this repo with strong performance. Give it a try :) p.s. There is a MHA_pro model in this repo with strong performance. Give it a try :)
# The top-a Sampling method # The Head-QK Trick: learning to copy and avoid tokens
In usual transformer, a small model has difficulty copying tokens (such as person names) in the context. We add extra Q & K to the final output such that the model can directly copy (or avoid) tokens in the context. Afterwards the model will teach itself NER (named entity recognition) if you look at the learned weights.
```
q = self.head_q(x)[:,:T,:] # projecting to 256-d
k = self.head_k(x)[:,:T,:] # projecting to 256-d
c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
x = self.head(x) + c
```
Note: when a token occurs multiple times in the context, it might be better to use max(prob) instead of sum(prob).
# The top-a sampling method
We also propose a new sampling method called top-a (as in src/utils.py): We also propose a new sampling method called top-a (as in src/utils.py):
(1) Find the max probability p_max after softmax. (1) Find the max probability p_max after softmax.
(2) Remove all entries whose probability is lower than 0.02 * pow(p_max, 2). So it's adaptive, hence "top-a". (2) Remove all entries whose probability is lower than 0.2 * pow(p_max, 2). So it's adaptive, hence "top-a".
(3) Feel free to tune the 0.2 and 2 factor. Tune 0.2 first.
The idea of top-a:
1. If max_prob=0.9, then remove all tokens with prob < 0.162 (so, removing all alternatives)
2. If max_prob=0.5, then remove all tokens with prob < 0.05 (so, allowing more choices)
3. If max_prob=0.1, then remove all tokens with prob < 0.002 (so, allowing lots of possibilities)
(3) Feel free to tune the 0.02 and 2 factor. ```
probs = F.softmax(logits, dim=-1)
limit = torch.pow(torch.max(probs), 2) * 0.02
logits[probs < limit] = -float('Inf')
```
# Performance # Performance

Binary file not shown.

After

Width:  |  Height:  |  Size: 161 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 67 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 410 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 359 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

@ -63,26 +63,20 @@ class RWKV_TimeMix(nn.Module):
self.head_size = config.n_attn // config.n_head self.head_size = config.n_attn // config.n_head
with torch.no_grad(): # initial time_w curves for better convergence with torch.no_grad(): # initial time_w curves for better convergence
ww = torch.zeros(config.n_head, config.ctx_len) ww = torch.ones(config.n_head, config.ctx_len)
curve = torch.tensor([0.9 ** (config.ctx_len - 1 - i) for i in range(config.ctx_len)]) curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
curve = curve * 2 + 0.7
for h in range(config.n_head): for h in range(config.n_head):
if config.n_head > 1: if h < config.n_head - 1:
mix_strength = 1 - 1.2 * h / (config.n_head - 1) # mix_strength from 1 to -0.2 decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
else: else:
mix_strength = 0.5 decay_speed = 0.0
ww[h] = (1 - mix_strength) + curve * mix_strength ww[h] = torch.exp(curve * decay_speed)
# special tweaks because of time_shift # print('layer', layer_id, 'head', h, 'decay_speed', round(decay_speed, 4), ww[h][:5].numpy(), '...', ww[h][-5:].numpy())
ww[h][config.ctx_len - 3] = (ww[h][config.ctx_len - 3] * 2 + 1) / 3
ww[h][config.ctx_len - 2] = (ww[h][config.ctx_len - 2] * 1 + 2) / 3
ww[h][config.ctx_len - 1] = 1
# print(h, mix_strength, ww[h])
self.time_w = nn.Parameter(ww) self.time_w = nn.Parameter(ww)
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len)) self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1)) self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1)) self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.time_shift = nn.ZeroPad2d((0,0,1,-1)) self.time_shift = nn.ZeroPad2d((0,0,1,-1))
@ -90,8 +84,8 @@ class RWKV_TimeMix(nn.Module):
self.value = nn.Linear(config.n_embd, config.n_attn) self.value = nn.Linear(config.n_embd, config.n_attn)
self.receptance = nn.Linear(config.n_embd, config.n_attn) self.receptance = nn.Linear(config.n_embd, config.n_attn)
if config.rwkv_tiny_attn > 0: # if config.rwkv_tiny_attn > 0:
self.tiny_att = RWKV_TinyAttn(config) # self.tiny_att = RWKV_TinyAttn(config)
self.output = nn.Linear(config.n_attn, config.n_embd) self.output = nn.Linear(config.n_attn, config.n_embd)
@ -107,12 +101,10 @@ class RWKV_TimeMix(nn.Module):
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT-1:] # w is now a circulant matrix w = w[:, :, TT-1:] # w is now a circulant matrix
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :] w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
self.mask = self.mask[:T, :T]
w = w.masked_fill(self.mask == 0, 0)
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
if hasattr(self, 'tiny_att'): # if hasattr(self, 'tiny_att'):
tiny_att = self.tiny_att(x, self.mask) # tiny_att = self.tiny_att(x, self.mask)
k = self.key(x) k = self.key(x)
v = self.value(x) v = self.value(x)
@ -129,8 +121,8 @@ class RWKV_TimeMix(nn.Module):
rwkv = torch.sigmoid(r) * wkv / sum_k rwkv = torch.sigmoid(r) * wkv / sum_k
rwkv = self.output(rwkv) rwkv = self.output(rwkv)
if hasattr(self, 'tiny_att'): # if hasattr(self, 'tiny_att'):
rwkv += tiny_att # rwkv += tiny_att
return rwkv * self.time_gamma[:T, :] return rwkv * self.time_gamma[:T, :]
@ -442,6 +434,12 @@ class GPT(nn.Module):
self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.head_q = nn.Linear(config.n_embd, 256)
self.head_q.scale_init = 0.01
self.head_k = nn.Linear(config.n_embd, 256)
self.head_k.scale_init = 0.01
self.register_buffer("copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len self.ctx_len = config.ctx_len
if self.config.model_type == 'RWKV': if self.config.model_type == 'RWKV':
@ -502,8 +500,15 @@ class GPT(nn.Module):
x = self.blocks(x) x = self.blocks(x)
x = self.ln_f(x) x = self.ln_f(x)
q = self.head_q(x)[:,:T,:]
k = self.head_k(x)[:,:T,:]
c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
x = x * self.time_out[:, :T, :] # reduce confidence of early tokens x = x * self.time_out[:, :T, :] # reduce confidence of early tokens
x = self.head(x) x = self.head(x) + c
loss = None loss = None
if targets is not None: if targets is not None:

@ -8,8 +8,8 @@ from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
print('logging to wandb... (comment it if you don\'t have wandb)') # print('logging to wandb... (comment it if you don\'t have wandb)')
import wandb # comment it if you don't have wandb # import wandb # comment this if you don't have wandb
class TrainerConfig: class TrainerConfig:
max_epochs = 10 max_epochs = 10
@ -22,7 +22,8 @@ class TrainerConfig:
lr_decay = False # linear warmup followed by cosine decay lr_decay = False # linear warmup followed by cosine decay
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper
final_tokens = 260e9 # at which point do we reach lr_final final_tokens = 260e9 # at which point do we reach lr_final
ckpt_path = None epoch_save_frequency = 0
epoch_save_path = 'trained-'
num_workers = 0 # for DataLoader num_workers = 0 # for DataLoader
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -56,11 +57,6 @@ class Trainer:
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
return run_name return run_name
def save_checkpoint(self): # DataParallel wrappers keep raw model object in .module attribute
raw_model = self.model.module if hasattr(self.model, "module") else self.model
logger.info("saving %s", self.config.ckpt_path)
torch.save(raw_model.state_dict(), self.config.ckpt_path)
def train(self): def train(self):
model, config = self.model, self.config model, config = self.model, self.config
raw_model = model.module if hasattr(self.model, "module") else model raw_model = model.module if hasattr(self.model, "module") else model
@ -77,12 +73,11 @@ class Trainer:
pbar = tqdm(enumerate(loader), total=len(loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) pbar = tqdm(enumerate(loader), total=len(loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
for it, (x, y) in pbar: for it, (x, y) in pbar:
x = x.to(self.device) # place data on the correct device x = x.to(self.device) # place data on the correct device
y = y.to(self.device) y = y.to(self.device)
with torch.set_grad_enabled(is_train): with torch.set_grad_enabled(is_train):
logits, loss = model(x, y) # forward the model _, loss = model(x, y) # forward the model
loss = loss.mean() # collapse all losses if they are scattered on multiple gpus loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
if is_train: # backprop and update the parameters if is_train: # backprop and update the parameters
@ -94,14 +89,15 @@ class Trainer:
if config.lr_decay: # decay the learning rate based on our progress if config.lr_decay: # decay the learning rate based on our progress
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
lr_final_factor = config.lr_final / config.learning_rate
if self.tokens < config.warmup_tokens: if self.tokens < config.warmup_tokens:
# linear warmup # linear warmup
lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens)) lr_mult = lr_final_factor + (1 - lr_final_factor) * float(self.tokens) / float(config.warmup_tokens)
progress = 0 progress = 0
else: else:
# cosine learning rate decay # cosine learning rate decay
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
lr_final_factor = config.lr_final / config.learning_rate # progress = min(progress * 1.1, 1.0) # more fine-tuning with low LR
lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1 lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1
lr = config.learning_rate * lr_mult lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
@ -118,20 +114,17 @@ class Trainer:
if self.avg_loss < 0: if self.avg_loss < 0:
self.avg_loss = now_loss self.avg_loss = now_loss
else: else:
factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1)) # factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1))
factor = 1 / (it + 1)
self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor
pbar.set_description(f"epoch {epoch+1} progress {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}") pbar.set_description(f"epoch {epoch+1} progress {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
best_loss = float('inf') while True:
self.tokens = 0 # counter used for learning rate decay self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs): for epoch in range(config.max_epochs):
run_epoch('train') run_epoch('train')
if self.test_dataset is not None:
test_loss = run_epoch('test')
# supports early stopping based on the test loss, or just save always if no test set is provided if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
good_model = self.test_dataset is None or test_loss < best_loss raw_model = self.model.module if hasattr(self.model, "module") else self.model # DataParallel wrappers keep raw model object in .module
if self.config.ckpt_path is not None and good_model: torch.save(raw_model, self.config.epoch_save_path + str(epoch+1) + '.pth')
best_loss = test_loss
self.save_checkpoint()

@ -25,36 +25,50 @@ model_type = 'RWKV'
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt"
datafile_encoding = 'utf-8' datafile_encoding = 'utf-8'
# datafile = u"D:\\NLP-Data\\ww100M.txt" # datafile = u"D:\\NLP-Data\\ww100M.txt"
# datafile = u"D:\\NLP-Data\\__2019.txt"
# datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt" # datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt"
# datafile = u"V:\\NLP\\enwik8-shift-300.bpe"
# datafile_encoding = 'utf-16' # datafile_encoding = 'utf-16'
# datafile = u"V:\\NLP\\simplebooks-shift-utf32.word"
# datafile_encoding = 'utf-32'
datafile_type = 0 # use 0 for char-level english. use 1 for chinese. only affects some RWKV hyperparametrs datafile_type = 0 # use 0 for char-level english. use 1 for chinese. only affects some RWKV hyperparametrs
#################################### VERY IMPORTANT ####################################
epoch_save_frequency = 10 # 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc.
epoch_save_path = 'trained-'
batch_size = 32 # if you see "CUDA out of memory", reduce this.
# if you have good GPU, increase this.
# use GPU-Z to find the highest value for your VRAM.
n_epoch = 100 # the 'epoch' here is actually very short (and of fixed length)
########################################################################################
model_level = 'character' # 'character' (recommended) or 'word' model_level = 'character' # 'character' (recommended) or 'word'
ctx_len = 256 # context length ctx_len = 256 # context length, try 512 or 1024 if you have good GPU
n_layer = 5 n_layer = 6 # try 12 for 100M, 24 for 300M
n_head = 8 n_head = 8 # try 12 for 100M, 16 for 300M
n_embd = n_head * 64 n_embd = n_head * 64
n_attn = n_embd n_attn = n_embd
n_ffn = n_embd n_ffn = n_embd
batch_size = 64 lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr. 8e-4 = 0.0008 4e-4 = 0.0004
lr_final = 4e-5
n_epoch = 50 # the 'epoch' here is actually very short (and of fixed length)
lr_init = 8e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr
lr_final = 2e-4
betas = (0.9, 0.999) if model_type == 'RWKV' else (0.9, 0.99) betas = (0.9, 0.99) if model_type == 'RWKV' else (0.9, 0.99)
eps = 1e-8 eps = 4e-9
weight_decay = 0 if model_type == 'RWKV' else 0.01 # wd is not useful when we have enough data weight_decay = 0 if model_type == 'RWKV' else 0.01 # wd is not useful when we have enough data
epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
######## special hyperparameters for RWKV model ######## ######## special hyperparameters for RWKV model ########
rwkv_emb_scale = 0.4 # scale of initial embedding. 0.4 is a good choice rwkv_emb_scale = 0.4 # scale of initial embedding. 0.4 is a good choice
rwkv_tiny_attn = 64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english rwkv_tiny_attn = 0#64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english
rwkv_tiny_head = 1 # 1 is good enough. 8 is slow rwkv_tiny_head = 1 # 1 is good enough. 8 is slow
# n_side_proj = 512 # extra 'side projection', quite useful for BPE models
######################################################################################################## ########################################################################################################
# Load data # Load data
@ -76,6 +90,15 @@ class Dataset(Dataset):
# for u in unique: # for u in unique:
# print(u, end=' ') # print(u, end=' ')
# print('\n\n') # print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
data_size, vocab_size = len(data), len(unique) data_size, vocab_size = len(data), len(unique)
print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size)) print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size))
self.stoi = { ch:i for i,ch in enumerate(unique) } self.stoi = { ch:i for i,ch in enumerate(unique) }
@ -105,63 +128,15 @@ model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_typ
rwkv_emb_scale=rwkv_emb_scale, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head, rwkv_emb_scale=rwkv_emb_scale, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head,
n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn)) n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn))
# load a trained model
# model.load_state_dict(torch.load('trained-xxx.pth').state_dict())
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn) print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn)
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay, tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps,
warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0) warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
trainer = Trainer(model, train_dataset, None, tconf) trainer = Trainer(model, train_dataset, None, tconf)
trainer.train() trainer.train()
torch.save(model, 'trained-' + trainer.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') torch.save(model, 'trained-' + trainer.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')
########################################################################################################
# Run model to generate text
########################################################################################################
from src.utils import sample_logits
NUM_OF_RUNS = 5
LENGTH_OF_EACH = 300
for run in range(NUM_OF_RUNS):
context = "it was"
if model_level == 'word':
x = np.array([train_dataset.stoi[s] for s in context.strip().lower().split(' ')], dtype=np.int64)
else:
x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64)
real_len = len(x)
if real_len < ctx_len:
x = np.pad(x, (0, ctx_len - real_len))
print_begin = 0
for i in range(LENGTH_OF_EACH):
if i == 0:
print(('-' * 80) + '\n' + context, end = '')
print_begin = real_len
with torch.no_grad():
xxx = torch.tensor(x[-ctx_len:], dtype=torch.long)[None,...].to("cuda:0")
out, _ = model(xxx)
pos = -1 if real_len >= ctx_len else real_len - 1
char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) # our special sampling method
if real_len < ctx_len:
x[real_len] = char
else:
x = np.append(x, char)
real_len += 1
if i % 10 == 9 or i == LENGTH_OF_EACH-1:
if model_level == 'word':
completion = ' ' + ' '.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]])
completion = completion.replace('\n ', '\n')
else:
completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]])
print(completion, end = '')
print_begin = real_len
print()

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 143 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 649 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 289 KiB

@ -0,0 +1,172 @@
#include <stdio.h>
// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)
#define F4(A, B) ((float4 *)(A))[(B) >> 2]
template <typename F>
__global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,
const F eps, const int B, const int C, const int T) {
const int i = blockIdx.y;
const int ij = (B * C) / BF;
const int t = threadIdx.x << 2;
__shared__ F ww[Tmax];
__shared__ F kk[Tmax * BF];
F4(ww, t) = F4(__w, t + T * (i % C));
#pragma unroll
for (int j = 0; j < BF; j++) {
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
}
__syncthreads();
float4 s[BF];
#pragma unroll
for (int j = 0; j < BF; j++) {
s[j] = {eps, eps, eps, eps};
}
const F *__restrict__ const w = ww + T - t - 4;
for (int u = 0; u <= t; u++) {
#pragma unroll
for (int j = 0; j < BF; j++) {
const F x = kk[u + Tmax * j];
s[j].x += w[u + 3] * x;
s[j].y += w[u + 2] * x;
s[j].z += w[u + 1] * x;
s[j].w += w[u + 0] * x;
}
}
#pragma unroll
for (int j = 0; j < BF; j++) {
const F *__restrict__ const k = kk + Tmax * j;
s[j].y += w[t + 3] * k[t + 1];
s[j].z += w[t + 2] * k[t + 1];
s[j].z += w[t + 3] * k[t + 2];
s[j].w += w[t + 1] * k[t + 1];
s[j].w += w[t + 2] * k[t + 2];
s[j].w += w[t + 3] * k[t + 3];
F4(x, t + T * (i + ij * j)) = s[j];
}
}
template <typename F>
__global__ void kernel_backward_W(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
F *__restrict__ const gw, F *__restrict__ const gk,
const int B, const int C, const int T) {
const int i = blockIdx.y;
const int t = threadIdx.x << 2;
__shared__ F k[Tmax];
__shared__ F gg[Tmax];
F4(k, t) = F4(__k, t + T * i);
F4(gg, t) = F4(__gwk, t + T * i);
__syncthreads();
float4 s = {0, 0, 0, 0};
const F *__restrict__ const g = gg + T - t - 4;
for (int u = 0; u <= t; u++) {
F x = k[u];
s.x += g[u + 3] * x;
s.y += g[u + 2] * x;
s.z += g[u + 1] * x;
s.w += g[u + 0] * x;
}
s.y += g[t + 3] * k[t + 1];
s.z += g[t + 2] * k[t + 1];
s.z += g[t + 3] * k[t + 2];
s.w += g[t + 1] * k[t + 1];
s.w += g[t + 2] * k[t + 2];
s.w += g[t + 3] * k[t + 3];
F4(gw, t + T * i) = s;
}
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) {
dim3 gridDim(1, B * C / BF);
dim3 blockDim(T >> 2);
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
}
template <typename F>
__global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
F *__restrict__ const gw, F *__restrict__ const gk,
const int B, const int C, const int T) {
const int i = blockIdx.y;
const int ij = (B * C) / BB;
const int t = threadIdx.x << 2;
__shared__ F w[Tmax];
__shared__ F kk[Tmax * BB];
__shared__ F gg[Tmax * BB];
F4(w, t) = F4(__w, t + T * (i % C));
#pragma unroll
for (int j = 0; j < BB; j++) {
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
F4(gg, t + Tmax * j) = F4(__gwk, t + T * (i + ij * j));
}
__syncthreads();
float4 s[BB];
#pragma unroll
for (int j = 0; j < BB; j++) {
s[j] = {0, 0, 0, 0};
}
for (int u = 0; u <= t; u++) {
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
F x = kk[u + Tmax * j];
s[j].x += g[u + 3] * x;
s[j].y += g[u + 2] * x;
s[j].z += g[u + 1] * x;
s[j].w += g[u + 0] * x;
}
}
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const k = kk + Tmax * j;
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
s[j].y += g[t + 3] * k[t + 1];
s[j].z += g[t + 2] * k[t + 1];
s[j].z += g[t + 3] * k[t + 2];
s[j].w += g[t + 1] * k[t + 1];
s[j].w += g[t + 2] * k[t + 2];
s[j].w += g[t + 3] * k[t + 3];
F4(gw, t + T * (i + ij * j)) = s[j];
}
#pragma unroll
for (int j = 0; j < BB; j++) {
s[j] = {0, 0, 0, 0};
}
for (int u = t + 3; u < T; u++) {
F x = w[u];
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
s[j].x += g[2 - u] * x;
s[j].y += g[3 - u] * x;
s[j].z += g[4 - u] * x;
s[j].w += g[5 - u] * x;
}
}
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
s[j].x += g[2 - t] * w[t + 0];
s[j].x += g[1 - t] * w[t + 1];
s[j].x += g[0 - t] * w[t + 2];
s[j].y += g[2 - t] * w[t + 1];
s[j].y += g[1 - t] * w[t + 2];
s[j].z += g[2 - t] * w[t + 2];
F4(gk, t + T * (i + ij * j)) = s[j];
}
}
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) {
dim3 gridDim(1, B * C / BB);
dim3 blockDim(T >> 2);
kernel_backward<<<gridDim, blockDim>>>(w, k, gwk, gw, gk, B, C, T);
}

@ -0,0 +1,21 @@
#include <torch/extension.h>
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T);
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T);
void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) {
cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T);
}
void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) {
cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "timex forward");
m.def("backward", &backward, "timex backward");
}
TORCH_LIBRARY(timex, m) {
m.def("forward", forward);
m.def("backward", backward);
}

Binary file not shown.

@ -0,0 +1,133 @@
# -*- coding:utf-8 -*-
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math
import time
import types
import copy
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER, Dataset
from src.model_run import RWKV_RNN
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
### Step 1: set model ##################################################################################
ctx_len = 1024
n_layer = 6
n_embd = 512
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
# your trained model
MODEL_NAME = 'trained-31'
WORD_NAME = 'vocab' # the .json vocab (generated by train.py
# ########## Uncomment these to test my 27M params enwik8 model ##########
# MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13'
# WORD_NAME = 'enwik8-vocab'
# EVAL_DATA = 'enwik8' # uncomment this for EVAL MODE (no text generation)
# ########################################################################
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> all unknown tokens in your context will be denoted by it <--
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False # True False - show softmax output
### Step 2: set context ################################################################################
context = "\nIn the" # ==> this is your prompt
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 500
TEMPERATURE = 1.0
top_p = 0.7
top_p_newline = 0.9
########################################################################################################
print(f'Loading {MODEL_NAME}...')
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
########################################################################################################
if 'EVAL_DATA' in vars() or 'EVAL_DATA' in globals():
print('Evaluating on ' + EVAL_DATA + ' ...')
data = open(EVAL_DATA, "r", encoding='utf-8').read()
loss_table = np.zeros(ctx_len)
N_SAMPLE = 1000
for iii in range(N_SAMPLE):
pos = np.random.randint(0, len(data) - ctx_len-1)
context = data[pos:pos+ctx_len+1]
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
model.clear()
for i in range(1, ctx_len+1):
x = ctx[:i]
out = model.run(x)
prob = F.softmax(torch.tensor(out), dim=-1)
loss_table[i-1] += -math.log(prob[ctx[i]])
print(f'Tested {iii+1} samples: avg_loss over ctx_len =',
np.mean(loss_table) / (iii+1))
exit(0)
########################################################################################################
context = tokenizer.refine_context(context)
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n')
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
t_begin = time.time_ns()
src_len = len(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
print(('-' * 30) + context, end='')
model.clear()
if TRIAL == 0:
init_state = types.SimpleNamespace()
for i in range(src_len):
x = ctx[:i+1]
if i == src_len - 1:
init_state.out = model.run(x)
else:
model.run(x)
model.save(init_state)
else:
model.load(init_state)
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
x = ctx[:i+1]
x = x[-ctx_len:]
if i == src_len:
out = copy.deepcopy(init_state.out)
else:
out = model.run(x)
if DEBUG_DEBUG:
print('model', np.array(x), '==>', np.array(
out), np.max(out), np.min(out))
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
top_p_usual=top_p, top_p_newline=top_p_newline)
char = char.item()
print(tokenizer.itos[int(char)], end='', flush=True)
ctx += [char]
t_end = time.time_ns()
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')

@ -0,0 +1,349 @@
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
from torch.utils.cpp_extension import load
import math
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)
########################################################################################################
# CUDA Kernel
########################################################################################################
T_MAX = 1024 # increase this if your ctx_len > 1024
B_GROUP_FORWARD = 4 # set to 8 for best performance
B_GROUP_BACKWARD = 2 # set to 2 for best performance
timex_cuda = load(name="timex", sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}'])
class TimeX(torch.autograd.Function):
@staticmethod
def forward(ctx, w, k, B, C, T, eps):
ctx.B = B
ctx.C = C
ctx.T = T
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
w = w.contiguous()
k = k.contiguous()
ctx.save_for_backward(w, k)
wk = torch.empty((B, C, T), device='cuda',
memory_format=torch.contiguous_format)
timex_cuda.forward(w, k, wk, eps, B, C, T)
return wk
@staticmethod
def backward(ctx, gwk):
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
w, k = ctx.saved_tensors
gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
memory_format=torch.contiguous_format)
gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
memory_format=torch.contiguous_format)
timex_cuda.backward(w, k, gwk.contiguous(), gw,
gk, ctx.B, ctx.C, ctx.T)
return (gw.sum(dim=0), gk, None, None, None, None)
########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
RWKV_K_CLAMP = 60 # e^60 = 1e26
RWKV_K_EPS = 1e-16
RWKV_HEAD_QK_DIM = 256
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
with torch.no_grad():
name = '[unknown weight]'
for name, parameter in module.named_parameters(): # find the name of the weight
if id(m.weight) == id(parameter):
break
shape = m.weight.data.shape
gain = 1.0
scale = 1.0 # extra scale for gain
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
scale = 1e-4
else:
scale = 0
if isinstance(m, nn.Linear):
if m.bias is not None:
m.bias.data.zero_()
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
scale = 0.5
if hasattr(m, 'scale_init'):
scale = m.scale_init
# print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
gain *= scale
if scale == -999:
nn.init.eye_(m.weight)
elif gain == 0:
# zero init is great for some RWKV matrices
nn.init.zeros_(m.weight)
elif gain > 0:
nn.init.orthogonal_(m.weight, gain=gain)
else:
nn.init.normal_(m.weight, mean=0.0, std=-scale)
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_embd = config.n_embd
attn_sz = config.n_embd
############# fancy init of time_w curves ###################################
f1_begin = 3.0
f1_end = 1.2
f2_begin = 0.65
f2_end = 0.4
with torch.no_grad(): # initial time_w curves for better convergence
decay_speed = torch.ones(attn_sz, 1)
first_sa_layer_id = 1
for h in range(attn_sz):
f1 = f1_begin + (layer_id-first_sa_layer_id) / \
(config.n_layer-1-first_sa_layer_id) * (f1_end - f1_begin)
f2 = f2_begin + (layer_id-first_sa_layer_id) / \
(config.n_layer-1-first_sa_layer_id) * (f2_end - f2_begin)
if layer_id == first_sa_layer_id:
f1 += 0.5
if layer_id == config.n_layer-2:
f2 = 0.4
if layer_id == config.n_layer-1:
f2 = 0.37
decay_speed[h][0] = math.pow(f2, h / (attn_sz-1) * 7) * f1
self.time_decay = nn.Parameter(torch.log(decay_speed)) # will use exp(self.time_decay) to ensure time_decay > 0
self.time_curve = torch.tensor(
[-(config.ctx_len - 2 - i) for i in range(config.ctx_len-1)]).unsqueeze(0)
self.time_curve = self.time_curve.to('cuda')
self.time_first = nn.Parameter(torch.ones(attn_sz, 1) * math.log(0.3))
#############################################################################
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # init to "shift half of the channels"
ww = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd // 2):
ww[0, 0, i] = 0
self.time_mix = nn.Parameter(ww)
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 0
def forward(self, x):
B, T, C = x.size()
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
k = self.key(x).transpose(-1, -2)
v = self.value(x).transpose(-1, -2)
r = self.receptance(x)
# RWKV_K_CLAMP can be removed if the CUDA kernel substracts the correct k_max for each k (I will do this later)
k = torch.clamp(k, max=RWKV_K_CLAMP)
k = torch.exp(k)
kv = k * v
self.time_w = torch.cat(
[torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1)
w = torch.exp(self.time_w)
wkv = TimeX.apply(w, kv, B, C, T, 0)
# RWKV_K_EPS can be removed if the CUDA kernel sets 0/0 = 0 (I will do this later)
wk = TimeX.apply(w, k, B, C, T, RWKV_K_EPS)
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv)
return rwkv
class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # init to "shift half of the channels"
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd // 2):
x[0, 0, i] = 0
self.time_mix = nn.Parameter(x)
hidden_sz = 4 * config.n_embd
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
self.value.scale_init = 0
self.receptance.scale_init = 0
def forward(self, x):
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
k = self.key(x)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(x)) * kv
return rkv
########################################################################################################
# The GPT Model with our blocks
########################################################################################################
class GPTConfig:
def __init__(self, vocab_size, ctx_len, **kwargs):
self.vocab_size = vocab_size
self.ctx_len = ctx_len
for k, v in kwargs.items():
setattr(self, k, v)
class Block(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.config = config
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(config, layer_id+1000)
else:
self.att = RWKV_TimeMix(config, layer_id)
self.ffn = RWKV_ChannelMix(config, layer_id)
def forward(self, x):
x = self.ln1(x)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(x) # better in some cases
else:
x = x + self.att(x)
x = self.ln2(x)
x = x + self.ffn(x)
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.step = 0
self.config = config
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.Sequential(*[Block(config, i)
for i in range(config.n_layer)])
self.ln_out = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len
RWKV_Init(self, config)
logger.info("number of parameters: %e", sum(p.numel()
for p in self.parameters()))
def get_ctx_len(self):
return self.ctx_len
def _init_weights(self, module):
if isinstance(module, (nn.Linear)):
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, (nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=1e-5)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def configure_optimizers(self, train_config):
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
for mn, m in self.named_modules(): # here we disable weight_decay
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
no_decay.add(fpn)
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(
inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
optim_groups = [
{"params": [param_dict[pn]
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.Adam(
optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
return optimizer
def forward(self, idx, targets=None):
self.step += 1
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).float()
x = self.head(x) + c
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
return x, loss

@ -0,0 +1,143 @@
import types
import copy
import torch
from torch.nn import functional as F
RWKV_K_CLAMP = 60
RWKV_K_EPS = 1e-16
RWKV_HEAD_QK_DIM = 256
DEBUG_TIME = False # True False - show trained time-coeffs
class RWKV_RNN():
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
self.RUN_DEVICE = RUN_DEVICE
self.model_type = model_type
self.n_layer = n_layer
self.n_embd = n_embd
self.ctx_len = ctx_len
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth',
map_location=torch.device(RUN_DEVICE))
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
if DEBUG_TIME and '.time_' in x:
print(x, w[x].squeeze().cpu().numpy())
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
self.hk = None
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
target.hk = copy.deepcopy(self.hk)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
self.hk = copy.deepcopy(target.hk)
def LN(self, xx, w):
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.square(torch.relu(w.key.weight @ x))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.exp(torch.clamp(w.key.weight @ x, max=RWKV_K_CLAMP))
v = w.value.weight @ x
kv = k * v
a = self.aa[name] + w.time_first * kv
b = self.bb[name] + w.time_first * k
self.aa[name] = w.time_decay * self.aa[name] + kv
self.bb[name] = w.time_decay * self.bb[name] + k
rwkv = r * a / (b + RWKV_K_EPS)
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
for i in range(self.n_layer):
x = self.LN(x, w.blocks[i].ln1)
if i == 0 and self.model_type == 'RWKV-ffnPre':
x = x + self.FF(x, w.blocks[i].ffnPre, f'ffnPre.{i}')
else:
x = x + self.SA(x, w.blocks[i].att, f'att.{i}')
x = self.LN(x, w.blocks[i].ln2)
x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
if self.hk == None:
self.hk = (w.head_k.weight @ x).unsqueeze(0)
else:
self.hk = torch.cat(
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
if self.hk.shape[0] > self.ctx_len:
self.hk = self.hk[-self.ctx_len:, :]
q = w.head_q.weight @ x
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
for i in range(len(c)):
x[ctx[i]] += c[i]
return x

@ -0,0 +1,170 @@
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
from torch.utils.data.dataloader import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import functional as F
import torch.nn as nn
import torch.optim as optim
import torch
from tqdm.auto import tqdm
import numpy as np
import logging
import os
import datetime
import sys
import math
# import wandb # comment this if you don't have wandb
# print('logging to wandb... (comment it if you don\'t have wandb)')
logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
log_file = open("mylog.txt", "a")
class TrainerConfig:
max_epochs = 10
batch_size = 64
learning_rate = 4e-4
betas = (0.9, 0.99)
eps = 1e-8
grad_norm_clip = 1.0
lr_decay = True # linear warmup followed by cosine decay
warmup_tokens = 0
final_tokens = 0
epoch_save_frequency = 0
epoch_save_path = 'trained-'
num_workers = 0 # for DataLoader
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class Trainer:
def __init__(self, model, train_dataset, test_dataset, config):
self.model = model
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.config = config
self.avg_loss = -1
self.steps = 0
if 'wandb' in sys.modules:
cfg = model.config
for k in config.__dict__:
setattr(cfg, k, config.__dict__[k]) # combine cfg
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' +
datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
self.device = 'cpu'
if torch.cuda.is_available(): # take over whatever gpus are on the system
self.device = torch.cuda.current_device()
def get_run_name(self):
raw_model = self.model.module if hasattr(
self.model, "module") else self.model
cfg = raw_model.config
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
return run_name
def train(self):
model, config = self.model, self.config
raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config)
def run_epoch(split):
is_train = split == 'train'
model.train(is_train)
data = self.train_dataset if is_train else self.test_dataset
if config.num_workers > 0:
loader = DataLoader(data, shuffle=False, pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers)
else:
loader = DataLoader(data, shuffle=False,
batch_size=config.batch_size,
num_workers=config.num_workers)
pbar = tqdm(enumerate(loader), total=len(
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
for it, (x, y) in pbar:
x = x.to(self.device) # place data on the correct device
y = y.to(self.device)
with torch.set_grad_enabled(is_train):
_, loss = model(x, y) # forward the model
if is_train: # backprop and update the parameters
model.zero_grad()
loss.backward()
if config.grad_norm_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), config.grad_norm_clip)
optimizer.step()
if config.lr_decay: # decay the learning rate based on our progress
# number of tokens processed this step (i.e. label is not -100)
self.tokens += (y >= 0).sum()
lr_final_factor = config.lr_final / config.learning_rate
if self.tokens < config.warmup_tokens:
# linear warmup
lr_mult = lr_final_factor + \
(1 - lr_final_factor) * float(self.tokens) / \
float(config.warmup_tokens)
progress = 0
else:
# cosine learning rate decay
progress = float(self.tokens - config.warmup_tokens) / float(
max(1, config.final_tokens - config.warmup_tokens))
lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor /
2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1
lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
lr = config.learning_rate
now_loss = loss.item() # report progress
self.lr = lr
if 'wandb' in sys.modules:
wandb.log({"loss": now_loss},
step=self.steps * self.config.batch_size)
self.steps += 1
if self.avg_loss < 0:
self.avg_loss = now_loss
else:
factor = 1 / (it + 1)
self.avg_loss = self.avg_loss * \
(1.0 - factor) + now_loss * factor
pbar.set_description(
f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):
run_epoch('train')
log_file.write(
f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n')
log_file.flush()
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
# DataParallel wrappers keep raw model object in .module
raw_model = self.model.module if hasattr(
self.model, "module") else self.model
torch.save(raw_model.state_dict(),
self.config.epoch_save_path + str(epoch+1) + '.pth')

@ -0,0 +1,122 @@
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import json
import random
import time
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
class Dataset(Dataset):
def __init__(self, data, ctx_len, epoch_length_fixed):
print('building token list...', end=' ')
unique = sorted(list(set(data)))
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
data_size, vocab_size = len(data), len(unique)
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.vocab_size = vocab_size
self.data = data
def __len__(self):
return self.epoch_length_fixed
def __getitem__(self, idx):
# cheat: pick a random spot in dataset
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long,
device=torch.device('cuda'))
y = torch.tensor(dix[1:], dtype=torch.long,
device=torch.device('cuda'))
return x, y
class TOKENIZER():
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
self.word_table = json.load(result_file)
self.vocab_size = len(self.word_table)
self.stoi = {v: int(k) for k, v in self.word_table.items()}
self.itos = {int(k): v for k, v in self.word_table.items()}
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
# out[self.UNKNOWN_CHAR] = -float('Inf')
lastChar = int(x[-1])
probs = F.softmax(torch.tensor(out), dim=-1)
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
sorted_probs, s_index = torch.sort(probs, descending=True)
# for j in range(30):
# pp = sorted_probs[j].item()
# if pp < 0.005:
# break
# ss = self.itos[int(s_index[j])].replace('\n','_')
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
# print('')
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
return torch.multinomial(probs, num_samples=1)[0]
def to_float(x):
return x.cpu().detach().numpy().flatten()[0].astype(float)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

@ -0,0 +1,98 @@
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import logging
import datetime
import json
from src.model import GPT, GPTConfig
from src.trainer import Trainer, TrainerConfig
from src.utils import Dataset
import torch
import numpy as np
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
### Step 1: set training data ##########################################################################
datafile = "enwik8"
datafile_encoding = 'utf-8'
# datafile_encoding = 'utf-16le'
### Step 2: set model size #############################################################################
ctx_len = 1024 # ===> increase T_MAX in model.py if your ctx_len > 1024
n_layer = 6
n_embd = 512
# 'RWKV' (better for char-level English) or 'RWKV-ffnPre' (better in some cases)
model_type = 'RWKV'
### Step 3: set batch size #############################################################################
# ===> batch_size must be divisible by B_GROUP_FORWARD and B_GROUP_BACKWARD in model.py
# For example, if your batch_size = 20, you can set B_GROUP_FORWARD = 4, B_GROUP_BACKWARD = 2
# If you see "CUDA out of memory", reduce it. Use GPU-Z to find the highest value for your VRAM.
batch_size = 12
### Step 4: set learning rate, training mini-epochs #######################################################
lr_init = 6e-4
lr_final = 1e-5
# the mini-epoch is very short and of fixed length (ctx_len * epoch_length_fixed tokens)
n_epoch = 500
# 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, etc.
epoch_save_frequency = 30
epoch_save_path = 'trained-'
epoch_length_fixed = 10000
########################################################################################################
# import src.utils
# src.utils.set_seed(42) # remember to change seed if you load a model
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
grad_norm_clip = 1.0
warmup_tokens = 0
betas = (0.9, 0.99)
eps = 4e-9
num_workers = 0
########################################################################################################
# Load data
########################################################################################################
print('loading data... ' + datafile)
train_dataset = Dataset(open(
datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)
########################################################################################################
# Train model
########################################################################################################
if __name__ == '__main__':
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
n_layer=n_layer, n_embd=n_embd)).cuda()
# # # load a trained model. remember to change random seed
# m2 = torch.load('trained-61.pth')
# model.load_state_dict(m2)
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, )
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip,
warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()
torch.save(model.state_dict(), 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() +
'-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')

Binary file not shown.

After

Width:  |  Height:  |  Size: 121 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 321 KiB

@ -0,0 +1,172 @@
#include <stdio.h>
// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)
#define F4(A, B) ((float4 *)(A))[(B) >> 2]
template <typename F>
__global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,
const F eps, const int B, const int C, const int T) {
const int i = blockIdx.y;
const int ij = (B * C) / BF;
const int t = threadIdx.x << 2;
__shared__ F ww[Tmax];
__shared__ F kk[Tmax * BF];
F4(ww, t) = F4(__w, t + T * (i % C));
#pragma unroll
for (int j = 0; j < BF; j++) {
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
}
__syncthreads();
float4 s[BF];
#pragma unroll
for (int j = 0; j < BF; j++) {
s[j] = {eps, eps, eps, eps};
}
const F *__restrict__ const w = ww + T - t - 4;
for (int u = 0; u <= t; u++) {
#pragma unroll
for (int j = 0; j < BF; j++) {
const F x = kk[u + Tmax * j];
s[j].x += w[u + 3] * x;
s[j].y += w[u + 2] * x;
s[j].z += w[u + 1] * x;
s[j].w += w[u + 0] * x;
}
}
#pragma unroll
for (int j = 0; j < BF; j++) {
const F *__restrict__ const k = kk + Tmax * j;
s[j].y += w[t + 3] * k[t + 1];
s[j].z += w[t + 2] * k[t + 1];
s[j].z += w[t + 3] * k[t + 2];
s[j].w += w[t + 1] * k[t + 1];
s[j].w += w[t + 2] * k[t + 2];
s[j].w += w[t + 3] * k[t + 3];
F4(x, t + T * (i + ij * j)) = s[j];
}
}
template <typename F>
__global__ void kernel_backward_W(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
F *__restrict__ const gw, F *__restrict__ const gk,
const int B, const int C, const int T) {
const int i = blockIdx.y;
const int t = threadIdx.x << 2;
__shared__ F k[Tmax];
__shared__ F gg[Tmax];
F4(k, t) = F4(__k, t + T * i);
F4(gg, t) = F4(__gwk, t + T * i);
__syncthreads();
float4 s = {0, 0, 0, 0};
const F *__restrict__ const g = gg + T - t - 4;
for (int u = 0; u <= t; u++) {
F x = k[u];
s.x += g[u + 3] * x;
s.y += g[u + 2] * x;
s.z += g[u + 1] * x;
s.w += g[u + 0] * x;
}
s.y += g[t + 3] * k[t + 1];
s.z += g[t + 2] * k[t + 1];
s.z += g[t + 3] * k[t + 2];
s.w += g[t + 1] * k[t + 1];
s.w += g[t + 2] * k[t + 2];
s.w += g[t + 3] * k[t + 3];
F4(gw, t + T * i) = s;
}
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) {
dim3 gridDim(1, B * C / BF);
dim3 blockDim(T >> 2);
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
}
template <typename F>
__global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
F *__restrict__ const gw, F *__restrict__ const gk,
const int B, const int C, const int T) {
const int i = blockIdx.y;
const int ij = (B * C) / BB;
const int t = threadIdx.x << 2;
__shared__ F w[Tmax];
__shared__ F kk[Tmax * BB];
__shared__ F gg[Tmax * BB];
F4(w, t) = F4(__w, t + T * (i % C));
#pragma unroll
for (int j = 0; j < BB; j++) {
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
F4(gg, t + Tmax * j) = F4(__gwk, t + T * (i + ij * j));
}
__syncthreads();
float4 s[BB];
#pragma unroll
for (int j = 0; j < BB; j++) {
s[j] = {0, 0, 0, 0};
}
for (int u = 0; u <= t; u++) {
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
F x = kk[u + Tmax * j];
s[j].x += g[u + 3] * x;
s[j].y += g[u + 2] * x;
s[j].z += g[u + 1] * x;
s[j].w += g[u + 0] * x;
}
}
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const k = kk + Tmax * j;
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
s[j].y += g[t + 3] * k[t + 1];
s[j].z += g[t + 2] * k[t + 1];
s[j].z += g[t + 3] * k[t + 2];
s[j].w += g[t + 1] * k[t + 1];
s[j].w += g[t + 2] * k[t + 2];
s[j].w += g[t + 3] * k[t + 3];
F4(gw, t + T * (i + ij * j)) = s[j];
}
#pragma unroll
for (int j = 0; j < BB; j++) {
s[j] = {0, 0, 0, 0};
}
for (int u = t + 3; u < T; u++) {
F x = w[u];
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
s[j].x += g[2 - u] * x;
s[j].y += g[3 - u] * x;
s[j].z += g[4 - u] * x;
s[j].w += g[5 - u] * x;
}
}
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
s[j].x += g[2 - t] * w[t + 0];
s[j].x += g[1 - t] * w[t + 1];
s[j].x += g[0 - t] * w[t + 2];
s[j].y += g[2 - t] * w[t + 1];
s[j].y += g[1 - t] * w[t + 2];
s[j].z += g[2 - t] * w[t + 2];
F4(gk, t + T * (i + ij * j)) = s[j];
}
}
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) {
dim3 gridDim(1, B * C / BB);
dim3 blockDim(T >> 2);
kernel_backward<<<gridDim, blockDim>>>(w, k, gwk, gw, gk, B, C, T);
}

@ -0,0 +1,21 @@
#include <torch/extension.h>
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T);
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T);
void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) {
cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T);
}
void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) {
cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "timex forward");
m.def("backward", &backward, "timex backward");
}
TORCH_LIBRARY(timex, m) {
m.def("forward", forward);
m.def("backward", backward);
}

@ -0,0 +1,98 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math
import time
import types
import copy
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER, Dataset
from src.model_run import RWKV_RNN
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
### Step 1: set model ##################################################################################
ctx_len = 1024
n_layer = 6
n_embd = 512
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
# your trained model
MODEL_NAME = 'trained-1'
WORD_NAME = 'vocab' # the .json vocab (generated by train.py
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> all unknown tokens in your context will be denoted by it <--
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False # True False - show softmax output
### Step 2: set context ################################################################################
context = "\nIn the" # ==> this is your prompt
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 500
TEMPERATURE = 1.0
top_p = 0.7
top_p_newline = 0.9
########################################################################################################
print(f'Loading {MODEL_NAME}...')
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
########################################################################################################
context = tokenizer.refine_context(context)
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n')
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
t_begin = time.time_ns()
src_len = len(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
print(('-' * 30) + context, end='')
model.clear()
if TRIAL == 0:
init_state = types.SimpleNamespace()
for i in range(src_len):
x = ctx[:i+1]
if i == src_len - 1:
init_state.out = model.run(x)
else:
model.run(x)
model.save(init_state)
else:
model.load(init_state)
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
x = ctx[:i+1]
x = x[-ctx_len:]
if i == src_len:
out = copy.deepcopy(init_state.out)
else:
out = model.run(x)
if DEBUG_DEBUG:
print('model', np.array(x), '==>', np.array(
out), np.max(out), np.min(out))
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
top_p_usual=top_p, top_p_newline=top_p_newline)
char = char.item()
print(tokenizer.itos[int(char)], end='', flush=True)
ctx += [char]
t_end = time.time_ns()
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')

@ -0,0 +1,363 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
from torch.utils.cpp_extension import load
import math
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)
RWKV_K_CLAMP = 60 # e^60 = 1e26
RWKV_K_EPS = 1e-8
RWKV_HEAD_QK_DIM = 256
print(f'\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
########################################################################################################
# CUDA Kernel
########################################################################################################
T_MAX = 1024 # increase this if your ctx_len > 1024
B_GROUP_FORWARD = 4 # set to 8 for best performance
B_GROUP_BACKWARD = 2 # set to 2 for best performance (sometimes 8 is faster)
timex_cuda = load(name="timex", sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}'])
class TimeX(torch.autograd.Function):
@staticmethod
def forward(ctx, w, k, B, C, T, eps):
ctx.B = B
ctx.C = C
ctx.T = T
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
w = w.contiguous()
k = k.contiguous()
ctx.save_for_backward(w, k)
wk = torch.empty((B, C, T), device='cuda',
memory_format=torch.contiguous_format)
timex_cuda.forward(w, k, wk, eps, B, C, T)
return wk
@staticmethod
def backward(ctx, gwk):
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
w, k = ctx.saved_tensors
gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
memory_format=torch.contiguous_format)
gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
memory_format=torch.contiguous_format)
timex_cuda.backward(w, k, gwk.contiguous(), gw,
gk, ctx.B, ctx.C, ctx.T)
return (gw.sum(dim=0), gk, None, None, None, None)
########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
with torch.no_grad():
name = '[unknown weight]'
for name, parameter in module.named_parameters(): # find the name of the weight
if id(m.weight) == id(parameter):
break
shape = m.weight.data.shape
gain = 1.0
scale = 1.0 # extra scale for gain
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
scale = 1e-4
else:
scale = 0
if isinstance(m, nn.Linear):
if m.bias is not None:
m.bias.data.zero_()
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
scale = 0.5
if hasattr(m, 'scale_init'):
scale = m.scale_init
# print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
gain *= scale
if scale == -999:
nn.init.eye_(m.weight)
elif gain == 0:
# zero init is great for some RWKV matrices
nn.init.zeros_(m.weight)
elif gain > 0:
nn.init.orthogonal_(m.weight, gain=gain)
else:
nn.init.normal_(m.weight, mean=0.0, std=-scale)
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_embd = config.n_embd
attn_sz = config.n_embd
with torch.no_grad(): # fancy init
self.time_curve = torch.tensor([-(config.ctx_len - 2 - i) for i in range(config.ctx_len-1)]).unsqueeze(0)
self.time_curve = self.time_curve.to('cuda')
ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
# fancy time_decay
decay_speed = torch.ones(attn_sz, 1)
for h in range(attn_sz):
decay_speed[h][0] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1)
self.time_decay = nn.Parameter(decay_speed)
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
# fancy time_first
zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5).unsqueeze(1)
self.time_first = nn.Parameter(torch.ones(attn_sz, 1) * math.log(0.3) + zigzag)
# fancy time_mix
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd):
x[0, 0, i] = i / config.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 0
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1))
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
# RWKV_K_CLAMP can be removed if the CUDA kernel substracts the correct k_max for each k (I will do this later)
k = torch.clamp(k, max=RWKV_K_CLAMP) # clamp k to avoid overflow
k = torch.exp(k)
kv = k * v
# Compute the W-curve = [e^(-n * e^time_decay), e^(-(n-1) * e^time_decay), ..., 1, e^(time_first)]
self.time_w = torch.cat(
[torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1)
w = torch.exp(self.time_w)
# Use W to mix kv and k respectively. Add K_EPS to wk to avoid divide-by-zero
wkv = TimeX.apply(w, kv, B, C, T, 0)
# RWKV_K_EPS can be removed if the CUDA kernel sets 0/0 = 0 (I will do this later)
wk = TimeX.apply(w, k, B, C, T, RWKV_K_EPS)
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv)
return rwkv
class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd):
x[0, 0, i] = i / config.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
hidden_sz = 4 * config.n_embd
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
self.value.scale_init = 0
self.receptance.scale_init = 0
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
########################################################################################################
# The GPT Model with our blocks
########################################################################################################
class GPTConfig:
def __init__(self, vocab_size, ctx_len, **kwargs):
self.vocab_size = vocab_size
self.ctx_len = ctx_len
for k, v in kwargs.items():
setattr(self, k, v)
class Block(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.config = config
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(config, layer_id+1000)
else:
self.att = RWKV_TimeMix(config, layer_id)
self.ffn = RWKV_ChannelMix(config, layer_id)
def forward(self, x):
if self.layer_id == 0:
x = self.ln0(x)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(self.ln1(x)) # better in some cases
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.step = 0
self.config = config
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.Sequential(*[Block(config, i)
for i in range(config.n_layer)])
self.ln_out = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
if RWKV_HEAD_QK_DIM > 0:
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len
RWKV_Init(self, config)
logger.info("number of parameters: %e", sum(p.numel()
for p in self.parameters()))
def get_ctx_len(self):
return self.ctx_len
def _init_weights(self, module):
if isinstance(module, (nn.Linear)):
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, (nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=1e-5)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def configure_optimizers(self, train_config):
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
for mn, m in self.named_modules(): # here we disable weight_decay
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
no_decay.add(fpn)
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(
inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
optim_groups = [
{"params": [param_dict[pn]
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.Adam(
optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
return optimizer
def forward(self, idx, targets=None):
self.step += 1
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
if RWKV_HEAD_QK_DIM > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).float()
x = self.head(x) + c
else:
x = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
return x, loss

@ -0,0 +1,319 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import types
import copy
import torch
import math
from torch.nn import functional as F
import torch.nn as nn
RWKV_K_CLAMP = 60
RWKV_K_EPS = 1e-8
RWKV_HEAD_QK_DIM = 256
print(f'\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
DEBUG_TIME = False # True False - show trained time-coeffs
############################################################################################################
RWKV_CFG = types.SimpleNamespace()
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
hidden_sz = 4 * RWKV_CFG.n_embd
self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False)
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
class RWKV_TimeMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1))
self.time_curve = torch.tensor([-(RWKV_CFG.ctx_len - 2 - i) for i in range(RWKV_CFG.ctx_len-1)]).unsqueeze(0)
self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1) * math.log(0.3))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
def forward(self, x):
B, T, C = x.size()
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
k = torch.clamp(k, max=RWKV_K_CLAMP)
k = torch.exp(k)
kv = k * v
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
w = torch.exp(self.time_w)
w = w[:,-T:].unsqueeze(1)
wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + RWKV_K_EPS
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv)
return rwkv
class Block(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd)
self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd)
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(layer_id+1000)
else:
self.att = RWKV_TimeMix(layer_id)
self.ffn = RWKV_ChannelMix(layer_id)
def forward(self, x):
if self.layer_id == 0:
x = self.ln0(x)
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(self.ln1(x))
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class RWKV_GPT(nn.Module):
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len):
global RWKV_CFG
super().__init__()
RWKV_CFG.RUN_DEVICE = RUN_DEVICE
RWKV_CFG.model_type = model_type
RWKV_CFG.vocab_size = vocab_size
RWKV_CFG.n_layer = n_layer
RWKV_CFG.n_embd = n_embd
RWKV_CFG.ctx_len = ctx_len
print('\nloading RWKV-GPT', MODEL_NAME)
self.emb = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
self.ln_out = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
if RWKV_HEAD_QK_DIM > 0:
self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(ctx_len, ctx_len)))
self.ctx_len = ctx_len
self.eval()
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
self.eval()
def forward(self, idx):
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
if RWKV_HEAD_QK_DIM > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).float()
x = self.head(x) + c
else:
x = self.head(x)
return x
############################################################################################################
class RWKV_RNN():
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
self.RUN_DEVICE = RUN_DEVICE
self.model_type = model_type
self.n_layer = n_layer
self.n_embd = n_embd
self.ctx_len = ctx_len
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth',
map_location=torch.device(RUN_DEVICE))
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
if DEBUG_TIME and '.time_' in x:
print(x, w[x].squeeze().cpu().numpy())
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
self.hk = None
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
target.hk = copy.deepcopy(self.hk)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
self.hk = copy.deepcopy(target.hk)
def LN(self, xx, w):
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.square(torch.relu(w.key.weight @ xk))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.exp(torch.clamp(w.key.weight @ xk, max=RWKV_K_CLAMP))
v = w.value.weight @ xv
kv = k * v
a = self.aa[name] + w.time_first * kv
b = self.bb[name] + w.time_first * k
self.aa[name] = w.time_decay * self.aa[name] + kv
self.bb[name] = w.time_decay * self.bb[name] + k
rwkv = r * a / (b + RWKV_K_EPS)
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
for i in range(self.n_layer):
if i == 0:
x = self.LN(x, w.blocks[i].ln0)
if i == 0 and self.model_type == 'RWKV-ffnPre':
x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
else:
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
if RWKV_HEAD_QK_DIM > 0:
if self.hk == None:
self.hk = (w.head_k.weight @ x).unsqueeze(0)
else:
self.hk = torch.cat(
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
if self.hk.shape[0] > self.ctx_len:
self.hk = self.hk[-self.ctx_len:, :]
q = w.head_q.weight @ x
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
for i in range(len(c)):
x[ctx[i]] += c[i]
else:
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
return x

@ -0,0 +1,171 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
from torch.utils.data.dataloader import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import functional as F
import torch.nn as nn
import torch.optim as optim
import torch
from tqdm.auto import tqdm
import numpy as np
import logging
import os
import datetime
import sys
import math
# import wandb # comment this if you don't have wandb
# print('logging to wandb... (comment it if you don\'t have wandb)')
logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
log_file = open("mylog.txt", "a")
class TrainerConfig:
max_epochs = 10
batch_size = 64
learning_rate = 4e-4
betas = (0.9, 0.99)
eps = 1e-8
grad_norm_clip = 1.0
lr_decay = True # linear warmup followed by cosine decay
warmup_tokens = 0
final_tokens = 0
epoch_save_frequency = 0
epoch_save_path = 'trained-'
num_workers = 0 # for DataLoader
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class Trainer:
def __init__(self, model, train_dataset, test_dataset, config):
self.model = model
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.config = config
self.avg_loss = -1
self.steps = 0
if 'wandb' in sys.modules:
cfg = model.config
for k in config.__dict__:
setattr(cfg, k, config.__dict__[k]) # combine cfg
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' +
datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
self.device = 'cpu'
if torch.cuda.is_available(): # take over whatever gpus are on the system
self.device = torch.cuda.current_device()
def get_run_name(self):
raw_model = self.model.module if hasattr(
self.model, "module") else self.model
cfg = raw_model.config
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
return run_name
def train(self):
model, config = self.model, self.config
raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config)
def run_epoch(split):
is_train = split == 'train'
model.train(is_train)
data = self.train_dataset if is_train else self.test_dataset
if config.num_workers > 0:
loader = DataLoader(data, shuffle=False, pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers)
else:
loader = DataLoader(data, shuffle=False,
batch_size=config.batch_size,
num_workers=config.num_workers)
pbar = tqdm(enumerate(loader), total=len(
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
for it, (x, y) in pbar:
x = x.to(self.device) # place data on the correct device
y = y.to(self.device)
with torch.set_grad_enabled(is_train):
_, loss = model(x, y) # forward the model
if is_train: # backprop and update the parameters
model.zero_grad()
loss.backward()
if config.grad_norm_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), config.grad_norm_clip)
optimizer.step()
if config.lr_decay: # decay the learning rate based on our progress
# number of tokens processed this step (i.e. label is not -100)
self.tokens += (y >= 0).sum()
lr_final_factor = config.lr_final / config.learning_rate
if self.tokens < config.warmup_tokens:
# linear warmup
lr_mult = lr_final_factor + \
(1 - lr_final_factor) * float(self.tokens) / \
float(config.warmup_tokens)
progress = 0
else:
# exponential learning rate decay
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
if progress >= 1:
lr_mult = lr_final_factor
else:
lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1))
lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
lr = config.learning_rate
now_loss = loss.item() # report progress
self.lr = lr
if 'wandb' in sys.modules:
wandb.log({"loss": now_loss},
step=self.steps * self.config.batch_size)
self.steps += 1
if self.avg_loss < 0:
self.avg_loss = now_loss
else:
factor = 1 / (it + 1)
self.avg_loss = self.avg_loss * \
(1.0 - factor) + now_loss * factor
pbar.set_description(
f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):
run_epoch('train')
log_file.write(
f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n')
log_file.flush()
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
# DataParallel wrappers keep raw model object in .module
raw_model = self.model.module if hasattr(
self.model, "module") else self.model
torch.save(raw_model.state_dict(),
self.config.epoch_save_path + str(epoch+1) + '.pth')

@ -0,0 +1,122 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import json
import random
import time
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
class Dataset(Dataset):
def __init__(self, data, ctx_len, epoch_length_fixed):
print('building token list...', end=' ')
unique = sorted(list(set(data)))
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
data_size, vocab_size = len(data), len(unique)
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.vocab_size = vocab_size
self.data = data
def __len__(self):
return self.epoch_length_fixed
def __getitem__(self, idx):
# cheat: pick a random spot in dataset
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long,
device=torch.device('cuda'))
y = torch.tensor(dix[1:], dtype=torch.long,
device=torch.device('cuda'))
return x, y
class TOKENIZER():
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
self.word_table = json.load(result_file)
self.vocab_size = len(self.word_table)
self.stoi = {v: int(k) for k, v in self.word_table.items()}
self.itos = {int(k): v for k, v in self.word_table.items()}
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
# out[self.UNKNOWN_CHAR] = -float('Inf')
lastChar = int(x[-1])
probs = F.softmax(torch.tensor(out), dim=-1)
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
sorted_probs, s_index = torch.sort(probs, descending=True)
# for j in range(30):
# pp = sorted_probs[j].item()
# if pp < 0.005:
# break
# ss = self.itos[int(s_index[j])].replace('\n','_')
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
# print('')
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
return torch.multinomial(probs, num_samples=1)[0]
def to_float(x):
return x.cpu().detach().numpy().flatten()[0].astype(float)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

@ -0,0 +1,118 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os
# if False: # True False ---> Set to False if you don't understand it
# print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n")
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# import src.utils
# src.utils.set_seed(42) # make training deterministic (including dataloader). if you are doing this, remember to change seed when you load a model (otherwise the dataloader loads old samples)
import logging
import datetime
from src.model import GPT, GPTConfig
from src.trainer import Trainer, TrainerConfig
from src.utils import Dataset
import torch
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
### Step 1: set training data ##########################################################################
datafile = "../data/enwik8" # your data
datafile_encoding = 'utf-8'
# datafile_encoding = 'utf-16le'
### Step 2: set model size #############################################################################
# ----> test deeper models (n_layer at least 12) to see the advantage of RWKV-3 over RWKV-2
ctx_len = 1024 # increase T_MAX in model.py if your ctx_len > 1024
n_layer = 6
n_embd = 512
# 'RWKV' (better for English) or 'RWKV-ffnPre' (better in some cases)
model_type = 'RWKV'
# ---> there is a RWKV_HEAD_QK_DIM in model.py and model_run.py
# set it to 256, then it's using my headQK trick (similar to a tiny attention) to improve loss
# set it to 0, then it's a pure RNN (attention-free)
### Step 3: set batch size #############################################################################
# ---> batch_size must be divisible by B_GROUP_FORWARD and B_GROUP_BACKWARD in model.py
# for example, if your batch_size = 20, you can set B_GROUP_FORWARD = 4, B_GROUP_BACKWARD = 2
# if you see "CUDA out of memory", reduce batch_size. Use nvidia-smi to find the highest value for your GPU.
batch_size = 12
### Step 4: set learning rate, number of mini-epochs #######################################################
# By default we are using exponential LR decay.
#
# Here are my suggestions for training a good model.
# Let's say you will train a L6-D512 model.
# 1) Set lr_init = lr_final = 8e-4. Let it run for some mini-epochs, until the improvement of loss become slow.
# 2) Check epoch_save_frequency and make sure the partially-trained model is saved. Ctrl+C to stop the run.
# 3) Set lr_init = 8e-4, lr_final = 1e-5, warmup_tokens = ctx_len * batch_size * 50, betas = (0.9, 0.999).
# 4) Search for "torch.load" here and modify it to load the partially-trained model. Continue the training.
#
# For L12-D768, set lr_init = 6e-4. For L24-D1024, set lr_init = 4e-4. For L24-D2048, set lr_init = 3e-4.
lr_init = 8e-4 # we can use larger lr because of preLN
lr_final = 1e-5
# the mini-epoch is very short and of fixed length (length = ctx_len * epoch_length_fixed tokens)
n_epoch = 500
epoch_length_fixed = 10000
# 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, ...
epoch_save_frequency = 10
epoch_save_path = 'trained-'
########################################################################################################
grad_norm_clip = 1.0
warmup_tokens = ctx_len * batch_size * 0
betas = (0.9, 0.99)
eps = 4e-9
num_workers = 0
########################################################################################################
# Load data
########################################################################################################
print('loading data... ' + datafile)
train_dataset = Dataset(open(
datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)
########################################################################################################
# Train model
########################################################################################################
if __name__ == '__main__':
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
n_layer=n_layer, n_embd=n_embd)).cuda()
### ---> load a trained model <---
# m2 = torch.load('trained-61.pth')
# model.load_state_dict(m2)
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, )
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip,
warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()
torch.save(model.state_dict(), 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() +
'-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')

@ -0,0 +1,65 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
# this is for verifying the results of different models and make sure they agree with each other
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
RUN_DEVICE = 'cuda'
import torch
from src.model_run import RWKV_RNN, RWKV_GPT
from src.model import GPT, GPTConfig
ctx_len = 1024
n_layer = 6
n_embd = 512
model_type = 'RWKV'
model_name = 'trained-1'
from src.utils import TOKENIZER
tokenizer = TOKENIZER('vocab', UNKNOWN_CHAR=' ')
########################################################################################################
model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda()
print('loading ' + model_name)
m2 = torch.load(model_name + '.pth', map_location=RUN_DEVICE)
model_train.load_state_dict(m2)
model_rnn = RWKV_RNN(model_name, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
model_gpt = RWKV_GPT(model_name, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len).cuda()
########################################################################################################
context = '\nIn a'
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
print(f'input len {len(ctx)} data {ctx}')
########################################################################################################
print('\nRWKV-GPT output')
out = model_gpt.forward(torch.tensor(ctx).unsqueeze(0).cuda())[0].detach().cpu().numpy()
print(out)
print('\nRWKV-RNN output')
model_rnn.clear()
src_len = len(ctx)
for i in range(src_len):
x = ctx[:i+1]
out = model_rnn.run(x)
if i < 3 or i >= src_len - 3:
print(torch.tensor(out).detach().cpu().numpy())
if i == 2:
print('...')
print('\nRWKV-train output')
ctx += [0] * (ctx_len - src_len) # pad to ctx_len
ctx = [ctx] * 4 # increase batch size (to make it work with B_GROUP_FORWARD & B_GROUP_BACKWARD)
out = model_train.forward(torch.tensor(ctx).cuda())[0][0][:src_len].detach().cpu().numpy()
print(out, '\n')

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

File diff suppressed because it is too large Load Diff

@ -0,0 +1,125 @@
#include <stdio.h>
#include <assert.h>
#define MIN_VALUE (-1e38)
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
F *__restrict__ const _y) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
F *__restrict__ const y = _y + _offset;
F p = 0, q = 0, o = MIN_VALUE;
// p and q are running sums divided by exp(o) (to avoid overflows)
for (int i = 0; i < T; i++) {
const int ii = i * C;
F no = max(o, u + k[ii]);
F A = exp(o - no);
F B = exp(u + k[ii] - no);
y[ii] = (A * p + B * v[ii]) / (A * q + B);
no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
}
}
template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
const F *__restrict__ const gy = _gy + _offset;
F *__restrict__ const gk = _gk + _offset;
F *__restrict__ const gv = _gv + _offset;
F y[Tmax], z[Tmax], zexp[Tmax];
F gw = 0, gu = 0;
F p = 0, q = 0;
F dpdw = 0, dqdw = 0;
F o = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
F no = max(o, k[ii] + u);
F A = exp(o - no);
F B = exp(k[ii] + u - no);
F num = A * p + B * v[ii];
F iden = 1 / (A * q + B);
y[i] = num * iden;
z[i] = iden;
zexp[i] = k[ii] + u - no;
gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
gu += gy[ii] * (v[ii] - y[i]) * B * iden;
no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
dpdw = A * (p + dpdw);
dqdw = A * (q + dqdw);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
}
F gp = 0, gq = 0;
o = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) {
const int ii = i * C;
F A = gy[ii] * z[i] * exp(zexp[i]);
F B = exp(k[ii] + o);
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
gv[ii] = A + B * gp;
F no = max(w + o, zexp[i] - k[ii] - u);
A = exp(w + o - no);
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
gp = A * gp + B;
gq = A * gq - B * y[i];
o = no;
}
// Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] += gw * _w[_c];
_gu[_offsetBC] += gu;
}
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
}
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
}

@ -0,0 +1,21 @@
#include <torch/extension.h>
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv forward");
m.def("backward", &backward, "wkv backward");
}
TORCH_LIBRARY(wkv, m) {
m.def("forward", forward);
m.def("backward", backward);
}

@ -0,0 +1,149 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math, os
import time
import types
import copy
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER, Dataset
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
########################################################################################################
# Step 1: set model
#
# Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch.
#
# Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models.
########################################################################################################
TOKEN_MODE = 'char' # char / bpe / pile
n_layer = 6
n_embd = 512
ctx_len = 1024
if TOKEN_MODE == 'char':
MODEL_NAME = 'trained-500' # your trained model
WORD_NAME = 'vocab' # the .json vocab (generated by train.py)
# set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it
UNKNOWN_CHAR = ' ' # here we just set it to ' ' for simplicity
elif TOKEN_MODE == 'bpe':
MODEL_NAME = 'trained-500' # your trained model
WORD_NAME = ['model-vocab.json', 'model-merges.txt'] # [vocab, merge] for your BPE model
UNKNOWN_CHAR = None
elif TOKEN_MODE == 'pile':
WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json']
UNKNOWN_CHAR = None
#---> you can set MODEL_NAME to your fine-tuned model <---
MODEL_NAME = 'RWKV-4-Pile-169M-20220807-8023'
# MODEL_NAME = 'trained-11'
n_layer = 12
n_embd = 768
ctx_len = 1024
# MODEL_NAME = 'RWKV-4-Pile-430M-20220808-8066'
# n_layer = 24
# n_embd = 1024
# ctx_len = 1024
# MODEL_NAME = 'RWKV-4-Pile-1B5-20220903-8040'
# n_layer = 24
# n_embd = 2048
# ctx_len = 1024
os.environ['RWKV_FLOAT_MODE'] = 'fp32' # 'bf16' / 'fp16' / 'fp32' (note: only using fp32 at this moment)
os.environ['RWKV_RUN_DEVICE'] = 'cpu' # 'cpu' (already very fast) or 'cuda'
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
########################################################################################################
# Step 2: set prompt & sampling stuffs
########################################################################################################
# context = 'A'
# context = "\nIn the"
# context = '\nSugar:'
context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.'
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 333
TEMPERATURE = 1.0
top_p = 0.7
top_p_newline = 0.9 # only used in TOKEN_MODE = char
DEBUG_DEBUG = False # True False --> show softmax output
########################################################################################################
print(f'Loading {MODEL_NAME}...')
from src.model_run import RWKV_RNN
model = RWKV_RNN(MODEL_NAME, os.environ['RWKV_RUN_DEVICE'], model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
########################################################################################################
if tokenizer.charMode:
context = tokenizer.refine_context(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
else:
ctx = tokenizer.tokenizer.encode(context)
src_len = len(ctx)
src_ctx = ctx.copy()
print('\nYour prompt has ' + str(src_len) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n')
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
t_begin = time.time_ns()
print(('-' * 30) + context, end='')
ctx = src_ctx.copy()
model.clear()
if TRIAL == 0:
init_state = types.SimpleNamespace()
for i in range(src_len):
x = ctx[:i+1]
if i == src_len - 1:
init_state.out = model.run(x)
else:
model.run(x)
model.save(init_state)
else:
model.load(init_state)
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
x = ctx[:i+1]
x = x[-ctx_len:]
if i == src_len:
out = copy.deepcopy(init_state.out)
else:
out = model.run(x)
if DEBUG_DEBUG:
print('model', np.array(x), '==>', np.array(
out), np.max(out), np.min(out))
if TOKEN_MODE == 'pile':
out[0] = -999999999 # disable <|endoftext|>
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
top_p_usual=top_p, top_p_newline=top_p_newline)
char = char.item()
if tokenizer.charMode:
print(tokenizer.itos[int(char)], end='', flush=True)
else:
print(tokenizer.tokenizer.decode(int(char)), end='', flush=True)
ctx += [char]
t_end = time.time_ns()
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')

@ -0,0 +1,216 @@
from lib2to3.pgen2 import token
import os
import torch
import numpy as np
import shutil
import struct
from functools import lru_cache
from itertools import accumulate
def print_rank_0(*message):
"""If distributed is initialized print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(*message, flush=True)
else:
print(*message, flush=True)
def _warmup_mmap_file(path):
pass
# with open(path, "rb") as stream:
# while stream.read(100 * 1024 * 1024):
# pass
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: float,
7: np.double,
8: np.uint16,
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + ".idx"
def data_file_path(prefix_path):
return prefix_path + ".bin"
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b"MMIDIDX\x00\x00"
def __init__(self, path, skip_warmup=False):
with open(path, "rb") as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
# Little endian unsigned 64 Bit integer
version = struct.unpack("<Q", stream.read(8))
assert (1,) == version
# Little endian unsigned 8 Bit integer
(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack("<Q", stream.read(8))[0]
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()
if not skip_warmup:
print_rank_0(" warming up index mmap file...")
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print_rank_0(" reading sizes...")
self._sizes = np.frombuffer(
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes,
)
print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@property
def doc_idx(self):
return self._doc_idx
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path, skip_warmup=False):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path, skip_warmup)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, path, skip_warmup):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup:
print_rank_0(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print_rank_0(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(
data_file_path(self._path), mode="r", order="C"
)
print_rank_0(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index
def __len__(self):
return len(self._index)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
ptr, size = self._index[idx]
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError(
"Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
)
sents = np.split(np_array, offsets[:-1])
return sents
def get(self, idx, offset=0, length=None):
"""Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
)
return np_array
@property
def sizes(self):
return self._index.sizes
@property
def doc_idx(self):
return self._index.doc_idx
def get_doc_idx(self):
return self._index._doc_idx
def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(
data_file_path(path)
)

@ -0,0 +1,414 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import math, os
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
try:
from deepspeed.ops.adam import FusedAdam
except:
pass # some poor windows users cant install deepspeed
logger = logging.getLogger(__name__)
RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
class L2Wrap(torch.autograd.Function):
@staticmethod
def forward(ctx, loss, y):
ctx.save_for_backward(y)
return loss
@staticmethod
def backward(ctx, grad_output):
y = ctx.saved_tensors[0]
# to encourage the logits to be close to 0
factor = 1e-4 / (y.shape[0] * y.shape[1])
maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y)
gy.scatter_(-1, ids, maxx * factor)
return (grad_output, gy)
########################################################################################################
# CUDA Kernel
########################################################################################################
T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
if '32' in os.environ['RWKV_FLOAT_MODE']:
w = -torch.exp(w.contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
else:
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
if '32' in os.environ['RWKV_FLOAT_MODE']:
return y
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return y.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return y.bfloat16()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda').contiguous()
gu = torch.zeros((B, C), device='cuda').contiguous()
gk = torch.zeros((B, T, C), device='cuda').contiguous()
gv = torch.zeros((B, T, C), device='cuda').contiguous()
if '32' in os.environ['RWKV_FLOAT_MODE']:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
if '32' in os.environ['RWKV_FLOAT_MODE']:
return (None, None, None, gw, gu, gk, gv)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in the model
print("\n[--> first run, init model params (very slow for large models) <--]")
print("[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n")
for mm in model.modules():
if "RecursiveScriptModule" in str(type(mm)):
if mm.original_name not in ["Linear"]:
continue
ww = None
for name, param in mm.named_parameters():
if name == "weight":
ww = param
else:
m = mm
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
ww = m.weight
with torch.no_grad():
name = "[unknown weight]"
for name, parameter in model.named_parameters(): # find the name of the weight
if id(ww) == id(parameter):
break
shape = ww.shape
gain = 1.0
scale = 1.0 # extra scale for gain
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # token emb?
scale = 1e-4
else:
scale = 0
if isinstance(m, nn.Linear):
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # final projection?
scale = 0.5
if hasattr(m, "scale_init"):
scale = m.scale_init
# print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {name}")
gain *= scale
if scale == -999:
nn.init.eye_(ww)
elif gain == 0:
# zero init is great for some RWKV matrices
nn.init.zeros_(ww)
elif gain > 0:
nn.init.orthogonal_(ww, gain=gain)
else:
nn.init.normal_(ww, mean=0.0, std=-scale)
class RWKV_TimeMix(torch.jit.ScriptModule):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_embd = config.n_embd
attn_sz = config.n_embd
with torch.no_grad(): # fancy init
ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
# fancy time_decay
decay_speed = torch.ones(attn_sz)
for h in range(attn_sz):
decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1)
self.time_decay = nn.Parameter(decay_speed)
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
# fancy time_first
zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5)
self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
# fancy time_mix
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd):
x[0, 0, i] = i / config.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 0
@torch.jit.script_method
def jit_func(self, x):
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)
return sr, k, v
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v = self.jit_func(x)
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
rwkv = self.output(rwkv)
return rwkv
class RWKV_ChannelMix(torch.jit.ScriptModule):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd):
x[0, 0, i] = i / config.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
hidden_sz = 4 * config.n_embd
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
self.value.scale_init = 0
self.receptance.scale_init = 0
@torch.jit.script_method
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
########################################################################################################
# The GPT Model with our blocks
########################################################################################################
class GPTConfig:
def __init__(self, vocab_size, ctx_len, **kwargs):
self.vocab_size = vocab_size
self.ctx_len = ctx_len
for k, v in kwargs.items():
setattr(self, k, v)
class Block(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.config = config
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(config, 0)
else:
self.att = RWKV_TimeMix(config, layer_id)
self.ffn = RWKV_ChannelMix(config, layer_id)
def forward(self, x):
if self.layer_id == 0:
x = self.ln0(x)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(self.ln1(x)) # better in some cases
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.step = 0
self.config = config
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.Sequential(*[Block(config, i)
for i in range(config.n_layer)])
self.ln_out = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
if RWKV_HEAD_QK_DIM > 0:
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len
try:
if os.environ['RWKV_LOAD_MODEL'] == str(False):
RWKV_Init(self, config)
except:
pass
logger.info("number of parameters: %e", sum(p.numel()
for p in self.parameters()))
def get_ctx_len(self):
return self.ctx_len
def _init_weights(self, module):
if isinstance(module, (nn.Linear)):
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, (nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=1e-5)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def configure_optimizers(self, train_config):
no_decay = set()
for mn, m in self.named_modules(): # here we disable weight_decay
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
no_decay.add(fpn)
param_dict = {pn: p for pn, p in self.named_parameters()}
optim_groups = [
{"params": [param_dict[pn]
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
try:
optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
except:
print('\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n')
optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
return optimizer
def forward(self, idx, targets=None):
idx = idx.to(self.emb.weight.device)
self.step += 1
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
if RWKV_HEAD_QK_DIM > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
if '32' in os.environ['RWKV_FLOAT_MODE']:
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).bfloat16()
x = self.head(x) + c
else:
x = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1))
return L2Wrap.apply(loss, x)

@ -0,0 +1,392 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import types
import copy
import torch
import math, os
from torch.nn import functional as F
import torch.nn as nn
RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
DEBUG_TIME = False # True False - show trained time-coeffs
########################################################################################################
# CUDA Kernel
########################################################################################################
if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
if '32' in os.environ['RWKV_FLOAT_MODE']:
w = -torch.exp(w.contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
else:
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
if '32' in os.environ['RWKV_FLOAT_MODE']:
return y
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return y.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return y.bfloat16()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda').contiguous()
gu = torch.zeros((B, C), device='cuda').contiguous()
gk = torch.zeros((B, T, C), device='cuda').contiguous()
gv = torch.zeros((B, T, C), device='cuda').contiguous()
if '32' in os.environ['RWKV_FLOAT_MODE']:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
if '32' in os.environ['RWKV_FLOAT_MODE']:
return (None, None, None, gw, gu, gk, gv)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
############################################################################################################
RWKV_CFG = types.SimpleNamespace()
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
hidden_sz = 4 * RWKV_CFG.n_embd
self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False)
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
class RWKV_TimeMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd))
self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd) * math.log(0.3))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
def forward(self, x):
B, T, C = x.size()
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
rwkv = self.output(rwkv)
return rwkv
class Block(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd)
self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd)
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(layer_id+1000)
else:
self.att = RWKV_TimeMix(layer_id)
self.ffn = RWKV_ChannelMix(layer_id)
def forward(self, x):
if self.layer_id == 0:
x = self.ln0(x)
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(self.ln1(x))
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class RWKV_GPT(nn.Module):
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len):
global RWKV_CFG
super().__init__()
RWKV_CFG.RUN_DEVICE = RUN_DEVICE
RWKV_CFG.model_type = model_type
RWKV_CFG.vocab_size = vocab_size
RWKV_CFG.n_layer = n_layer
RWKV_CFG.n_embd = n_embd
RWKV_CFG.ctx_len = ctx_len
print('\nloading RWKV-GPT', MODEL_NAME)
self.emb = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
self.ln_out = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
if RWKV_HEAD_QK_DIM > 0:
self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(ctx_len, ctx_len)))
self.ctx_len = ctx_len
self.eval()
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
self.eval()
def forward(self, idx):
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
if RWKV_HEAD_QK_DIM > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
if '32' in os.environ['RWKV_FLOAT_MODE']:
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).bfloat16()
x = self.head(x) + c
else:
x = self.head(x)
return x
############################################################################################################
class RWKV_RNN(): # this is running in FP32 at this moment
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
self.RUN_DEVICE = RUN_DEVICE
self.model_type = model_type
self.n_layer = n_layer
self.n_embd = n_embd
self.ctx_len = ctx_len
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth',
map_location=torch.device(RUN_DEVICE))
for x in w.keys():
w[x] = w[x].float()
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = -torch.exp(w[x])
if DEBUG_TIME and '.time_' in x:
print(x, w[x].squeeze().cpu().numpy())
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
self.pp = {}
self.hk = None
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
target.pp = copy.deepcopy(self.pp)
target.hk = copy.deepcopy(self.hk)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
self.pp = copy.deepcopy(target.pp)
self.hk = copy.deepcopy(target.hk)
def LN(self, xx, w):
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.square(torch.relu(w.key.weight @ xk))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = w.key.weight @ xk
v = w.value.weight @ xv
pp = self.pp[name]
aa = self.aa[name]
bb = self.bb[name]
ww = w.time_first + k
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * v
b = e1 * bb + e2
ww = pp + w.time_decay
p = torch.maximum(ww, k)
e1 = torch.exp(ww - p)
e2 = torch.exp(k - p)
self.aa[name] = e1 * aa + e2 * v
self.bb[name] = e1 * bb + e2
self.pp[name] = p
rwkv = r * a / b
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
for i in range(self.n_layer):
if i == 0:
x = self.LN(x, w.blocks[i].ln0)
if i == 0 and self.model_type == 'RWKV-ffnPre':
x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
else:
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
if RWKV_HEAD_QK_DIM > 0:
if self.hk == None:
self.hk = (w.head_k.weight @ x).unsqueeze(0)
else:
self.hk = torch.cat(
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
if self.hk.shape[0] > self.ctx_len:
self.hk = self.hk[-self.ctx_len:, :]
q = w.head_q.weight @ x
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
for i in range(len(c)):
x[ctx[i]] += c[i]
else:
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
return x

@ -0,0 +1,187 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os
NUM_GPUS = int(os.environ['RWKV_NUM_GPUS'])
USE_WANDB = (int(os.environ['USE_WANDB']) == 1)
from torch.utils.data.dataloader import DataLoader
import torch
from tqdm.auto import tqdm
import logging
import datetime
import math
from pytorch_lightning.lite import LightningLite
import gc
logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
if os.environ['RWKV_FLOAT_MODE'] == 'fp32':
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
class TrainerConfig:
batch_size = 64
learning_rate = 4e-4
betas = (0.9, 0.99)
eps = 1e-8
grad_norm_clip = 1.0
warmup_tokens = 0
final_tokens = 0
epoch_save_frequency = 0
epoch_save_path = 'trained-'
num_workers = 0 # for DataLoader
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
from src.model import GPT, GPTConfig
class Trainer(LightningLite):
def get_run_name(self):
raw_model = self.model.module if hasattr(
self.model, "module") else self.model
cfg = raw_model.config
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
return run_name
def run(self, m_cfg, train_dataset, test_dataset, config):
self.cuda_id = int(str(self.device).strip('cuda:'))
print('[0]')
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=m_cfg.model_type,
n_layer=m_cfg.n_layer, n_embd=m_cfg.n_embd))
print('[1]')
with torch.no_grad():
if m_cfg.LOAD_MODEL:
print('loading', m_cfg.MODEL_NAME)
m2 = torch.load(m_cfg.MODEL_NAME + '.pth', map_location='cpu')
model.load_state_dict(m2)
del m2
model.to(self.device)
self.model = model
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.config = config
self.avg_loss = -1
self.EPOCH_BEGIN = m_cfg.EPOCH_BEGIN
self.steps = self.EPOCH_BEGIN * (len(self.train_dataset) // (config.batch_size // NUM_GPUS))
if self.cuda_id == 0:
log_file = open("mylog.txt", "a")
if USE_WANDB:
print('logging to wandb... (comment it if you don\'t have wandb)')
import wandb # comment this if you don't have wandb
cfg = model.config
for k in config.__dict__:
setattr(cfg, k, config.__dict__[k]) # combine cfg
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
model, config = self.model, self.config
raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config)
model, optimizer = self.setup(model, optimizer)
print('[3]')
def run_epoch(split):
is_train = split == 'train'
model.train(is_train)
data = self.train_dataset if is_train else self.test_dataset
data.idx_begin = self.steps * config.batch_size + 1
data.cuda_id = self.cuda_id
if config.num_workers > 0:
loader = DataLoader(data, shuffle=False, pin_memory=True,
batch_size=config.batch_size // NUM_GPUS,
num_workers=config.num_workers)
else:
loader = DataLoader(data, shuffle=False,
batch_size=config.batch_size // NUM_GPUS,
num_workers=config.num_workers)
pbar = tqdm(enumerate(loader), total=len(
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
loader = self.setup_dataloaders(loader)
gc.collect()
torch.cuda.empty_cache()
for it, (x, y) in pbar:
with torch.set_grad_enabled(is_train):
loss = model(x, y) # forward the model
if os.environ['RWKV_DEEPSPEED'] == '0':
all_loss = [loss.clone()]
else:
all_loss = [loss.clone() for _ in range(NUM_GPUS)]
torch.distributed.all_gather(all_loss, loss)
if is_train: # backprop and update the parameters
model.zero_grad()
self.backward(loss)
# deepspeed will handle gradient_clipping
optimizer.step()
# decay the learning rate based on our progress
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
lr_final_factor = config.lr_final / config.learning_rate
if self.tokens < config.warmup_tokens:
# linear warmup
lr_mult = lr_final_factor + \
(1 - lr_final_factor) * float(self.tokens) / \
float(config.warmup_tokens)
progress = 0
else:
# exponential learning rate decay
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
if progress >= 1:
lr_mult = lr_final_factor
else:
lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1))
lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups:
param_group['lr'] = lr
self.lr = lr
self.steps += 1
now_loss = 0
for gg in range(NUM_GPUS):
now_loss += all_loss[gg].item()
now_loss = now_loss / NUM_GPUS # report progress
if USE_WANDB and self.cuda_id == 0:
wandb.log({"loss": now_loss}, step = self.steps)
if self.avg_loss < 0:
self.avg_loss = now_loss
else:
factor = 1 / (it + 1)
self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor
pbar.set_description(f"miniE {epoch+1+self.EPOCH_BEGIN} s {self.steps} prog {progress*100.0:.2f}% : ppl {math.exp(self.avg_loss):.6f} loss {self.avg_loss:.6f} lr {lr:e}")
self.tokens = 0 # counter used for learning rate decay
for epoch in range(99999999):
run_epoch('train')
if math.isnan(self.avg_loss):
exit(0)
if self.cuda_id == 0:
log_file.write(f'{epoch+1+self.EPOCH_BEGIN} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} {epoch+1} \n')
log_file.flush()
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
raw_model = self.model.module if hasattr(self.model, "module") else self.model
torch.save(raw_model.state_dict(), self.config.epoch_save_path + str(epoch+1+self.EPOCH_BEGIN) + '.pth')

@ -0,0 +1,153 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os
try:
NUM_GPUS = int(os.environ['RWKV_NUM_GPUS'])
except:
NUM_GPUS = 1
import json
import random
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data import Dataset
class Dataset(Dataset):
def __init__(self, data, ctx_len, epoch_length_fixed):
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.data = data
if 'MMapIndexedDataset' in str(type(self.data)):
self.vocab_size = int(os.environ['VOCAB_SIZE'])
print('current vocab size =', self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data._bin_buffer) // 2
print(f'data has {self.data_size} tokens.')
elif 'numpy' in str(type(self.data)):
self.vocab_size = int(os.environ['VOCAB_SIZE'])
print('current vocab size =', self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data)
print(f'data has {self.data_size} tokens.')
else:
print('building token list...', end=' ')
unique = sorted(list(set(data)))
self.vocab_size = len(unique)
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
self.data_size = len(self.data)
print('data has %d tokens, %d unique.' % (self.data_size, self.vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
def __len__(self):
return self.epoch_length_fixed // NUM_GPUS
def __getitem__(self, idx):
#
# we are cheating: pick a random spot in dataset
#
i = np.random.randint(0, self.data_size - (self.ctx_len + 1))
if 'MMapIndexedDataset' in str(type(self.data)):
dix = self.data.get(idx=0, offset=i, length=self.ctx_len + 1).astype(int)
elif 'numpy' in str(type(self.data)):
dix = self.data[i:i+self.ctx_len+1]
else:
dix = [self.stoi[s] for s in self.data[i:i+self.ctx_len+1]]
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
return x, y
class TOKENIZER():
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
if 'list' in str(type(WORD_NAME)):
self.charMode = False
if WORD_NAME[0] == WORD_NAME[1]:
from transformers import PreTrainedTokenizerFast
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
else:
from transformers import GPT2TokenizerFast
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
self.vocab_size = len(self.tokenizer)
else:
self.charMode = True
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
self.word_table = json.load(result_file)
self.vocab_size = len(self.word_table)
self.stoi = {v: int(k) for k, v in self.word_table.items()}
self.itos = {int(k): v for k, v in self.word_table.items()}
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
# out[self.UNKNOWN_CHAR] = -float('Inf')
lastChar = int(x[-1])
probs = F.softmax(torch.tensor(out), dim=-1)
if self.charMode:
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
else:
top_p = top_p_usual
sorted_probs, s_index = torch.sort(probs, descending=True)
# for j in range(30):
# pp = sorted_probs[j].item()
# if pp < 0.005:
# break
# ss = self.itos[int(s_index[j])].replace('\n','_')
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
# print('')
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
return torch.multinomial(probs, num_samples=1)[0]
def to_float(x):
return x.cpu().detach().numpy().flatten()[0].astype(float)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

@ -0,0 +1,280 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os
import logging, types
from src.utils import Dataset
import torch
import numpy as np
from src.binidx import MMapIndexedDataset
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
# if False: # True False ---> Set to False if you don't understand it
# print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n")
# import src.utils
# src.utils.set_seed(42) # make training deterministic (including dataloader). if you are doing this, remember to change seed when you load a model (otherwise the dataloader loads old samples)
########################################################################################################
# Step 1: set training data & cfg
########################################################################################################
EXPRESS_PILE_MODE = False # True: express mode for fine-tuning a pile model // False: usual training
EXPRESS_PILE_MODEL_NAME = 'RWKV-4-Pile-169M-20220807-8023'
EXPRESS_PILE_MODEL_TYPE = 'RWKV-4-Pile-169M'
# EXPRESS_PILE_MODEL_NAME = 'RWKV-4-Pile-430M-20220808-8066'
# EXPRESS_PILE_MODEL_TYPE = 'RWKV-4-Pile-430M'
# EXPRESS_PILE_MODEL_NAME = 'RWKV-4-Pile-1B5-20220903-8040'
# EXPRESS_PILE_MODEL_TYPE = 'RWKV-4-Pile-1B5'
########################################################################################################
datafile = "../data/enwik8" # your data
datafile_encoding = 'utf-8' # 'utf-8' / 'utf-16le' / 'numpy' (for fine-tuning pile models) / 'binidx' (the Megatron-LM 'binidx' format)
# datafile = 'my-gpt_seq_document'
# datafile_encoding = 'binidx'
if EXPRESS_PILE_MODE:
datafile = 'train.npy' # use 'prepare-data.py' in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/tree/main/RWKV-v3 to tokenize .txt into .npy
datafile_encoding = 'numpy'
#
# set VOCAB_SIZE = 0 (auto-compute) if you are training a char-level LM from scratch
# set VOCAB_SIZE = 50277 for fine-tuning pile models
# set VOCAB_SIZE = your_vocab_size for 'binidx' data
#
os.environ['VOCAB_SIZE'] = '0'
if EXPRESS_PILE_MODE:
os.environ['VOCAB_SIZE'] = '50277'
#
# Currently it's slow to initialize a new model. Hence I suggest this procedure for multi-GPU training:
# 1) set RWKV_NUM_GPUS = '1' and let it run for 1 miniEpoch and it will save a trained-1.pth
# 2) set RWKV_NUM_GPUS = '8' (or your #GPU), batch_size = single_gpu_batchsz * RWKV_NUM_GPUS,
# EPOCH_BEGIN = 1, LOAD_MODEL = True, and it will load 'trained-1.pth' and continue the training from it
#
os.environ['RWKV_NUM_GPUS'] = '1' # num of GPUs to use
#
# 'bf16' (fast & stable)
# 'fp16' (fast & will overflow after training a large model for very long. can be solved in the future)
# 'tf32' (decent speed & stable)
# 'fp32' (!!!very slow!!! only for verification)
os.environ['RWKV_FLOAT_MODE'] = 'bf16'
os.environ['RWKV_DEEPSPEED'] = '1' # Use DeepSpeed? 0 = False, 1 = True
if int(os.environ['RWKV_NUM_GPUS']) == 1: # Usually you don't need DeepSpeed for 1 GPU training.
os.environ['RWKV_DEEPSPEED'] = '0' # However, sometimes DeepSpeed saves VRAM even for 1 GPU training. So you shall try it.
os.environ['USE_WANDB'] = '0' # wandb logging. 0 = False, 1 = True
########################################################################################################
# Step 2: set model details
########################################################################################################
EPOCH_BEGIN = 0 # begins with miniEpoch = EPOCH_BEGIN
LOAD_MODEL = False # shall we load the #EPOCH_BEGIN model and continue the training from it?
n_layer = 6
n_embd = 512
ctx_len = 1024 # increase T_MAX in src/model.py if your ctx_len is longer
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' (sometimes better)
# there is also a RWKV_HEAD_QK_DIM in model.py and model_run.py
# set it to 256, then it's using my headQK trick (a tiny attention) to improve loss
# set it to 0, then it's a pure RNN (attention-free)
if EXPRESS_PILE_MODE:
LOAD_MODEL = True
if EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-169M':
n_layer = 12
n_embd = 768
ctx_len = 1024
elif EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-430M':
n_layer = 24
n_embd = 1024
ctx_len = 1024
elif EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-1B5':
n_layer = 24
n_embd = 2048
ctx_len = 1024
########################################################################################################
# Step 3: set batch size & learning rate etc.
########################################################################################################
# if you see "CUDA out of memory", reduce batch_size. Use nvidia-smi to find the highest value for your GPU.
batch_size = 12 * int(os.environ['RWKV_NUM_GPUS'])
assert (batch_size % int(os.environ['RWKV_NUM_GPUS']) == 0)
# By default we are using exponential LR decay.
# Here are my suggestions for training.
# Let's say you are training a L6-D512 model.
# 1) Set lr_init = lr_final = 8e-4. Let it run for some mini-epochs, until you feel like reducing LR.
# 2) Check epoch_save_frequency and make sure the partially-trained model is saved. Ctrl+C to stop the run.
# 3) Set lr_init = 8e-4, lr_final = 1e-5, betas = (0.9, 0.999).
# 4) Set EPOCH_BEGIN & LOAD_MODEL to load the partially-trained model. Continue the training.
#
# For L12-D768, set lr_init = 6e-4. For L24-D1024, set lr_init = 4e-4. For L24-D2048, set lr_init = 3e-4.
lr_init = 8e-4
lr_final = 1e-5
# the mini-epoch is very short and of fixed length (length = ctx_len * epoch_length_fixed tokens)
n_epoch = 500
epoch_length_fixed = (10000 // batch_size) * batch_size # feel free to increase it if you have lots of GPU
# epoch_save_frequency 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, ...
epoch_save_frequency = 10
epoch_save_path = 'trained-'
if EXPRESS_PILE_MODE:
if EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-169M':
lr_init = 2e-5
else:
lr_init = 1e-5
lr_final = 1e-5
n_epoch = 100000
### misc stuffs ########################################################################################
if LOAD_MODEL and EPOCH_BEGIN > 0: # we are not saving gradients, so let's have some warmup if we load a model
warmup_tokens = 50 * ctx_len * batch_size // NUM_GPUS
else:
warmup_tokens = 0
betas = (0.9, 0.99) # set betas = (0.9, 0.999) if your model has been trained for a while
eps = 1e-8
num_workers = 1 # DataLoader worker. I only tested num_workers = 1
NUM_GPUS = int(os.environ['RWKV_NUM_GPUS'])
os.environ['RWKV_LOAD_MODEL'] = str(LOAD_MODEL)
MODEL_NAME = epoch_save_path + str(EPOCH_BEGIN)
if EXPRESS_PILE_MODE:
betas = (0.9, 0.999)
MODEL_NAME = EXPRESS_PILE_MODEL_NAME
torch.backends.cudnn.benchmark = True
if os.environ['RWKV_FLOAT_MODE'] == 'fp32':
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
########################################################################################################
# Load data
########################################################################################################
print(f'loading {datafile_encoding} data... ' + datafile)
if datafile_encoding == 'binidx':
train_dataset = Dataset(MMapIndexedDataset(datafile), ctx_len, epoch_length_fixed)
elif datafile_encoding == 'numpy':
train_dataset = Dataset(np.load(datafile).astype('int'), ctx_len, epoch_length_fixed)
else:
train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)
########################################################################################################
# Train model
########################################################################################################
if __name__ == '__main__':
from src.trainer import Trainer, TrainerConfig
print('\nmodel', model_type, os.environ['RWKV_FLOAT_MODE'], 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, '\n')
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps,
warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
m_cfg = types.SimpleNamespace()
m_cfg.model_type = model_type
m_cfg.n_layer = n_layer
m_cfg.n_embd = n_embd
m_cfg.EPOCH_BEGIN = EPOCH_BEGIN
m_cfg.LOAD_MODEL = LOAD_MODEL
m_cfg.MODEL_NAME = MODEL_NAME
if os.environ['RWKV_DEEPSPEED'] == '0':
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=16)
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision='bf16')
elif '32' in os.environ['RWKV_FLOAT_MODE']:
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=32)
else:
from pytorch_lightning.strategies import DeepSpeedStrategy
DEEPSPEED_CFG = {
"zero_allow_untested_optimizer":True,
"zero_optimization":{
"stage":2,
"contiguous_gradients":True,
"overlap_comm":True,
"allgather_partitions":True,
"reduce_scatter":True,
"allgather_bucket_size":200000000,
"reduce_bucket_size":200000000,
"sub_group_size":1000000000000
},
"activation_checkpointing":{
"partition_activations":False,
"cpu_checkpointing":False,
"contiguous_memory_optimization":False,
"synchronize_checkpoint_boundary":False
},
"aio":{
"block_size":1048576,
"queue_depth":8,
"single_submit":False,
"overlap_events":True,
"thread_count":1
},
"gradient_clipping": 1.0,
"gradient_accumulation_steps": 1,
}
if NUM_GPUS == 1:
DEEPSPEED_CFG['zero_optimization'] = {
"stage":1, # saves some VRAM
"contiguous_gradients":False,
"overlap_comm":False,
"allgather_partitions":False,
"reduce_scatter":False,
"allgather_bucket_size":200000000,
"reduce_bucket_size":200000000,
"sub_group_size":1000000000000
}
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
DEEPSPEED_CFG["fp16"] = {
"fp16": True,
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 12,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
}
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=16)
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
DEEPSPEED_CFG["bf16"] = {
"enabled": True
}
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16')
elif '32' in os.environ['RWKV_FLOAT_MODE']:
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=32)
print(trainer._strategy.config)
trainer.run(m_cfg, train_dataset, None, tconf)

@ -0,0 +1,90 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
# this is for verifying the results of different models and make sure they agree with each other
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)
os.environ['RWKV_RUN_DEVICE'] = 'cuda'
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
import torch
from src.model_run import RWKV_RNN, RWKV_GPT
from src.model import GPT, GPTConfig
TOKEN_MODE = 'pile' # char / pile
if TOKEN_MODE == 'char':
MODEL_NAME = 'trained-1'
WORD_NAME = 'vocab' # the .json vocab (generated by train.py)
ctx_len = 1024
n_layer = 6
n_embd = 512
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
elif TOKEN_MODE == 'pile':
WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json']
MODEL_NAME = 'RWKV-4-Pile-169M-20220807-8023'
ctx_len = 1024
n_layer = 12
n_embd = 768
UNKNOWN_CHAR = None
model_type = 'RWKV'
from src.utils import TOKENIZER
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
if TOKEN_MODE == 'pile':
tokenizer.vocab_size = 50277
########################################################################################################
model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda()
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
model_train = model_train.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
model_train = model_train.bfloat16()
print('loading ' + MODEL_NAME)
m2 = torch.load(MODEL_NAME + '.pth', map_location=RUN_DEVICE)
model_train.load_state_dict(m2)
model_rnn = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
model_gpt = RWKV_GPT(MODEL_NAME, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len).cuda()
########################################################################################################
# context = '\nIn a'
context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.'
if TOKEN_MODE == 'char':
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
elif TOKEN_MODE == 'pile':
ctx = tokenizer.tokenizer.encode(context)
print(f'input len {len(ctx)} data {ctx}')
########################################################################################################
print('\nRWKV-GPT output')
out = model_gpt.forward(torch.tensor(ctx).unsqueeze(0).cuda())[0].detach().cpu().numpy()
print(out)
print('\nRWKV-RNN output')
model_rnn.clear()
src_len = len(ctx)
for i in range(src_len):
x = ctx[:i+1]
out = model_rnn.run(x)
if i < 3 or i >= src_len - 3:
print(torch.tensor(out).detach().cpu().numpy())
if i == 2:
print('...')
print('\nRWKV-train output')
out = model_train.forward(torch.tensor([ctx]).cuda())[0][0].detach().cpu().float().numpy()
print(out, '\n')

File diff suppressed because it is too large Load Diff

@ -0,0 +1,361 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
print('Loading...')
from src.model_run import RWKV_RNN
import numpy as np
import os, copy, types, gc, sys
import torch
from src.utils import TOKENIZER
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
CHAT_LANG = 'English' # English Chinese
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
args = types.SimpleNamespace()
args.RUN_DEVICE = "cuda" # 'cpu' (already very fast) // 'cuda'
args.FLOAT_MODE = "fp16" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate)
args.vocab_size = 50277
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230108-5170'
args.n_layer = 40
args.n_embd = 5120
args.ctx_len = 1024
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
# args.n_layer = 32
# args.n_embd = 4096
# args.ctx_len = 1024
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023'
# args.n_layer = 32
# args.n_embd = 2560
# args.ctx_len = 1024
if CHAT_LANG == 'English':
user = "User"
bot = "Bot"
interface = ":"
# The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
# The following is a conversation between a highly knowledgeable and intelligent AI called {bot}, and a human called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} do its best to answer {user}'s questions. {bot} is respectful, polite and inclusive. {bot} knows a lot, and always tells the truth.
init_prompt = f'''
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
{user}{interface} french revolution what year
{bot}{interface} The French Revolution started in 1789, and lasted 10 years until 1799.
{user}{interface} 3+5=?
{bot}{interface} The answer is 8.
{user}{interface} guess i marry who ?
{bot}{interface} Only if you tell me more about yourself - what are your interests?
{user}{interface} solve for a: 9-a=2
{bot}{interface} The answer is a = 7, because 9 - 7 = 2.
{user}{interface} wat is lhc
{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
'''
HELP_MSG = '''Commands:
say something --> chat with bot. use \\n for new line.
+alt --> alternate chat reply
+reset --> reset chat
+gen YOUR PROMPT --> free generation with any prompt. use \\n for new line.
+qa YOUR QUESTION --> free generation - ask any question (just ask the question). use \\n for new line.
+more --> continue last free generation (only for +gen / +qa)
+retry --> retry last free generation (only for +gen / +qa)
Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results.
This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +gen for free generation.
'''
elif CHAT_LANG == 'Chinese':
args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run3z/rwkv-293'
args.n_layer = 32
args.n_embd = 4096
args.ctx_len = 1024
user = "Q"
bot = "A"
interface = ":"
init_prompt = '''
Q: 企鹅会飞吗
A: 企鹅是不会飞的它们的翅膀主要用于游泳和平衡而不是飞行
Q: 西瓜是什么
A: 西瓜是一种常见的水果是一种多年生蔓生藤本植物西瓜的果实呈圆形或卵形通常是绿色的里面有红色或黄色的肉和很多的籽西瓜味甜多吃可以增加水分是夏季非常受欢迎的水果之一
'''
HELP_MSG = '''指令:
直接输入内容 --> 和机器人聊天\\n代表换行
+alt --> 让机器人换个回答
+reset --> 重置对话
+gen 某某内容 --> 续写任何中英文内容\\n代表换行
+qa 某某问题 --> 问独立的问题忽略上下文\\n代表换行
+more --> 继续 +gen / +qa 的回答
+retry --> 换个 +gen / +qa 的回答
现在可以输入内容和机器人聊天注意它不怎么懂中文它可能更懂英文请经常使用 +reset 重置机器人记忆
'''
# Load Model
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
MODEL_NAME = args.MODEL_NAME
print(f'loading... {MODEL_NAME}')
model = RWKV_RNN(args)
model_tokens = []
current_state = None
########################################################################################################
def run_rnn(tokens, newline_adj = 0):
global model_tokens, current_state
for i in range(len(tokens)):
model_tokens += [int(tokens[i])]
if i == len(tokens) - 1:
out, current_state = model.forward(model_tokens, current_state)
else:
current_state = model.forward(model_tokens, current_state, preprocess_only = True)
# print(f'### model ###\n[{tokenizer.tokenizer.decode(model_tokens)}]')
out[0] = -999999999 # disable <|endoftext|>
out[187] += newline_adj
# if newline_adj > 0:
# out[15] += newline_adj / 2 # '.'
return out
all_state = {}
def save_all_stat(srv, name, last_out):
n = f'{name}_{srv}'
all_state[n] = {}
all_state[n]['out'] = last_out
all_state[n]['rnn'] = copy.deepcopy(current_state)
all_state[n]['token'] = copy.deepcopy(model_tokens)
def load_all_stat(srv, name):
global model_tokens, current_state
n = f'{name}_{srv}'
current_state = copy.deepcopy(all_state[n]['rnn'])
model_tokens = copy.deepcopy(all_state[n]['token'])
return all_state[n]['out']
########################################################################################################
# Run inference
print(f'\nRun prompt...')
out = run_rnn(tokenizer.tokenizer.encode(init_prompt))
gc.collect()
torch.cuda.empty_cache()
save_all_stat('', 'chat_init', out)
srv_list = ['dummy_server']
for s in srv_list:
save_all_stat(s, 'chat', out)
print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n')
def reply_msg(msg):
print(f'{bot}{interface} {msg}\n')
def on_message(message):
global model_tokens, current_state
srv = 'dummy_server'
msg = message.replace('\\n','\n').strip()
if len(msg) > 1000:
reply_msg('your message is too long (max 1000 tokens)')
return
x_temp = 1.0
x_top_p = 0.85
if ("-temp=" in msg):
x_temp = float(msg.split("-temp=")[1].split(" ")[0])
msg = msg.replace("-temp="+f'{x_temp:g}', "")
# print(f"temp: {x_temp}")
if ("-top_p=" in msg):
x_top_p = float(msg.split("-top_p=")[1].split(" ")[0])
msg = msg.replace("-top_p="+f'{x_top_p:g}', "")
# print(f"top_p: {x_top_p}")
if x_temp <= 0.2:
x_temp = 0.2
if x_temp >= 5:
x_temp = 5
if x_top_p <= 0:
x_top_p = 0
if msg == '+reset':
out = load_all_stat('', 'chat_init')
save_all_stat(srv, 'chat', out)
reply_msg("Chat reset.")
return
elif msg[:5].lower() == '+gen ' or msg[:4].lower() == '+qa ' or msg.lower() == '+more' or msg.lower() == '+retry':
if msg[:5].lower() == '+gen ':
new = '\n' + msg[5:].strip()
# print(f'### prompt ###\n[{new}]')
current_state = None
out = run_rnn(tokenizer.tokenizer.encode(new))
save_all_stat(srv, 'gen_0', out)
elif msg[:4].lower() == '+qa ':
out = load_all_stat('', 'chat_init')
real_msg = msg[4:].strip()
new = f"{user}{interface} {real_msg}\n\n{bot}{interface}"
# print(f'### qa ###\n[{new}]')
out = run_rnn(tokenizer.tokenizer.encode(new))
save_all_stat(srv, 'gen_0', out)
# new = f"\nThe following is an excellent Q&A session consists of detailed and factual information.\n\nQ: What is 3+5?\nA: The answer is 8.\n\nQ: {msg[9:].strip()}\nA:"
# print(f'### prompt ###\n[{new}]')
# current_state = None
# out = run_rnn(tokenizer.tokenizer.encode(new))
# save_all_stat(srv, 'gen_0', out)
elif msg.lower() == '+more':
try:
out = load_all_stat(srv, 'gen_1')
save_all_stat(srv, 'gen_0', out)
except:
return
elif msg.lower() == '+retry':
try:
out = load_all_stat(srv, 'gen_0')
except:
return
begin = len(model_tokens)
out_last = begin
for i in range(150):
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p_usual=x_top_p,
top_p_newline=x_top_p,
)
if msg[:4].lower() == '+qa ':
out = run_rnn([token], newline_adj=-1)
else:
out = run_rnn([token])
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx:
print(xxx, end='', flush=True)
out_last = begin + i + 1
print('\n')
# send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
# print(f'### send ###\n[{send_msg}]')
# reply_msg(send_msg)
save_all_stat(srv, 'gen_1', out)
else:
if msg.lower() == '+alt':
try:
out = load_all_stat(srv, 'chat_pre')
except:
return
else:
out = load_all_stat(srv, 'chat')
new = f"{user}{interface} {msg}\n\n{bot}{interface}"
# print(f'### add ###\n[{new}]')
out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999)
save_all_stat(srv, 'chat_pre', out)
begin = len(model_tokens)
out_last = begin
print(f'{bot}{interface}', end='', flush=True)
for i in range(999):
if i <= 0:
newline_adj = -999999999
elif i <= 30:
newline_adj = (i - 30) / 10
elif i <= 130:
newline_adj = 0
else:
newline_adj = (i - 130) * 0.25 # MUST END THE GENERATION
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p_usual=x_top_p,
top_p_newline=x_top_p,
)
out = run_rnn([token], newline_adj=newline_adj)
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx:
print(xxx, end='', flush=True)
out_last = begin + i + 1
send_msg = tokenizer.tokenizer.decode(model_tokens[begin:])
if '\n\n' in send_msg:
send_msg = send_msg.strip()
break
# send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
# if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!!
# send_msg = send_msg[:-len(f'{user}{interface}')].strip()
# break
# if send_msg.endswith(f'{bot}{interface}'):
# send_msg = send_msg[:-len(f'{bot}{interface}')].strip()
# break
# print(f'{model_tokens}')
# print(f'[{tokenizer.tokenizer.decode(model_tokens)}]')
# print(f'### send ###\n[{send_msg}]')
# reply_msg(send_msg)
save_all_stat(srv, 'chat', out)
print(HELP_MSG)
while True:
msg = input(f'{user}{interface} ')
if len(msg.strip()) > 0:
on_message(msg)
else:
print('Erorr: please say something')

@ -0,0 +1,133 @@
#include <stdio.h>
#include <assert.h>
#define MIN_VALUE (-1e38)
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
F *__restrict__ const _y) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
F *__restrict__ const y = _y + _offset;
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
F aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
const F kk = k[ii];
const F vv = v[ii];
F ww = u + kk;
F p = max(pp, ww);
F e1 = exp(pp - p);
F e2 = exp(ww - p);
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
}
template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
const F *__restrict__ const _y, const F *__restrict__ const _gy,
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
const F *__restrict__ const y = _y + _offset;
const F *__restrict__ const gy = _gy + _offset;
F *__restrict__ const gk = _gk + _offset;
F *__restrict__ const gv = _gv + _offset;
F q[Tmax], r[Tmax];
F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
const F kk = k[ii];
const F vv = v[ii];
const F yy = y[ii];
F ww = u + kk;
F p = max(pp, ww);
F e1 = exp(pp - p);
F e2 = exp(ww - p);
const F qq = gy[ii] / (e1 * bb + e2);
gw += (ga - gb * yy) * e1 * qq;
gu += (vv - yy) * e2 * qq;
q[i] = qq;
r[i] = ww - p;
ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
ga = e1 * (aa + ga);
gb = e1 * (bb + gb);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
_gu[_offsetBC] = gu;
aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) {
const int ii = i * C;
const F kk = k[ii];
const F vv = v[ii];
const F yy = y[ii];
const F qq = q[i];
const F rr = r[i];
F e1 = qq * exp(rr);
F e2 = exp(kk + pp);
gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
gv[ii] = e1 + e2 * aa;
const F ww = w + pp;
const F www = rr - u - kk;
const F p = max(ww, www);
e1 = exp(ww - p);
e2 = qq * exp(www - p);
aa = e1 * aa + e2;
bb = e1 * bb - e2 * yy;
pp = p;
}
}
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
}
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
}

@ -0,0 +1,132 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
#define MIN_VALUE (-1e38)
typedef at::BFloat16 bf16;
__global__ void kernel_forward(const int B, const int T, const int C,
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
bf16 *__restrict__ const _y) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
float u = float(_u[_c]);
float w = _w[_c];
const bf16 *__restrict__ const k = _k + _offset;
const bf16 *__restrict__ const v = _v + _offset;
bf16 *__restrict__ const y = _y + _offset;
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
float aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
const float kk = float(k[ii]);
const float vv = float(v[ii]);
float ww = u + kk;
float p = max(pp, ww);
float e1 = exp(pp - p);
float e2 = exp(ww - p);
y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
}
__global__ void kernel_backward(const int B, const int T, const int C,
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy,
bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
float u = float(_u[_c]);
float w = _w[_c];
const bf16 *__restrict__ const k = _k + _offset;
const bf16 *__restrict__ const v = _v + _offset;
const bf16 *__restrict__ const y = _y + _offset;
const bf16 *__restrict__ const gy = _gy + _offset;
bf16 *__restrict__ const gk = _gk + _offset;
bf16 *__restrict__ const gv = _gv + _offset;
float q[Tmax], r[Tmax];
float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
const float kk = float(k[ii]);
const float vv = float(v[ii]);
const float yy = float(y[ii]);
float ww = u + kk;
float p = max(pp, ww);
float e1 = exp(pp - p);
float e2 = exp(ww - p);
const float qq = float(gy[ii]) / (e1 * bb + e2);
gw += (ga - gb * yy) * e1 * qq;
gu += (vv - yy) * e2 * qq;
q[i] = qq;
r[i] = ww - p;
ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
ga = e1 * (aa + ga);
gb = e1 * (bb + gb);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
_gu[_offsetBC] = bf16(gu);
aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) {
const int ii = i * C;
const float kk = float(k[ii]);
const float vv = float(v[ii]);
const float yy = float(y[ii]);
const float qq = q[i];
const float rr = r[i];
float e1 = qq * exp(rr);
float e2 = exp(kk + pp);
gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
gv[ii] = bf16(e1 + e2 * aa);
const float ww = w + pp;
const float www = rr - u - kk;
const float p = max(ww, www);
e1 = exp(ww - p);
e2 = qq * exp(www - p);
aa = e1 * aa + e2;
bb = e1 * bb - e2 * yy;
pp = p;
}
}
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
}
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
}

@ -0,0 +1,21 @@
#include <torch/extension.h>
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv forward");
m.def("backward", &backward, "wkv backward");
}
TORCH_LIBRARY(wkv, m) {
m.def("forward", forward);
m.def("backward", backward);
}

@ -0,0 +1,25 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
typedef at::BFloat16 bf16;
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv forward");
m.def("backward", &backward, "wkv backward");
}
TORCH_LIBRARY(wkv, m) {
m.def("forward", forward);
m.def("backward", backward);
}

@ -0,0 +1,165 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import torch, types, os
import numpy as np
from PIL import Image
import torch.nn as nn
from torch.nn import functional as F
import torchvision as vision
import torchvision.transforms as transforms
np.set_printoptions(precision=4, suppress=True, linewidth=200)
print(f'loading...')
########################################################################################################
model_prefix = 'test/image_trained/out-v7c_d8_256-224-13bit-OB32x0.5-201'
input_img = 'test/img_ae_test/test0.png'
########################################################################################################
class ToBinary(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone() # pass-through
class R_ENCODER(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
dd = 8
self.Bxx = nn.BatchNorm2d(dd*64)
self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*4)
self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*64)
self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
def forward(self, img):
ACT = F.mish
x = self.CIN(img)
xx = self.Bxx(F.pixel_unshuffle(x, 8))
x = x + self.Cx1(ACT(self.Cx0(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
x = x + self.C13(ACT(self.C12(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
x = self.COUT(x + xx)
return torch.sigmoid(x)
class R_DECODER(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
dd = 8
self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*64)
self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*4)
self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
def forward(self, code):
ACT = F.mish
x = self.CIN(code)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
x = x + self.C13(ACT(self.C12(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.Cx1(ACT(self.Cx0(x)))
x = self.COUT(x)
return torch.sigmoid(x)
########################################################################################################
print(f'building model...')
args = types.SimpleNamespace()
args.my_img_bit = 13
encoder = R_ENCODER(args).eval().cuda()
decoder = R_DECODER(args).eval().cuda()
zpow = torch.tensor([2**i for i in range(0,13)]).reshape(13,1,1).cuda().long()
encoder.load_state_dict(torch.load(f'{model_prefix}-E.pth'))
decoder.load_state_dict(torch.load(f'{model_prefix}-D.pth'))
########################################################################################################
print(f'test image...')
img_transform = transforms.Compose([
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Resize((224, 224))
])
with torch.no_grad():
img = img_transform(Image.open(input_img)).unsqueeze(0).cuda()
z = encoder(img)
z = ToBinary.apply(z)
zz = torch.sum(z.squeeze().long() * zpow, dim=0)
print(f'Code shape = {zz.shape}\n{zz.cpu().numpy()}\n')
out = decoder(z)
vision.utils.save_image(out, f"{input_img.split('.')[0]}-out-13bit.jpg")

@ -0,0 +1,237 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math, os, sys, types, time, gc
import torch
from src.utils import TOKENIZER
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
args = types.SimpleNamespace()
########################################################################################################
# Step 1: set model & config (use v4 to run your trained-from-scratch models. v4 and v4neo are compatible)
########################################################################################################
args.RUN_DEVICE = "cuda" # 'cuda' // 'cpu' (already fast)
args.FLOAT_MODE = "fp16" # fp16 (good for GPU, does not work for CPU) // fp32 (good for CPU) // bf16 (less accurate, but works for CPU)
# if args.RUN_DEVICE == "cuda":
# os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output
os.environ["RWKV_JIT_ON"] = '1' # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!!
TOKEN_MODE = "pile"
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None
vocab_size = 50277
# Download Pile models: https://huggingface.co/BlinkDL
# or, set MODEL_NAME to your fine-tuned model
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
# n_layer = 12
# n_embd = 768
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066'
# n_layer = 24
# n_embd = 1024
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040'
# n_layer = 24
# n_embd = 2048
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023'
# n_layer = 32
# n_embd = 2560
# ctx_len = 1024
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
n_layer = 32
n_embd = 4096
ctx_len = 1024
args.MODEL_NAME = MODEL_NAME
args.n_layer = n_layer
args.n_embd = n_embd
args.ctx_len = ctx_len
args.vocab_size = vocab_size
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
########################################################################################################
# Step 2: set prompt & sampling stuffs
########################################################################################################
# context = 'A'
# context = "\nIn the"
# context = '\nSugar:'
context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
# context = "\n深圳是" # test Chinese
# context = "\n東京は" # test Japanese
# ###### A good prompt for Q&A ######
# context = '''
# Questions & Helpful Answers
# Ask Research Experts
# Question:
# Can penguins fly?
# Full Answer:
# '''
# ###### A good prompt for chatbot ######
# context = '''
# The following is a conversation between a highly knowledgeable and intelligent AI assistant called Bot, and a human user called User. In the following interactions, User and Bot converse in natural language, and Bot always answer User's questions. Bot is very smart, polite and humorous. Bot knows a lot, and always tells the truth. The conversation begins.
# User: who is president of usa?
# Bot: Its Joe Biden; he was sworn in earlier this year.
# User: french revolution what year
# Bot: It started in 1789, but it lasted 10 years until 1799.
# User: guess i marry who ?
# Bot: Only if you tell me more about yourself - what are your interests?
# User: wat is lhc
# Bot: Its a large and very expensive piece of science equipment. If I understand correctly, its a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
# User:''' # type your question here
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 333
TEMPERATURE = 1.0
top_p = 0.8
top_p_newline = 0.9 # only used in TOKEN_MODE = char
DEBUG_DEBUG = False # True False --> show softmax output
########################################################################################################
print(f'\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...')
from src.model_run import RWKV_RNN
model = RWKV_RNN(args)
print(f'\nOptimizing speed...')
out, _ = model.forward([187], None)
# print(out)
gc.collect()
torch.cuda.empty_cache()
# input(0)
print(f'\nLoading tokenizer {WORD_NAME}...')
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
if TOKEN_MODE == "pile":
assert tokenizer.tokenizer.decode([187]) == '\n'
########################################################################################################
if tokenizer.charMode:
context = tokenizer.refine_context(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
else:
ctx = tokenizer.tokenizer.encode(context)
src_len = len(ctx)
src_ctx = ctx.copy()
print("\nYour prompt has " + str(src_len) + " tokens.")
print(
"Note: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n"
)
time_slot = {}
time_ref = time.time_ns()
def record_time(name):
if name not in time_slot:
time_slot[name] = 1e20
tt = (time.time_ns() - time_ref) / 1e9
if tt < time_slot[name]:
time_slot[name] = tt
init_state = None
init_out = None
state = None
out = None
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
print(("-" * 50) + '\n' + context, end="")
time_ref = time.time_ns()
ctx = src_ctx.copy()
if TRIAL == 0:
for i in range(src_len):
x = ctx[: i + 1]
if i == src_len - 1:
init_out, init_state = model.forward(x, init_state)
else:
init_state = model.forward(x, init_state, preprocess_only=True)
gc.collect()
torch.cuda.empty_cache()
record_time('preprocess')
out_last = src_len
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
x = ctx[: i + 1]
x = x[-ctx_len:]
if i == src_len:
out = init_out.clone()
state = init_state.clone()
else:
out, state = model.forward(x, state)
if DEBUG_DEBUG:
print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy()))
if TOKEN_MODE == "pile":
out[0] = -999999999 # disable <|endoftext|>
ttt = tokenizer.sample_logits(
out,
x,
ctx_len,
temperature=TEMPERATURE,
top_p_usual=top_p,
top_p_newline=top_p_newline,
)
ctx += [ttt]
if tokenizer.charMode:
char = tokenizer.itos[ttt]
print(char, end="", flush=True)
else:
char = tokenizer.tokenizer.decode(ctx[out_last:])
if '\ufffd' not in char: # is valid utf8 string?
print(char, end="", flush=True)
out_last = i+1
record_time('total')
# print(f'\n\n{time_slot}\n\n')
print(
f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = ''
)
print(("-" * 50) + '\n')

@ -0,0 +1,269 @@
from lib2to3.pgen2 import token
import os
import torch
import numpy as np
import shutil
import struct
from functools import lru_cache
from itertools import accumulate
def print_rank_0(*message):
pass
# """If distributed is initialized print only on rank 0."""
# if torch.distributed.is_initialized():
# if torch.distributed.get_rank() == 0:
# print(*message, flush=True)
# else:
# print(*message, flush=True)
def _warmup_mmap_file(path):
pass
# with open(path, "rb") as stream:
# while stream.read(100 * 1024 * 1024):
# pass
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: float,
7: np.double,
8: np.uint16,
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + ".idx"
def data_file_path(prefix_path):
return prefix_path + ".bin"
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b"MMIDIDX\x00\x00"
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, "wb")
# Write Magic string so we can check the file format then opening it again.
self._file.write(cls._HDR_MAGIC)
# Write version number
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", 1))
# Little endian unsigned 8 Bit integer
self._file.write(struct.pack("<B", code(dtype)))
return self
@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers
def write(self, sizes, doc_idx):
pointers = self._get_pointers(sizes)
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", len(sizes)))
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", len(doc_idx)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers
doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order="C"))
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path, skip_warmup=False):
with open(path, "rb") as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
# Little endian unsigned 64 Bit integer
version = struct.unpack("<Q", stream.read(8))
assert (1,) == version
# Little endian unsigned 8 Bit integer
(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack("<Q", stream.read(8))[0]
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()
if not skip_warmup:
print_rank_0(" warming up index mmap file...")
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print_rank_0(" reading sizes...")
self._sizes = np.frombuffer(
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes,
)
print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@property
def doc_idx(self):
return self._doc_idx
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path, skip_warmup=False):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path, skip_warmup)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, path, skip_warmup):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup:
print_rank_0(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print_rank_0(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(
data_file_path(self._path), mode="r", order="C"
)
print_rank_0(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index
def __len__(self):
return len(self._index)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
ptr, size = self._index[idx]
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError(
"Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
)
sents = np.split(np_array, offsets[:-1])
return sents
def get(self, idx, offset=0, length=None):
"""Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
)
return np_array
@property
def sizes(self):
return self._index.sizes
@property
def doc_idx(self):
return self._index.doc_idx
def get_doc_idx(self):
return self._index._doc_idx
def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(
data_file_path(path)
)

@ -0,0 +1,240 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import json, math, random, os, sys
import numpy as np
import torch
from torch.utils.data import Dataset
from pytorch_lightning.utilities import rank_zero_info
from .binidx import MMapIndexedDataset
from .utils import MaybeIsPrime
class MyDataset(Dataset):
def __init__(self, args):
self.args = args
if args.data_type == "binidx":
self.vocab_size = args.vocab_size
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
if args.my_pile_version == 1:
self.data = MMapIndexedDataset(args.data_file)
self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
rank_zero_info(f"Data has {self.data_size} tokens.")
else:
data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n')
data_list = [i.strip().split(' ') for i in data_list]
self.data = []
self.data_size = int(data_list[-1][-1])
rank_zero_info(f"Data has {self.data_size} chunks.")
for d in data_list:
data = MMapIndexedDataset(d[0])
data_size = len(data._bin_buffer) // data._index._dtype_size
assert (data_size - args.ctx_len) == int(d[1])
self.data += [[int(d[-1]), int(d[1]), data]]
# rank_zero_info(self.data)
if args.my_qa_mask > 0:
self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document')
# self.data_pile = MMapIndexedDataset('/fsx/pile_deduped/pile_0.87_deduped_text_document')
self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
if args.my_pile_stage > 0:
# assert self.data_size == 332115325534 and self.vocab_size == 50277
self.samples_per_epoch = args.epoch_steps * args.real_bsz
assert self.samples_per_epoch == 40320
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
dataset_slot = self.data_size // args.ctx_len
if args.my_pile_stage != 4:
assert MaybeIsPrime(args.magic_prime)
assert args.magic_prime % 3 == 2
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data)
rank_zero_info(f"Data has {self.data_size} tokens.")
elif args.data_type == "uint16":
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
self.vocab_size = args.vocab_size
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = self.data.shape[0]
rank_zero_info(f"Data has {self.data_size} samples.")
elif args.data_type == "wds_img":
self.vocab_size = -1
self.data_size = -1
self.data = None
self.error_count = 0
else:
if args.data_type == "dummy":
rank_zero_info("Building dummy data...")
self.data = ""
for i in range(100000):
aa = (i) % 10000
bb = (i * i) % 10000
cc = aa + bb
self.data += f".{aa}+{bb}={cc}."
else:
self.data = open(args.data_file, "r", encoding=args.data_type).read()
rank_zero_info("Building token list...")
unique = sorted(list(set(self.data)))
self.vocab_size = len(unique)
# rank_zero_info()
# for u in unique:
# print(u, end=' ')
# rank_zero_info('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
self.data_size = len(self.data)
rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
def __len__(self):
return self.args.epoch_steps * self.args.micro_bsz
def __getitem__(self, idx):
args = self.args
rank = self.global_rank
epoch = self.real_epoch
world_size = self.world_size
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
if args.data_type == "wds_img":
def init_wds(self, bias=0):
def identity(x):
return x
import webdataset as wds
import torchvision.transforms as transforms
# img_transform = transforms.Compose(
# [transforms.CenterCrop(256)]
# )
img_transform = transforms.Compose([
transforms.CenterCrop(512),
transforms.Resize((args.my_img_size))
])
self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
for pp in self.data_raw.pipeline:
if 'Resampled' in str(pp):
pp.deterministic = True
def worker_seed():
return rank*100000+epoch+bias*1e9
pp.worker_seed = worker_seed
self.data = iter(self.data_raw)
# print(f"WebDataset loaded for rank {rank} epoch {epoch}")
if self.data == None:
init_wds(self)
trial = 0
while trial < 10:
try:
dd = next(self.data) # jpg, json, txt
break
except:
print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]')
self.error_count += 1
init_wds(self, self.error_count)
trial += 1
pass
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}")
# with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp:
# tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n")
return dd[0], dd[2]
else:
if args.data_type == "uint16":
i = np.random.randint(0, self.data_size-1)
dix = self.data[i]
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
else:
ctx_len = args.ctx_len
req_len = ctx_len + 1
magic_prime = args.magic_prime
data = self.data
if args.my_pile_stage > 0 and args.my_pile_stage != 4:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
if args.my_qa_mask > 0:
ii_orig = ii
if ii % 2 == 0:
ii = -1
data = self.data_pile
else:
ii = ii // 2
if ii < 0:
i = np.random.randint(0, self.data_pile_size - req_len)
else:
factor = (math.sqrt(5) - 1) / 2
factor = int(magic_prime * factor)
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
i = i + args.my_pile_shift
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
elif args.my_pile_stage == 4:
# cheat: pick a random spot in dataset
if args.my_pile_version == 1:
i = np.random.randint(0, self.data_size - req_len)
else:
i = np.random.randint(0, self.data_size)
else:
# cheat: pick a random spot in dataset
i = np.random.randint(0, self.data_size - req_len)
if args.data_type == "binidx":
if args.my_pile_version == 1:
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
else:
# self.data : cutoff, chunk_count, data
for j in range(len(data)):
if i < data[j][0]:
ii = i
i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1]
dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int)
# print(ii, j, i)
break
elif args.data_type == "numpy":
dix = data[i : i + req_len]
else:
dix = [self.stoi[s] for s in data[i : i + req_len]]
if args.my_qa_mask == 1:
if data == self.data_pile:
z = [1] * ctx_len
else:
z = [0] * ctx_len
z_sum = 0
isGood = False
for i in range(3, ctx_len):
if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
isGood = True
if dix[i] == 0:
isGood = False
if isGood:
z[i] = 1
z_sum += 1
if z_sum == 0:
z = [1] * ctx_len
i = np.random.randint(0, self.data_pile_size - req_len)
dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
z = torch.tensor(z, dtype=torch.bfloat16)
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
# if ii_orig < 50:
# # if rank == 1:
# print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
# else:
# exit(0)
if args.my_qa_mask == 1:
return x, y, z
return x, y

@ -0,0 +1,610 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os, math, gc, importlib
import torch
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
import torch.nn as nn
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
if importlib.util.find_spec('deepspeed'):
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
try:
print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
except:
os.environ["RWKV_MY_TESTING"] = ''
def __nop(ob):
return ob
MyModule = nn.Module
MyFunction = __nop
if os.environ["RWKV_JIT_ON"] == "1":
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
########################################################################################################
# CUDA Kernel
########################################################################################################
T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
w = -torch.exp(w.float().contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
ctx.save_for_backward(w, u, k, v, y)
return y
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
w, u, k, v, y = ctx.saved_tensors
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw, gu, gk, gv)
else:
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
if "32" in os.environ["RWKV_FLOAT_MODE"]:
w = -torch.exp(w.contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
else:
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
ctx.save_for_backward(w, u, k, v, y)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
return y
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
return y.half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
return y.bfloat16()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
w, u, k, v, y = ctx.saved_tensors
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
return (None, None, None, gw, gu, gk, gv)
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w, u, k, v)
########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
class RWKV_TimeMix(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.ctx_len = args.ctx_len
self.n_embd = args.n_embd
with torch.no_grad(): # fancy init
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
# fancy time_decay
decay_speed = torch.ones(args.dim_att)
for h in range(args.dim_att):
decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
self.time_decay = nn.Parameter(decay_speed)
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
# fancy time_first
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5
self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)
# fancy time_mix
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
if 'a' in os.environ["RWKV_MY_TESTING"]:
self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
d_qkv = args.n_embd // 16
self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
self.vv = nn.Linear(args.n_embd, d_qkv, bias=False)
self.oo = nn.Linear(d_qkv, args.n_embd, bias=False)
with torch.no_grad():
self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
if 'a' not in os.environ["RWKV_MY_TESTING"]:
@MyFunction
def jit_func(self, x):
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)
return sr, k, v
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v = self.jit_func(x)
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
return self.output(rwkv)
if 'a' in os.environ["RWKV_MY_TESTING"]:
@MyFunction
def QKV(self, q, k, v):
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.att_mask == 0, float('-inf'))
att = F.softmax(att, dim = -1)
x = att @ v
return x
@MyFunction
def jit_funcQKV(self, x):
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq)
xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk)
xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv)
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)
qq = self.qq(xqq)
kk = self.kk(xkk)
vv = self.vv(xvv)
return sr, k, v, qq, kk, vv
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
return rwkv
########################################################################################################
class RWKV_ChannelMix(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
@MyFunction
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
return torch.sigmoid(self.receptance(xr)) * kv
class MishGLU(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad():
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
x = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
x[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
@MyFunction
def forward(self, x):
xx = self.time_shift(x)
xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
a = self.aa(xa)
b = self.bb(xb)
return self.value(a * F.mish(b))
########################################################################################################
# The RWKV Model with our blocks
########################################################################################################
class Block(nn.Module):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(args.n_embd)
self.ln2 = nn.LayerNorm(args.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(args.n_embd)
if args.my_pos_emb > 0:
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
if self.layer_id == 0 and self.args.pre_ffn > 0:
self.ffnPre = RWKV_ChannelMix(args, 0)
else:
self.att = RWKV_TimeMix(args, layer_id)
if 'g' in os.environ["RWKV_MY_TESTING"]:
self.ffn = MishGLU(args, layer_id)
else:
self.ffn = RWKV_ChannelMix(args, layer_id)
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
self.tiny_ln = nn.LayerNorm(args.n_embd)
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
def forward(self, x, x_emb=None):
args = self.args
B, T, C = x.size()
if self.layer_id == 0:
x = self.ln0(x)
if args.my_pos_emb > 0:
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
x = x + pos_emb
if self.layer_id == 0 and args.pre_ffn > 0:
x = x + self.ffnPre(self.ln1(x))
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
xx = self.tiny_ln(x)
q = self.tiny_q(xx)[:, :T, :]
k = self.tiny_k(xx)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
x = x + c @ self.tiny_v(x_emb)
return x
class L2Wrap(torch.autograd.Function):
@staticmethod
def forward(ctx, loss, y):
ctx.save_for_backward(y)
return loss
@staticmethod
def backward(ctx, grad_output):
y = ctx.saved_tensors[0]
# to encourage the logits to be close to 0
factor = 1e-4 / (y.shape[0] * y.shape[1])
maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y)
gy.scatter_(-1, ids, maxx * factor)
return (grad_output, gy)
class RWKV(pl.LightningModule):
def __init__(self, args):
super().__init__()
self.args = args
if not hasattr(args, 'dim_att'):
args.dim_att = args.n_embd
if not hasattr(args, 'dim_ffn'):
args.dim_ffn = args.n_embd * 4
if not hasattr(args, 'tiny_att_layer'):
args.tiny_att_layer = -1
if not hasattr(args, 'tiny_att_dim'):
args.tiny_att_dim = -1
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
self.ln_out = nn.LayerNorm(args.n_embd)
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
if args.head_qk > 0:
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
def configure_optimizers(self):
args = self.args
if args.layerwise_lr > 0:
lr_1x = set()
lr_2x = set()
lr_3x = set()
for n, p in self.named_parameters():
if "time_mix" in n:
if args.my_pile_stage == 2:
lr_2x.add(n)
else:
lr_1x.add(n)
elif "time_decay" in n:
if args.my_pile_stage == 2:
lr_3x.add(n)
else:
lr_2x.add(n)
elif "time_first" in n:
lr_3x.add(n)
else:
lr_1x.add(n)
lr_1x = sorted(list(lr_1x))
lr_2x = sorted(list(lr_2x))
lr_3x = sorted(list(lr_3x))
# print('1x', lr_1x)
# print('2x', lr_2x)
# print('3x', lr_3x)
param_dict = {n: p for n, p in self.named_parameters()}
if args.my_pile_stage == 2:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
]
else:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
]
else:
optim_groups = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
]
if self.deepspeed_offload:
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
@property
def deepspeed_offload(self) -> bool:
strategy = self.trainer.strategy
if isinstance(strategy, DeepSpeedStrategy):
cfg = strategy.config["zero_optimization"]
return cfg.get("offload_optimizer") or cfg.get("offload_param")
return False
def forward(self, idx):
args = self.args
B, T = idx.size()
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
x = self.emb(idx)
x_emb = x
if args.tiny_att_dim > 0:
for block in self.blocks:
if args.grad_cp == 1:
x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
else:
x = block(x, x_emb)
else:
for block in self.blocks:
if args.grad_cp == 1:
x = deepspeed.checkpointing.checkpoint(block, x)
else:
x = block(x)
x = self.ln_out(x)
if args.head_qk > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
x = self.head(x) + c
else:
x = self.head(x)
return x
def training_step(self, batch, batch_idx):
args = self.args
if args.my_qa_mask != 1:
idx, targets = batch
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
else:
idx, targets, mask = batch
mask = mask.view(-1)
sum_mask = torch.sum(mask).item()
# if sum_mask == 0:
# return torch.tensor([0.0], requires_grad=True)
logits = self(idx)
if sum_mask == mask.shape[0]:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
# print('rank', self.global_rank, 'loss', loss.item())
else:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
# loss_raw = loss
loss = torch.sum(loss * mask) / sum_mask
# torch.set_printoptions(threshold=10000)
# if True: #self.global_rank == 1:
# tmp = ''
# sss = 0
# ccc = 0
# for i in range(mask.shape[0]):
# if mask[i] > 0:
# tmp += str(idx.view(-1)[i].item()) + ','
# sss += loss_raw.view(-1)[i].float().item()
# ccc += 1
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
return L2Wrap.apply(loss, logits)
def training_step_end(self, batch_parts):
all = self.all_gather(batch_parts)
if self.trainer.is_global_zero:
self.trainer.my_loss_all = all
def generate_init_weight(self):
print(
f"""
############################################################################
#
# Init model weight (slow for large models)...
#
############################################################################
"""
)
m = {}
for n in self.state_dict():
p = self.state_dict()[n]
shape = p.shape
gain = 1.0
scale = 1.0
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
m[n] = p
else:
if n == "emb.weight":
scale = -1 * self.args.lr_init
else:
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
if kk in n:
scale = 0
if n == "head.weight":
scale = 0.5
if "head_k." in n:
scale = 0.1
if "head_q." in n:
scale = 0
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
if self.args.accelerator.upper() == "GPU":
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
else:
m[n] = torch.empty((shape[0], shape[1]))
if scale == 0:
nn.init.zeros_(m[n])
elif scale < 0:
nn.init.uniform_(m[n], a=scale, b=-scale)
else:
nn.init.orthogonal_(m[n], gain=gain * scale)
m[n] = m[n].cpu()
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
m[n] = m[n].half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
m[n] = m[n].bfloat16()
# if n == "emb.weight":
# print(m[n])
gc.collect()
torch.cuda.empty_cache()
return m

@ -0,0 +1,446 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import os, math, gc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as vision
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from pytorch_msssim import MS_SSIM
def __nop(ob):
return ob
MyModule = torch.jit.ScriptModule
# MyFunction = __nop
MyFunction = torch.jit.script_method
import clip
from transformers import CLIPModel
class L2pooling(nn.Module):
def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
super(L2pooling, self).__init__()
self.padding = (filter_size - 2) // 2
self.stride = stride
self.channels = channels
a = np.hanning(filter_size)[1:-1]
g = torch.Tensor(a[:, None] * a[None, :])
g = g / torch.sum(g)
self.register_buffer(
"filter", g[None, None, :, :].repeat((self.channels, 1, 1, 1))
)
def forward(self, input):
input = input**2
out = F.conv2d(
input,
self.filter,
stride=self.stride,
padding=self.padding,
groups=input.shape[1],
)
return (out + 1e-12).sqrt()
class DISTS(torch.nn.Module):
def __init__(self, load_weights=True):
super(DISTS, self).__init__()
vgg_pretrained_features = vision.models.vgg16(
weights="VGG16_Weights.IMAGENET1K_V1"
).features
self.stage1 = torch.nn.Sequential()
self.stage2 = torch.nn.Sequential()
self.stage3 = torch.nn.Sequential()
self.stage4 = torch.nn.Sequential()
self.stage5 = torch.nn.Sequential()
for x in range(0, 4):
self.stage1.add_module(str(x), vgg_pretrained_features[x])
self.stage2.add_module(str(4), L2pooling(channels=64))
for x in range(5, 9):
self.stage2.add_module(str(x), vgg_pretrained_features[x])
self.stage3.add_module(str(9), L2pooling(channels=128))
for x in range(10, 16):
self.stage3.add_module(str(x), vgg_pretrained_features[x])
self.stage4.add_module(str(16), L2pooling(channels=256))
for x in range(17, 23):
self.stage4.add_module(str(x), vgg_pretrained_features[x])
self.stage5.add_module(str(23), L2pooling(channels=512))
for x in range(24, 30):
self.stage5.add_module(str(x), vgg_pretrained_features[x])
self.register_buffer(
"mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1)
)
self.register_buffer(
"std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1)
)
self.chns = [3, 64, 128, 256, 512, 512]
self.register_buffer(
"alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1))
)
self.register_buffer("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
self.alpha.data.normal_(0.1, 0.01)
self.beta.data.normal_(0.1, 0.01)
weights = torch.load("test/DISTS_weights.pt")
self.alpha.data = weights["alpha"]
self.beta.data = weights["beta"]
for param in self.parameters():
param.requires_grad = False
def forward_once(self, x):
h = (x - self.mean) / self.std
h = self.stage1(h)
h_relu1_2 = h
h = self.stage2(h)
h_relu2_2 = h
h = self.stage3(h)
h_relu3_3 = h
h = self.stage4(h)
h_relu4_3 = h
h = self.stage5(h)
h_relu5_3 = h
return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
def forward(self, x, y, require_grad=False, batch_average=False):
if require_grad:
feats0 = self.forward_once(x)
feats1 = self.forward_once(y)
else:
with torch.no_grad():
feats0 = self.forward_once(x)
feats1 = self.forward_once(y)
dist1 = 0
dist2 = 0
c1 = 1e-6
c2 = 1e-6
w_sum = self.alpha.sum() + self.beta.sum()
alpha = torch.split(self.alpha / w_sum, self.chns, dim=1)
beta = torch.split(self.beta / w_sum, self.chns, dim=1)
for k in range(len(self.chns)):
x_mean = feats0[k].mean([2, 3], keepdim=True)
y_mean = feats1[k].mean([2, 3], keepdim=True)
S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1)
dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True)
x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True)
y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True)
xy_cov = (feats0[k] * feats1[k]).mean(
[2, 3], keepdim=True
) - x_mean * y_mean
S2 = (2 * xy_cov + c2) / (x_var + y_var + c2)
dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True)
score = 1 - (dist1 + dist2).squeeze()
if batch_average:
return score.mean()
else:
return score
class ToBinary(torch.autograd.Function):
@staticmethod
def forward(ctx, x):#, noise_scale):
# if noise_scale > 0:
# noise_min = 0.5 - noise_scale / 2
# noise_max = 0.5 + noise_scale / 2
# return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max))
# else:
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone()#, None
########################################################################################################
class R_ENCODER(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
dd = 8
self.Bxx = nn.BatchNorm2d(dd*64)
self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*4)
self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*64)
self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.B21 = nn.BatchNorm2d(dd*64)
# self.C24 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C25 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.C26 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C27 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
@MyFunction
def forward(self, img):
ACT = F.mish
x = self.CIN(img)
xx = self.Bxx(F.pixel_unshuffle(x, 8))
x = x + self.Cx1(ACT(self.Cx0(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
x = x + self.C13(ACT(self.C12(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
# x = x + self.C25(ACT(self.C24(ACT(self.B21(x)))))
# x = x + self.C27(ACT(self.C26(x)))
x = self.COUT(x + xx)
return torch.sigmoid(x)
########################################################################################################
class R_DECODER(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
dd = 8
self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*64)
self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.B01 = nn.BatchNorm2d(dd*64)
# self.C04 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C05 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.C06 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C07 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*4)
self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
@MyFunction
def forward(self, code):
ACT = F.mish
x = self.CIN(code)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
# x = x + self.C05(ACT(self.C04(ACT(self.B01(x)))))
# x = x + self.C07(ACT(self.C06(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
x = x + self.C13(ACT(self.C12(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.Cx1(ACT(self.Cx0(x)))
x = self.COUT(x)
return torch.sigmoid(x)
########################################################################################################`
def cosine_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return 1 - torch.einsum('ij,ij->i',[x,y])
class RWKV_IMG(pl.LightningModule):
def __init__(self, args):
super().__init__()
self.args = args
self.encoder = R_ENCODER(args)
self.decoder = R_DECODER(args)
self.clip_model = None
clip_name = args.my_img_clip
if clip_name == 'B32':
clip_name = 'ViT-B/32'
elif clip_name == 'B16':
clip_name = 'ViT-B/16'
elif clip_name == 'L14':
clip_name = 'ViT-L/14'
elif clip_name == 'OB32':
clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
self.clip_model = CLIPModel.from_pretrained(clip_name)
self.clip_model.encode_image = self.clip_model.get_image_features
if self.clip_model == None:
self.clip_model, _ = clip.load(clip_name, jit = True)
self.register_buffer(
"clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1)
)
self.register_buffer(
"clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1)
)
for n, p in self.named_parameters():
if 'clip_model' in n:
p.requires_grad = False
self.loss_dists = DISTS()
# self.loss_ssim = MS_SSIM(data_range=1, size_average=True, channel=3)
def configure_optimizers(self):
args = self.args
optim_groups = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
]
if self.deepspeed_offload:
return DeepSpeedCPUAdam(
optim_groups,
lr=self.args.lr_init,
betas=self.args.betas,
eps=self.args.adam_eps,
bias_correction=True,
adamw_mode=False,
weight_decay=0,
amsgrad=False,
)
return FusedAdam(
optim_groups,
lr=self.args.lr_init,
betas=self.args.betas,
eps=self.args.adam_eps,
bias_correction=True,
adam_w_mode=False,
weight_decay=0,
amsgrad=False,
)
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
@property
def deepspeed_offload(self) -> bool:
strategy = self.trainer.strategy
if isinstance(strategy, DeepSpeedStrategy):
config = strategy.config["zero_optimization"]
return config.get("offload_optimizer") or config.get("offload_param")
return False
def forward(self, img):
z = self.encoder(img)
z = ToBinary.apply(z)#, self.args.my_img_noise_scale)
out = self.decoder(z)
return out
def training_step(self, batch, batch_idx):
args = self.args
img, txt = batch
out = self(img)
if self.trainer.is_global_zero:
if (self.trainer.global_step + 1) % (100 * int(args.devices)) == 0:
img_dir = f"test/image_model/{args.run_name}"
if not os.path.exists(img_dir):
os.makedirs(img_dir)
vision.utils.save_image(
img[:4], f"{img_dir}/{self.trainer.global_step}-src.jpg"#, padding=0
)
vision.utils.save_image(
out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg"#, padding=0
)
# loss_ssim = 1 - self.loss_ssim(out, img)
loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True)
iii = self.clip_model.encode_image((img - self.clip_mean) / self.clip_std)
ooo = self.clip_model.encode_image((out - self.clip_mean) / self.clip_std)
loss_clip = torch.mean(cosine_loss(iii, ooo))
if args.my_img_l1_scale > 0:
loss_l1 = F.l1_loss(out, img)
return loss_dists + loss_clip * args.my_img_clip_scale + loss_l1 * args.my_img_l1_scale
else:
return loss_dists + loss_clip * args.my_img_clip_scale
def training_step_end(self, batch_parts):
all = self.all_gather(batch_parts)
if self.trainer.is_global_zero:
self.trainer.my_loss_all = all
def generate_init_weight(self):
print(
f"""
############################################################################
#
# Init model weight (slow for large models)...
#
############################################################################
"""
)
m = {}
for n in self.state_dict():
scale = 1
p = self.state_dict()[n]
shape = p.shape
ss = n.split('.')
# if ss[0] in ['encoder', 'decoder']:
# if ss[2] == 'bias':
# scale = 0
# # elif n == 'encoder.CIN.weight':
# # nn.init.dirac_(p)
# else:
# try:
# if ss[1][0] == 'C' and (int(ss[1][2]) % 2 == 1):
# scale = 0
# except:
# pass
# m[n] = p * scale
m[n] = p
m[n] = m[n].cpu()
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
m[n] = m[n].half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
m[n] = m[n].bfloat16()
gc.collect()
torch.cuda.empty_cache()
return m

@ -0,0 +1,237 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import types
import torch
import math, os, gc
from torch.nn import functional as F
import torch.nn as nn
from typing import List, Dict
MyModule = nn.Module
def __nop(ob):
return ob
MyFunction = __nop
# # try torchdynamo
# import torchdynamo
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
# try torch jit --> faster for fp32, slower for fp16 (why?)
if os.environ["RWKV_JIT_ON"] == "1":
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n')
DEBUG_TIME = False # True False - show trained time-coeffs
RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer
############################################################################################################
class RWKV_RNN(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.FLOAT_MODE = args.FLOAT_MODE
self.RUN_DEVICE = args.RUN_DEVICE
with torch.no_grad():
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
# refine weights and send to correct device
keys = list(w.keys())
if 'pos_emb_x' in keys:
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:]
keys = list(w.keys())
print_need_newline = False
for x in keys:
block_id = 0
if 'blocks.' in x:
block_id = int(x.split('.')[1])
if 'att.output.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
if 'ffn.value.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
if '.time_' in x:
w[x] = w[x].squeeze()
if DEBUG_TIME:
print(x, w[x].numpy())
if '.time_decay' in x:
w[x] = w[x].float()
w[x] = -torch.exp(w[x])
elif '.time_first' in x:
w[x] = w[x].float()
else:
if self.FLOAT_MODE == "fp32":
w[x] = w[x].float()
elif self.FLOAT_MODE == "bf16":
w[x] = w[x].bfloat16()
elif self.FLOAT_MODE == "fp16":
w[x] = w[x].half()
w[x].requires_grad = False
if args.RUN_DEVICE == 'cuda' and x != 'emb.weight':
w[x] = w[x].cuda()
if ('blocks.' not in x) or ('blocks.0.' in x):
if print_need_newline:
print('\n', end = '')
print_need_newline = False
print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device)
else:
print_need_newline = True
print('.', end = '', flush = True)
# store weights in self.w
keys = list(w.keys())
self.w = types.SimpleNamespace()
for x in keys:
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.eval()
gc.collect()
torch.cuda.empty_cache()
def LN(self, x, w):
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
# state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
@MyFunction
def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+0] = x.float()
elif self.FLOAT_MODE == "fp16":
xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r)
state[5*i+0] = x.float()
else:
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
state[5*i+0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk))
kv = vw @ k
return r * kv
@MyFunction
def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+1] = x.float()
elif self.FLOAT_MODE == "fp16":
xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r)
state[5*i+1] = x.float()
else:
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
state[5*i+1] = x
r = torch.sigmoid(rw @ xr)
k = kw @ xk
v = vw @ xv
if '16' in self.FLOAT_MODE:
kk = k.float()
vv = v.float()
else:
kk = k
vv = v
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = time_first + kk
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * vv
b = e1 * bb + e2
ww = pp + time_decay
p = torch.maximum(ww, kk)
e1 = torch.exp(ww - p)
e2 = torch.exp(kk - p)
state[5*i+2] = e1 * aa + e2 * vv
state[5*i+3] = e1 * bb + e2
state[5*i+4] = p
if self.FLOAT_MODE == "bf16":
wkv = (a / b).type(torch.bfloat16)
elif self.FLOAT_MODE == "fp16":
wkv = (a / b).half()
else:
wkv = a / b
return ow @ (r * wkv)
def forward(self, ctx, state, preprocess_only = False):
with torch.no_grad():
w = self.w
args = self.args
x = w.emb.weight[ctx[-1]]
if self.RUN_DEVICE == 'cuda':
x = x.cuda()
try:
pos_emb = w.pos_emb[len(ctx)-1]
x = x + pos_emb
except:
pass
if state == None:
state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE)
for i in range(args.n_layer):
state[5*i+4] -= 1e30
for i in range(args.n_layer):
if i == 0:
x = self.LN(x, w.blocks[i].ln0)
ww = w.blocks[i].att
x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i,
ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay,
ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight)
ww = w.blocks[i].ffn
x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i,
ww.time_mix_k, ww.time_mix_r,
ww.key.weight, ww.value.weight, ww.receptance.weight)
if (i+1) % RWKV_RESCALE_LAYER == 0:
x = x / 2
if preprocess_only:
return state
x = self.LN(x, w.ln_out)
x = w.head.weight @ x
return x.float(), state

@ -0,0 +1,190 @@
import os, math, time, datetime, subprocess
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
def my_save(dd, ff):
if '14b-run1' not in ff:
torch.save(dd, ff)
else:
fn = ff.split('/')[-1]
fff = '/dev/shm/' + fn
torch.save(dd, fff)
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
class train_callback(pl.Callback):
def __init__(self, args):
super().__init__()
self.args = args
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
args = self.args
# if args.cuda_cleanup > 0:
# torch.cuda.empty_cache()
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
# LR schedule
w_step = args.warmup_steps
if args.lr_final == args.lr_init or args.epoch_count == 0:
lr = args.lr_init
else:
decay_step = real_step - args.my_pile_edecay * args.epoch_steps
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
progress = (decay_step - w_step + 1) / (decay_total - w_step)
progress = min(1, max(0, progress))
if args.lr_final == 0 or args.lr_init == 0: # linear decay
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
if trainer.global_step < w_step:
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
# if trainer.is_global_zero:
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
for param_group in trainer.optimizers[0].param_groups:
if args.layerwise_lr > 0:
param_group["lr"] = lr * param_group["my_lr_scale"]
# print(param_group["lr"], param_group["my_lr_scale"])
else:
param_group["lr"] = lr
trainer.my_lr = lr
# rank_zero_info(f"{real_step} {lr}")
if trainer.global_step == 0:
if trainer.is_global_zero: # logging
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
try:
print(f"\n{trainer.strategy.config}\n")
trainer.my_log.write(f"{trainer.strategy.config}\n")
except:
pass
trainer.my_log.flush()
if len(args.wandb) > 0:
print("Login to wandb...")
import wandb
wandb.init(
project=args.wandb,
name=args.run_name + " " + args.my_timestamp,
config=args,
save_code=False,
)
trainer.my_wandb = wandb
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
args = self.args
if trainer.is_global_zero: # logging
t_now = time.time_ns()
token_per_step = args.ctx_len * args.real_bsz
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
kt_s = 0
try:
t_cost = (t_now - trainer.my_time_ns) / 1e9
kt_s = token_per_step / t_cost / 1000
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
except:
pass
trainer.my_time_ns = t_now
trainer.my_loss = trainer.my_loss_all.float().mean().item()
trainer.my_loss_sum += trainer.my_loss
trainer.my_loss_count += 1
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
# self.log("s", real_step, prog_bar=True, on_step=True)
if len(args.wandb) > 0:
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
if kt_s > 0:
lll["kt/s"] = kt_s
trainer.my_wandb.log(lll, step=int(real_step))
if args.magic_prime > 0:
expand_factor = 2 if args.my_qa_mask > 0 else 1
if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1:
to_save_dict = pl_module.state_dict()
my_save(
to_save_dict,
f"{args.proj_dir}/rwkv-final.pth",
)
def on_train_epoch_start(self, trainer, pl_module):
args = self.args
dataset = trainer.train_dataloader.dataset.datasets
assert "MyDataset" in str(dataset)
dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
def on_train_epoch_end(self, trainer, pl_module):
args = self.args
if trainer.is_global_zero: # logging & save state_dict
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
if args.data_type == 'wds_img':
raw_dict = pl_module.state_dict()
to_save_dict = {}
for k in raw_dict:
if k.startswith('encoder.') or k.startswith('decoder.'):
to_save_dict[k] = raw_dict[k]
else:
to_save_dict = pl_module.state_dict()
try:
my_save(
to_save_dict,
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
)
except Exception as e:
print('Error\n\n', e, '\n\n')
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
trainer.my_log.flush()
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
@rank_zero_only
def generate_init_weight(model, init_weight_name):
mm = model.generate_init_weight()
if model.args.my_pile_stage == 1:
if len(model.args.load_model) > 0:
print(f"Combine weights from {model.args.load_model}...")
load_dict = torch.load(model.args.load_model, map_location="cpu")
for k in load_dict:
assert k in mm
src = load_dict[k]
try:
mm[k] = src.reshape(mm[k].shape)
except:
tmp = mm[k].squeeze().clone()
print(k, src.shape, '-->', mm[k].shape)
ss = src.shape[0]
dd = tmp.shape[0]
for i in range(dd):
pos = i / dd * ss
if pos >= ss - 1:
tmp[i] = src[ss-1]
else:
p0 = int(math.floor(pos))
ii = pos - p0
tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii)
mm[k] = tmp.reshape(mm[k].shape)
sss = src.squeeze().float().cpu().numpy()
print(sss[:10], '...', sss[-10:])
mmm = mm[k].squeeze().float().cpu().numpy()
print(mmm[:10], '...', mmm[-10:])
print(f"Save to {init_weight_name}...")
torch.save(mm, init_weight_name)
if model.args.my_pile_stage == 1:
print("Done. Now go for stage 2.")
exit(0)

@ -0,0 +1,130 @@
import json, time, random, os
import numpy as np
import torch
from torch.nn import functional as F
time_slot = {}
time_ref = time.time_ns()
def record_time(name):
if name not in time_slot:
time_slot[name] = 1e20
tt = (time.time_ns() - time_ref) / 1e9
if tt < time_slot[name]:
time_slot[name] = tt
class TOKENIZER():
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
if 'list' in str(type(WORD_NAME)):
self.charMode = False
if WORD_NAME[0] == WORD_NAME[1]:
from transformers import PreTrainedTokenizerFast
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
else:
from transformers import GPT2TokenizerFast
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
self.vocab_size = len(self.tokenizer)
else:
self.charMode = True
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
self.word_table = json.load(result_file)
self.vocab_size = len(self.word_table)
self.stoi = {v: int(k) for k, v in self.word_table.items()}
self.itos = {int(k): v for k, v in self.word_table.items()}
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
# out[self.UNKNOWN_CHAR] = -float('Inf')
lastChar = int(x[-1])
probs = F.softmax(out, dim=-1)
if self.charMode:
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
else:
top_p = top_p_usual
if os.environ["RWKV_RUN_DEVICE"] == "cpu":
probs = probs.numpy()
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return out
else:
sorted_probs = torch.sort(probs, descending=True)[0]
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0]
return out
def MaybeIsPrime(number):
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
return True
else:
return False
def FermatPrimalityTest(number):
if number > 1:
for time in range(3):
randomNumber = random.randint(2, number) - 1
if pow(randomNumber, number - 1, number) != 1:
return False
return True
else:
return False
def MillerRabinPrimalityTest(number):
if number == 2:
return True
elif number == 1 or number % 2 == 0:
return False
oddPartOfNumber = number - 1
timesTwoDividNumber = 0
while oddPartOfNumber % 2 == 0:
oddPartOfNumber = oddPartOfNumber // 2
timesTwoDividNumber = timesTwoDividNumber + 1
for time in range(3):
while True:
randomNumber = random.randint(2, number) - 1
if randomNumber != 0 and randomNumber != 1:
break
randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
iterationNumber = 1
while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1):
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
iterationNumber = iterationNumber + 1
if randomNumberWithPower != (number - 1):
return False
return True

@ -0,0 +1,349 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
if __name__ == "__main__":
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
rank_zero_info("########## work in progress ##########")
########################################################################################################
#
# example: train a simple L12-D768 RWKV on dummy data
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "" --data_type "dummy" --vocab_size 0 \
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: train a simple L6-D512 RWKV from scratch on enwik8
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
parser = ArgumentParser()
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--random_seed", default="-1", type=int)
parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--dim_att", default=0, type=int)
parser.add_argument("--dim_ffn", default=0, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float)
parser.add_argument("--warmup_steps", default=-1, type=int) # try 50 if you load a model
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
parser.add_argument("--my_pile_edecay", default=0, type=int)
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser.add_argument("--my_img_version", default=0, type=str)
parser.add_argument("--my_img_size", default=0, type=int)
parser.add_argument("--my_img_bit", default=0, type=int)
parser.add_argument("--my_img_clip", default='x', type=str)
parser.add_argument("--my_img_clip_scale", default=1, type=float)
parser.add_argument("--my_img_l1_scale", default=0, type=float)
parser.add_argument("--my_img_encoder", default='x', type=str)
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser.add_argument("--my_sample_len", default=0, type=int)
parser.add_argument("--my_ffn_shift", default=1, type=int)
parser.add_argument("--my_att_shift", default=1, type=int)
parser.add_argument("--my_pos_emb", default=0, type=int)
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
########################################################################################################
import os, warnings, math, datetime, sys, time, importlib
import numpy as np
import torch
from torch.utils.data import DataLoader
if "deepspeed" in args.strategy:
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
seed_everything(args.random_seed)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
# os.environ["WDS_SHOW_SEED"] = "1"
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
args.enable_checkpointing = False
args.replace_sampler_ddp = False
args.logger = False
args.gradient_clip_val = 1.0
args.num_sanity_val_steps = 0
args.check_val_every_n_epoch = int(1e20)
args.log_every_n_steps = int(1e20)
args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
os.environ["RWKV_MY_TESTING"] = args.my_testing
if args.dim_att <= 0:
args.dim_att = args.n_embd
if args.dim_ffn <= 0:
args.dim_ffn = args.n_embd * 4
if args.data_type == "wds_img":
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
else:
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
if not os.path.exists(args.proj_dir):
os.makedirs(args.proj_dir)
if args.my_pile_stage > 0:
magic_prime_bak = args.magic_prime
if args.my_pile_version == 1:
if args.ctx_len == 1024:
args.magic_prime = 324331313
args.epoch_count = 8043
elif args.ctx_len == 2048:
args.magic_prime = 162165671
args.epoch_count = 4021
elif args.ctx_len == 4096:
args.magic_prime = 81082817
args.epoch_count = 2010
elif args.ctx_len == 8192:
args.magic_prime = 40541399
args.epoch_count = 1005
else:
if args.ctx_len == 1024:
args.magic_prime = 1694947181
args.epoch_count = 42036
elif args.ctx_len == 2048:
args.magic_prime = 847473509
args.epoch_count = 21017
elif args.ctx_len == 4096:
args.magic_prime = 423736637
args.epoch_count = 10508
elif args.ctx_len == 6144:
args.magic_prime = 282491051
args.epoch_count = 7005
elif args.ctx_len == 8192:
args.magic_prime = 211868243
args.epoch_count = 5253
if args.my_pile_shift < 0:
args.my_pile_shift = 0
if magic_prime_bak > 0:
args.magic_prime = magic_prime_bak
args.epoch_steps = 40320 // args.real_bsz
assert args.epoch_steps * args.real_bsz == 40320
if args.my_pile_stage == 2:
assert args.lr_final == args.lr_init
if args.my_pile_stage >= 2: # find latest saved model
list_p = []
for p in os.listdir(args.proj_dir):
if p.startswith("rwkv") and p.endswith(".pth"):
p = ((p.split("-"))[1].split("."))[0]
if p == "init":
p = -1
else:
p = int(p)
list_p += [p]
list_p.sort()
max_p = list_p[-1]
if len(list_p) > 1:
args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted
if max_p == -1:
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
if args.warmup_steps < 0:
if args.my_pile_stage == 2:
args.warmup_steps = 10
else:
args.warmup_steps = 30
args.epoch_begin = max_p + 1
samples_per_epoch = args.epoch_steps * args.real_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len
rank_zero_info(
f"""
############################################################################
#
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
#
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
#
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
#
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
#
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
#
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
#
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
# Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
#
############################################################################
"""
)
rank_zero_info(str(vars(args)) + "\n")
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
if args.lr_final == 0 or args.lr_init == 0:
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision
if args.precision == "fp32":
for i in range(10):
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
os.environ["RWKV_JIT_ON"] = "1"
if "deepspeed_stage_3" in args.strategy:
os.environ["RWKV_JIT_ON"] = "0"
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
if args.precision == "fp32":
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
if "32" in args.precision:
args.precision = 32
elif args.precision == "fp16":
args.precision = 16
else:
args.precision = "bf16"
########################################################################################################
from src.trainer import train_callback, generate_init_weight
from src.dataset import MyDataset
train_data = MyDataset(args)
args.vocab_size = train_data.vocab_size
if args.data_type == 'wds_img':
from src.model_img import RWKV_IMG
model = RWKV_IMG(args)
else:
from src.model import RWKV
model = RWKV(args)
if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
generate_init_weight(model, init_weight_name) # save initial weights
args.load_model = init_weight_name
rank_zero_info(f"########## Loading {args.load_model}... ##########")
try:
load_dict = torch.load(args.load_model, map_location="cpu")
except:
rank_zero_info(f"Bad checkpoint {args.load_model}")
if args.my_pile_stage >= 2: # try again using another checkpoint
max_p = args.my_pile_prev_p
if max_p == -1:
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
args.epoch_begin = max_p + 1
rank_zero_info(f"Trying {args.load_model}")
load_dict = torch.load(args.load_model, map_location="cpu")
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in model.state_dict():
if k not in load_keys:
load_dict[k] = model.state_dict()[k]
model.load_state_dict(load_dict)
trainer = Trainer.from_argparse_args(
args,
callbacks=[train_callback(args)],
)
if trainer.global_rank == 0:
for n in model.state_dict():
shape = model.state_dict()[n].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
else:
print(f"{str(shape[0]).ljust(5)} {n}")
if "deepspeed" in args.strategy:
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
trainer.fit(model, data_loader)

@ -0,0 +1,104 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
# this is for verifying the results of different models and make sure they agree with each other
import os, sys, types
import numpy as np
import torch
np.set_printoptions(precision=4, suppress=True, linewidth=200)
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # bf16 or fp32
os.environ['RWKV_RUN_DEVICE'] = 'cuda' # currently model_train requires CUDA
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
TOKEN_MODE = 'pile'
if TOKEN_MODE == 'pile':
WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json']
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
n_layer = 32
n_embd = 2560
ctx_len = 1024
UNKNOWN_CHAR = None
from src.utils import TOKENIZER
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
if TOKEN_MODE == 'pile':
tokenizer.vocab_size = 50277
########################################################################################################
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_T_MAX"] = str(ctx_len)
from src.model_run import RWKV_RNN
from src.model import RWKV
args = types.SimpleNamespace()
args.vocab_size = tokenizer.vocab_size
args.ctx_len = ctx_len
args.n_embd = n_embd
args.n_layer = n_layer
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
model_train = RWKV(args).to(RUN_DEVICE)
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
model_train = model_train.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
model_train = model_train.bfloat16()
print('loading ' + MODEL_NAME)
m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu')
model_train.load_state_dict(m2)
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
model_train = model_train.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
model_train = model_train.bfloat16()
args.MODEL_NAME = MODEL_NAME
args.RUN_DEVICE = RUN_DEVICE
args.FLOAT_MODE = os.environ['RWKV_FLOAT_MODE']
model_rnn = RWKV_RNN(args)
########################################################################################################
print(f"\nVerifying {os.environ['RWKV_RUN_DEVICE']} {os.environ['RWKV_FLOAT_MODE']}")
# context = '\nIn a'
context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.'
if TOKEN_MODE == 'pile':
ctx = tokenizer.tokenizer.encode(context)
print(f'input len {len(ctx)} data {ctx}')
########################################################################################################
with torch.no_grad():
print('\nRWKV-train output')
out = model_train.forward(torch.tensor([ctx]).to(RUN_DEVICE))[0].detach().cpu().float().numpy()
print(out, '\n')
print('\nRWKV-RNN output')
state = None
out = None
src_len = len(ctx)
for i in range(src_len):
x = ctx[:i+1]
out, state = model_rnn.forward(x, state)
if i < 3 or i >= src_len - 3:
print(out.detach().cpu().numpy())
if i == 2:
print('...')

Binary file not shown.

After

Width:  |  Height:  |  Size: 534 KiB

Loading…
Cancel
Save