RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
PENG Bo 6e2ba61d95
Update README.md
4 years ago
src small update 4 years ago
.gitignore add wandb, and rename variables 4 years ago
LICENSE Initial commit 4 years ago
README.md Update README.md 4 years ago
RWKV-vs-MHA.png no message 4 years ago
train.py small update 4 years ago

README.md

RWKV-LM

We propose the RWKV language model, with alternating time-mix and channel-mix layers:

\begin{align*}
\text{Time-mix :} && \text{TM}_{t,c} &&=&&\text{sigmoid}(\text{R}_{t,c}) &&\cdot&& &&\textstyle\sum_{u} &&\textbf{W}_{t,u,c} &&\cdot&& \text{softmax}_t(\text{K}_{u,c}) &&\cdot&& \text{V}_{u,c}\\
\text{Channel-mix :} && \text{CM}_{t,c} &&=&&\text{sigmoid}(\text{R}_{t,c}) &&\cdot&& &&\textstyle\sum_d &&\textbf{W}_{c,d} &&\cdot&& \text{gelu}(\text{K}_{t,d}) &&\cdot&& \text{V}_{t,d}
\end{align*}

  • The R, K, V are generated by linear transforms of input, and W is parameter. The idea of RWKV is to decompose attention into R(target) * W(src, target) * K(src). So we can call R "receptance", and sigmoid means it's in 0~1 range.

  • The Time-mix is similar to AFT (https://arxiv.org/abs/2105.14103). There are two differences.

(1) We changed the normalization (denominator). For masked language models, we define:

\text{softmax}_t(\text{K}_{u,c}) = \frac{\exp(\text{K}_{u,c})}{\sum_{v \leq t}\exp(\text{K}_{v,c})}

(2) We decompose W_{t,u,c} and introduce multi-head W (here h is the corresponding head of c):

W_{t,u,c}=f_h(t-u)\cdot \alpha_h(u) \cdot \beta_h(t)

Moreover we multiply the final output of Time-mix layer by γ(t). The reason for the α β γ factors, is because the context size is smaller when t is small, and this can be compensated using the α β γ factors.


the time-shift mixing means explicitly using both (half channel of this token) & (half channel of prev token) to generate all vectors.

i find divide by 2 and shift-1 is the best. i looked at the weights and found you may want to use less mixing in higher layers.

here is my theory:

when you train a GPT, the hidden representation of a token has to accomplish two different objects:

  1. predict the next token. sometimes this is easy (obvious next token).

  2. collect all prev ctx info so later token can use it. this is always hard.

the time_shifted channels can focus on (2). so we have good propagation of info. it's like some kind of residual connection.

you can use time_shift in usual QKV self-attention too. when i studied the weights, i found V really likes time_shift. less so for Q. makes sense if you think abt it.


p.s. There is aother MHA_pro model in this repo with strong performance. Give it a try :)


We also propose a new sampling method (as in src/utils.py):

(1) Find the max probability p_max after softmax.

(2) Remove all entries whose probability is lower than 0.02 * pow(p_max, 2)

(3) Feel free to tune the 0.02 and 2 factor.


Character-level loss on simplebooks-92 dataset https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip

RWKV-vs-MHA

Gray: usual MHA+Rotary+GeGLU - performance not as good.

Red: RWKV ("linear" attention) - VRAM friendly - quite faster when ctx window is long - good performance.

Black: MHA_pro (MHA with various tweaks & RWKV-type-FFN) - slow - needs more VRAM - good performance.

parameters count: 17.2 vs 18.5 vs 18.5.