Update README.md

main
PENG Bo 4 years ago committed by GitHub
parent 325ddb76f7
commit 64c9015dd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,17 +1,27 @@
# The RWKV Language Model # The RWKV Language Model
## RWKV v2: RNN with Transformer Performance ## RWKV v2: RNN with Transformer-level Performance
RWKV v2 is a RNN which can also be directly trained like a GPT transformer (parallelizable). You only need x_t, a_t, b_t of position t to compute the vectors for position t+1. Hence it can be 100x faster than GPT, and 100x more VRAM friendly, and you get a free sentence embedding. RWKV v2 is a RNN with Transformer-level performance, which can also be directly trained like a GPT transformer (parallelizable). And it's attention-free. You only need x_t, a_t, b_t of position t to compute the vectors for position t+1.
So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, and fast training. Moreover you get a free sentence embedding.
I am training it on the Pile: https://github.com/BlinkDL/RWKV-v2-RNN-Pile I am training it on the Pile: https://github.com/BlinkDL/RWKV-v2-RNN-Pile
See the release for a **27M params model on enwik8 with 0.72 BPC(dev)**. It might reach GPT-Neo performance within 100B tokens:
![RWKV-v2-430M-Pile](RWKV-v2-430M-Pile.png)
![RWKV-v2-RNN](RWKV-v2-RNN-run.png) See the release for a 27M params model on enwik8 with 0.72 BPC(dev).
## How it works ## How it works
RWKV is inspired by Apple's AFT (https://arxiv.org/abs/2105.14103).
The pseudocode (execution from top to bottom):
![RWKV-v2-RNN](RWKV-v2-RNN.png)
The a b c d factors work together to build a time-decay curve: X, 1, W, W^2, W^3, ... The a b c d factors work together to build a time-decay curve: X, 1, W, W^2, W^3, ...
Write out the formulas for "token at pos 2" and "token at pos 3" and you will get the idea: Write out the formulas for "token at pos 2" and "token at pos 3" and you will get the idea:
@ -24,11 +34,7 @@ RWKV v2 is parallelizable because the time-decay of each channel is data-indepen
It's also using my SmallInitEmb trick https://github.com/BlinkDL/SmallInitEmb (applicable to all transformers), and a custom CUDA kernel https://github.com/BlinkDL/RWKV-CUDA . It's also using my SmallInitEmb trick https://github.com/BlinkDL/SmallInitEmb (applicable to all transformers), and a custom CUDA kernel https://github.com/BlinkDL/RWKV-CUDA .
I find it might be nice to make the model stay on a mid-lr for a long period, because in theory that's where most learning shall happen. For example: 6e-4 to 1e-4 in 15% of steps, stays on 1e-4 for 60% of steps (actually I monitor the loss and decay the lr when it plateaus), then 1e-4 to 1e-5 in 25% of steps. I find it might be nice to make the model stay on a mid-lr for a long period, because in theory that's where most learning shall happen. For example: 6e-4 to 1e-4 in 20% of steps, stays at 1e-4 for 30% of steps (actually I monitor the loss and decay the lr when it plateaus), then 1e-4 to 1e-5 in 50% of steps.
The pseudocode (execution from top to bottom):
![RWKV-v2-RNN](RWKV-v2-RNN.png)
# Better Learning Rate Schedule via Variantional Method of Loss Curve # Better Learning Rate Schedule via Variantional Method of Loss Curve

Loading…
Cancel
Save