Update README.md

main
PENG Bo 4 years ago committed by GitHub
parent 8d63423a56
commit 92da17d68f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,9 +33,7 @@ Check https://github.com/BlinkDL/RWKV-v2-RNN-Pile for L24-D1024 and L12-D768 mod
Read the inference code in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/blob/main/src/model.py and try using the final hidden state.xx .aa .bb) as a faithful sentence embedding for other tasks (probably you shall begin with .xx and .aa/.bb (.aa divided by .bb)). Read the inference code in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/blob/main/src/model.py and try using the final hidden state.xx .aa .bb) as a faithful sentence embedding for other tasks (probably you shall begin with .xx and .aa/.bb (.aa divided by .bb)).
See the release here for a 27M params model on enwik8 with 0.72 BPC(dev). Run run.py in https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN. See the release here for a 27M params model on enwik8 with 0.72 BPC(dev). Run run.py in https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN. You can even run it in your browser: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng https://blinkdl.github.io/AI-Writer/eng/ (this is using tf.js WASM single-thread mode).
You can even run it in your browser: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng https://blinkdl.github.io/AI-Writer/eng/ (this is using tf.js WASM single-thread mode).
### Training / Fine-tuning ### Training / Fine-tuning
@ -43,12 +41,6 @@ Training: https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN
You will be training the "GPT" version because it's paralleziable and faster to train. I find RWKV-2 can extrapolate, so training with ctxLen 768 can work for ctxLen of several thousand. You can fine-tune the model with longer ctxLen later and it can quickly adapt to longer ctxLens. You will be training the "GPT" version because it's paralleziable and faster to train. I find RWKV-2 can extrapolate, so training with ctxLen 768 can work for ctxLen of several thousand. You can fine-tune the model with longer ctxLen later and it can quickly adapt to longer ctxLens.
My LR schedule for the L24-D1024 RWKV-2:
![RWKV-v2-430M-Pile-LR](RWKV-v2-430M-Pile-LR.png)
Fixing NaN or loss spikes: load a previous checkpoint, decrease LR a bit. I find you can decrease the LR faster than GPT, and eventually to 1/50 of LR_max.
**UPDATE: Search for "RWKV v2+" here and change RWKV-2 to PreLN to make it more stable.** **UPDATE: Search for "RWKV v2+" here and change RWKV-2 to PreLN to make it more stable.**
Fine-tuning: see https://github.com/BlinkDL/RWKV-v2-RNN-Pile. Fine-tuning: see https://github.com/BlinkDL/RWKV-v2-RNN-Pile.
@ -109,6 +101,108 @@ I need a better CUDA kernel to (1) pull off maxK so there's need to clamp k to 6
Removing the maxK limitation will also make it easy to clean the state of a KV-V channel, by using a huge K. Removing the maxK limitation will also make it easy to clean the state of a KV-V channel, by using a huge K.
========================================================================
### Explaining the code for RWKV v2+ GPT mode
Note: this is for the latest v2+ model.
#### The GPT mode - overview
The building blocks of RWKV-2 GPT mode are similar to that of a usual preLN GPT.
The only difference is an extra LN after embedding. Note you can absorb this LN into the embedding after finishing the training.
```python
x = self.emb(idx) # input: idx = token indices
x = self.ln_emb(x) # extra LN after embedding
x = x + self.att_0(self.ln_att_0(x)) # preLN
x = x + self.ffn_0(self.ln_ffn_0(x))
...
x = x + self.att_n(self.ln_att_n(x))
x = x + self.ffn_n(self.ln_ffn_n(x))
x = self.ln_head(x) # final LN before projection
x = self.head(x) # output: x = logits
```
It is important to initialize emb to tiny values, such as nn.init.uniform_(a=-1e-4, b=1e-4), to utilize my trick https://github.com/BlinkDL/SmallInitEmb.
For the 1.5B RWKV-2, I use Adam (no wd, no dropout) optimizer on 8 * A100 40G.
batchSz = 32 * 896, ctxLen = 896. I am using tf32 so the batchSz is a bit small.
For the first 15B tokens, LR is fixed at 3e-4, and beta=(0.9, 0.99).
Then I set beta=(0.9, 0.999), and do an exponential decay of LR, reaching 1e-5 at 332B tokens.
#### The GPT mode - ATT block
The RWKV-2 does not have any attention in the usual sense, but we will call this block ATT anyway.
```python
B, T, C = x.size() # x = (Batch,Time,Channel)
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1))
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
k = torch.clamp(k, max=60) # clamp k to avoid overflow
k = torch.exp(k)
kv = k * v
# Compute the W-curve = [e^(-n * e^time_decay), e^(-(n-1) * e^time_decay), ..., 1, e^(time_first)]
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(x.device), self.time_first], dim=-1)
w = torch.exp(self.time_w)
# Use W to mix kv and k respectively. Add K_EPS to wk to avoid divide-by-zero
if RUN_DEVICE == 'cuda':
wkv = TimeX.apply(w, kv, B,C,T, 0)
wk = TimeX.apply(w, k, B,C,T, K_EPS)
else:
w = w[:,-T:].unsqueeze(1)
wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + K_EPS
# The RWKV formula
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv) # final output projection
```
The self.key, self.receptance, self.output matrices are all initialized to zero.
The time_mix, time_decay, time_first vectors are transferred from a smaller trained model.
#### The GPT mode - FFN block
The FFN block has three tricks comparing with the usual GPT:
1. My time_mix trick.
2. The sqReLU from the Primer paper.
3. An extra receptance-gate (similar to the receptance-gate in ATT block).
```python
# Mix x with the previous timestep to produce xk, xr
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# The usual FFN operation
k = self.key(xk)
k = torch.square(torch.relu(k)) # from the Primer paper
kv = self.value(k)
# Apply an extra receptance-gate to kv
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
```
The self.value, self.receptance matrices are all initialized to zero.
========================================================================
### From GPT to RWKV-2 (the formulas) ### From GPT to RWKV-2 (the formulas)
Let F[t] be the system state at t. Let F[t] be the system state at t.

Loading…
Cancel
Save