|
|
# RWKV-LM
|
|
|
|
|
|
We propose the RWKV language model, with alternating time-mix and channel-mix layers:
|
|
|
|
|
|
<img src=
|
|
|
"https://render.githubusercontent.com/render/math?math=%5Cdisplaystyle+%5Cbegin%7Balign%2A%7D%0A%5Ctext%7BTime-mix+%3A%7D+%26%26+%5Ctext%7BTM%7D_%7Bt%2Cc%7D+%26%26%3D%26%26%5Ctext%7Bsigmoid%7D%28%5Ctext%7BR%7D_%7Bt%2Cc%7D%29+%26%26%5Ccdot%26%26+%26%26%5Ctextstyle%5Csum_%7Bu%7D+%26%26%5Ctextbf%7BW%7D_%7Bt%2Cu%2Cc%7D+%26%26%5Ccdot%26%26+%5Ctext%7Bsoftmax%7D_t%28%5Ctext%7BK%7D_%7Bu%2Cc%7D%29+%26%26%5Ccdot%26%26+%5Ctext%7BV%7D_%7Bu%2Cc%7D%5C%5C%0A%5Ctext%7BChannel-mix+%3A%7D+%26%26+%5Ctext%7BCM%7D_%7Bt%2Cc%7D+%26%26%3D%26%26%5Ctext%7Bsigmoid%7D%28%5Ctext%7BR%7D_%7Bt%2Cc%7D%29+%26%26%5Ccdot%26%26+%26%26%5Ctextstyle%5Csum_d+%26%26%5Ctextbf%7BW%7D_%7Bc%2Cd%7D+%26%26%5Ccdot%26%26+%5Ctext%7Bgelu%7D%28%5Ctext%7BK%7D_%7Bt%2Cd%7D%29+%26%26%5Ccdot%26%26+%5Ctext%7BV%7D_%7Bt%2Cd%7D%0A%5Cend%7Balign%2A%7D%0A"
|
|
|
alt="\begin{align*}
|
|
|
\text{Time-mix :} && \text{TM}_{t,c} &&=&&\text{sigmoid}(\text{R}_{t,c}) &&\cdot&& &&\textstyle\sum_{u} &&\textbf{W}_{t,u,c} &&\cdot&& \text{softmax}_t(\text{K}_{u,c}) &&\cdot&& \text{V}_{u,c}\\
|
|
|
\text{Channel-mix :} && \text{CM}_{t,c} &&=&&\text{sigmoid}(\text{R}_{t,c}) &&\cdot&& &&\textstyle\sum_d &&\textbf{W}_{c,d} &&\cdot&& \text{gelu}(\text{K}_{t,d}) &&\cdot&& \text{V}_{t,d}
|
|
|
\end{align*}
|
|
|
">
|
|
|
|
|
|
* The R, K, V are generated by linear transforms of input, and W is parameter. The idea of RWKV is to decompose attention into R(target) * W(src, target) * K(src). So we can call R "receptance", and sigmoid means it's in 0~1 range.
|
|
|
|
|
|
* The Time-mix is similar to AFT (https://arxiv.org/abs/2105.14103). There are two differences.
|
|
|
|
|
|
(1) We changed the normalization (denominator). For masked language models, we define:
|
|
|
|
|
|
<img src=
|
|
|
"https://render.githubusercontent.com/render/math?math=%5Cdisplaystyle+%5Ctext%7Bsoftmax%7D_t%28%5Ctext%7BK%7D_%7Bu%2Cc%7D%29+%3D+%5Cfrac%7B%5Cexp%28%5Ctext%7BK%7D_%7Bu%2Cc%7D%29%7D%7B%5Csum_%7Bv+%5Cleq+t%7D%5Cexp%28%5Ctext%7BK%7D_%7Bv%2Cc%7D%29%7D"
|
|
|
alt="\text{softmax}_t(\text{K}_{u,c}) = \frac{\exp(\text{K}_{u,c})}{\sum_{v \leq t}\exp(\text{K}_{v,c})}">
|
|
|
|
|
|
(2) We decompose W_{t,u,c} and introduce multi-head W (here h is the corresponding head of c):
|
|
|
|
|
|
<img src=
|
|
|
"https://render.githubusercontent.com/render/math?math=%5Cdisplaystyle+W_%7Bt%2Cu%2Cc%7D%3Df_h%28t-u%29%5Ccdot+%5Calpha_h%28u%29+%5Ccdot+%5Cbeta_h%28t%29"
|
|
|
alt="W_{t,u,c}=f_h(t-u)\cdot \alpha_h(u) \cdot \beta_h(t)">
|
|
|
|
|
|
Moreover we multiply the final output of Time-mix layer by γ(t). The reason for the α β γ factors, is because the context size is smaller when t is small, and this can be compensated using the α β γ factors.
|
|
|
|
|
|
* The Channel-mix is similar to GeGLU (https://arxiv.org/abs/2002.05202) with an extra R factor.
|
|
|
|
|
|
* Finally, we add extra time-shift mixing as in (https://github.com/BlinkDL/minGPT-tuned).
|
|
|
|
|
|
***
|
|
|
|
|
|
the time-shift mixing means explicitly using both (half channel of this token) & (half channel of prev token) to generate all vectors.
|
|
|
|
|
|
i find divide by 2 and shift-1 is the best. i looked at the weights and found you may want to use less mixing in higher layers.
|
|
|
|
|
|
here is my theory:
|
|
|
|
|
|
when you train a GPT, the hidden representation of a token has to accomplish two different objects:
|
|
|
|
|
|
1. predict the next token. sometimes this is easy (obvious next token).
|
|
|
|
|
|
2. collect all prev ctx info so later token can use it. this is always hard.
|
|
|
|
|
|
the time_shifted channels can focus on (2). so we have good propagation of info. it's like some kind of residual connection.
|
|
|
|
|
|
you can use time_shift in usual QKV self-attention too. when i studied the weights, i found V really likes the time_shifted channels. less so for Q. makes sense if you think abt it.
|
|
|
|
|
|
***
|
|
|
|
|
|
p.s. There is aother MHA_pro model in this repo with strong performance. Give it a try :)
|
|
|
|
|
|
***
|
|
|
|
|
|
We also propose a new sampling method (as in src/utils.py):
|
|
|
|
|
|
(1) Find the max probability p_max after softmax.
|
|
|
|
|
|
(2) Remove all entries whose probability is lower than 0.02 * pow(p_max, 2)
|
|
|
|
|
|
(3) Feel free to tune the 0.02 and 2 factor.
|
|
|
|
|
|
***
|
|
|
|
|
|
Character-level loss on simplebooks-92 dataset https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip
|
|
|
|
|
|

|
|
|
|
|
|
Gray: usual MHA+Rotary+GeGLU - performance not as good.
|
|
|
|
|
|
Red: RWKV ("linear" attention) - VRAM friendly - quite faster when ctx window is long - good performance.
|
|
|
|
|
|
Black: MHA_pro (MHA with various tweaks & RWKV-type-FFN) - slow - needs more VRAM - good performance.
|
|
|
|
|
|
parameters count: 17.2 vs 18.5 vs 18.5.
|
|
|
|
|
|
```
|
|
|
@software{peng_bo_2021_5196578,
|
|
|
author = {PENG Bo},
|
|
|
title = {BlinkDL/RWKV-LM: 0.01},
|
|
|
month = aug,
|
|
|
year = 2021,
|
|
|
publisher = {Zenodo},
|
|
|
version = {0.01},
|
|
|
doi = {10.5281/zenodo.5196578},
|
|
|
url = {https://doi.org/10.5281/zenodo.5196578}
|
|
|
}
|
|
|
```
|