## RWKV v2: RNN with Transformer-level Performance
## RWKV-2: RNN with Transformer-level Performance
RWKV v2 is a RNN with Transformer-level performance, which can also be directly trained like a GPT transformer (parallelizable). And it's attention-free. You only need x_t, a_t, b_t of position t to compute the vectors for position t+1. You can use the "GPT" mode to quickly build the hidden state for the "RNN" mode.
RWKV-2 is a RNN with Transformer-level performance, which can also be directly trained like a GPT transformer (parallelizable). And it's attention-free. You only need x_t, a_t, b_t of position t to compute the vectors for position t+1. You can use the "GPT" mode to quickly build the hidden state for the "RNN" mode.
So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
@ -14,7 +14,7 @@ I am training it on the Pile (https://github.com/BlinkDL/RWKV-v2-RNN-Pile) and i
All of the trained models will be open-source. Inference is very fast (only matrix-vector multiplications, no matrix-matrix multiplications) even on CPUs, and I believe you can run a 1B params RWKV-v2-RNN with reasonable speed on your phone.
### Quick start
## Quick start
See https://github.com/BlinkDL/RWKV-v2-RNN-Pile for L24-D1024 and L12-D768 models trained on the Pile (and the latest code).
@ -31,7 +31,48 @@ Note: For fine-tuning the Pile model, change 1e-15 to 1e-9 in https://github.com
RWKV is inspired by Apple's AFT (https://arxiv.org/abs/2105.14103).
The pseudocode (execution from top to bottom):
### From GPT to RWKV-2 (the formulas)
Let F[t] be the system state at t.
Let x[t] be the new external input at t.
In GPT, predicting F[t+1] requires considering F[0], F[1], .. F[t]. So it takes O(T^2) to generate a length T sequence.
Here R, K, V are trainable matrices, and W is a trainable vector (time-decay factor for each channel).
In GPT, the contribution of F[i] to F[t+1] is weighted by ![ \exp (\mathbf{Q}x[\mathrm{t}] * \mathbf{K}F[\mathrm{i}]) ](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle++%5Cexp+%28%5Cmathbf%7BQ%7Dx%5B%5Cmathrm%7Bt%7D%5D+%2A+%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29+).
In RWKV-2, the contribution of F[i] to F[t+1] is weighted by ![\sigma(\mathbf{R}x[\mathrm{t}]) \cdot \exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i})) \cdot \exp (\mathbf{K}F[\mathrm{i}]) ](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+%5Csigma%28%5Cmathbf%7BR%7Dx%5B%5Cmathrm%7Bt%7D%5D%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BW%7D+%5Ccdot%28%5Cmathrm%7Bt%7D-%5Cmathrm%7Bi%7D%29%29+%5Ccdot+%5Cexp+%28%5Cmathbf%7BK%7DF%5B%5Cmathrm%7Bi%7D%5D%29+).
* Here  is a non-linearity and we can use sigmoid.
* Note ![\sigma(\mathbf{R}x[\mathrm{t}])](https://render.githubusercontent.com/render/math?math=%5Ccolor%7Bblack%7D%5Cdisplaystyle+%5Csigma%28%5Cmathbf%7BR%7Dx%5B%5Cmathrm%7Bt%7D%5D%29) is not in the denominator, and I call R the "receptance".
* Here  is the time-decay factor. I proposed the same idea (scaling the attention by distance) in Aug 2020 and called it the "time-weighting" (check the commit history of https://github.com/BlinkDL/minGPT-tuned).
Now here is the punchline: we can rewrite it into a RNN (recursive formula). Note:
where A[t] and B[t] are the numerator and denominator of the previous step, respectively.
I believe RWKV-2 is performant because W is like repeatedly applying a diagonal matrix. Note (P^{-1} D P)^n = P^{-1} D^n P, so it is similar to repeatedly applying a general diagonalizable matrix.
Moreover it's possible to turn it into a continuous ODE (a bit similar to State Space Models). I will write about it later.
### The pseudocode (execution from top to bottom):

@ -49,7 +90,7 @@ It's also using my SmallInitEmb trick https://github.com/BlinkDL/SmallInitEmb (a
I find it might be nice to make the model stay on a mid-lr for a long period, because in theory that's where most learning shall happen. For example: constant 6e-4 for 10% of steps, 6e-4 to 1e-4 in 15% of steps, stays at 1e-4 for 25% of steps (actually I monitor the loss and decay the lr when it plateaus), then 1e-4 to 1e-5 in 50% of steps.
# Better Learning Rate Schedule via Variantional Method of Loss Curve
## Better Learning Rate Schedule via Variantional Method of Loss Curve
I propose a simple new method to find better LR schedules. The method is cost-efficient and practical for large LMs. The takeaway is we can model the loss curve dynamics (phenomenology) w.r.t. the LR, and a nice closed-form LR curve can be directly computed from it using variantional method. Moreover we can predict the final loss with reasonable accuracy.
@ -61,7 +102,7 @@ In the last three plots, black = predicted loss curve of the new LR schedule, bl