From 92da17d68f9ecb6ab1b8c5119c222c25c6991c06 Mon Sep 17 00:00:00 2001 From: PENG Bo <33809201+BlinkDL@users.noreply.github.com> Date: Mon, 27 Jun 2022 00:54:04 +0800 Subject: [PATCH] Update README.md --- README.md | 112 +++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 103 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index d3ec092..baade69 100644 --- a/README.md +++ b/README.md @@ -33,9 +33,7 @@ Check https://github.com/BlinkDL/RWKV-v2-RNN-Pile for L24-D1024 and L12-D768 mod Read the inference code in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/blob/main/src/model.py and try using the final hidden state(.xx .aa .bb) as a faithful sentence embedding for other tasks (probably you shall begin with .xx and .aa/.bb (.aa divided by .bb)). -See the release here for a 27M params model on enwik8 with 0.72 BPC(dev). Run run.py in https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN. - -You can even run it in your browser: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng https://blinkdl.github.io/AI-Writer/eng/ (this is using tf.js WASM single-thread mode). +See the release here for a 27M params model on enwik8 with 0.72 BPC(dev). Run run.py in https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN. You can even run it in your browser: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng https://blinkdl.github.io/AI-Writer/eng/ (this is using tf.js WASM single-thread mode). ### Training / Fine-tuning @@ -43,12 +41,6 @@ Training: https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN You will be training the "GPT" version because it's paralleziable and faster to train. I find RWKV-2 can extrapolate, so training with ctxLen 768 can work for ctxLen of several thousand. You can fine-tune the model with longer ctxLen later and it can quickly adapt to longer ctxLens. -My LR schedule for the L24-D1024 RWKV-2: - -![RWKV-v2-430M-Pile-LR](RWKV-v2-430M-Pile-LR.png) - -Fixing NaN or loss spikes: load a previous checkpoint, decrease LR a bit. I find you can decrease the LR faster than GPT, and eventually to 1/50 of LR_max. - **UPDATE: Search for "RWKV v2+" here and change RWKV-2 to PreLN to make it more stable.** Fine-tuning: see https://github.com/BlinkDL/RWKV-v2-RNN-Pile. @@ -109,6 +101,108 @@ I need a better CUDA kernel to (1) pull off maxK so there's need to clamp k to 6 Removing the maxK limitation will also make it easy to clean the state of a KV-V channel, by using a huge K. +======================================================================== + +### Explaining the code for RWKV v2+ GPT mode + +Note: this is for the latest v2+ model. + +#### The GPT mode - overview + +The building blocks of RWKV-2 GPT mode are similar to that of a usual preLN GPT. + +The only difference is an extra LN after embedding. Note you can absorb this LN into the embedding after finishing the training. +```python +x = self.emb(idx) # input: idx = token indices +x = self.ln_emb(x) # extra LN after embedding +x = x + self.att_0(self.ln_att_0(x)) # preLN +x = x + self.ffn_0(self.ln_ffn_0(x)) +... +x = x + self.att_n(self.ln_att_n(x)) +x = x + self.ffn_n(self.ln_ffn_n(x)) +x = self.ln_head(x) # final LN before projection +x = self.head(x) # output: x = logits +``` +It is important to initialize emb to tiny values, such as nn.init.uniform_(a=-1e-4, b=1e-4), to utilize my trick https://github.com/BlinkDL/SmallInitEmb. + +For the 1.5B RWKV-2, I use Adam (no wd, no dropout) optimizer on 8 * A100 40G. + +batchSz = 32 * 896, ctxLen = 896. I am using tf32 so the batchSz is a bit small. + +For the first 15B tokens, LR is fixed at 3e-4, and beta=(0.9, 0.99). + +Then I set beta=(0.9, 0.999), and do an exponential decay of LR, reaching 1e-5 at 332B tokens. + +#### The GPT mode - ATT block + +The RWKV-2 does not have any attention in the usual sense, but we will call this block ATT anyway. +```python +B, T, C = x.size() # x = (Batch,Time,Channel) + +# Mix x with the previous timestep to produce xk, xv, xr +xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1)) +xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) +xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) +xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + +# Use xk, xv, xr to produce k, v, r +k = self.key(xk).transpose(-1, -2) +v = self.value(xv).transpose(-1, -2) +r = self.receptance(xr) +k = torch.clamp(k, max=60) # clamp k to avoid overflow +k = torch.exp(k) +kv = k * v + +# Compute the W-curve = [e^(-n * e^time_decay), e^(-(n-1) * e^time_decay), ..., 1, e^(time_first)] +self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(x.device), self.time_first], dim=-1) +w = torch.exp(self.time_w) + +# Use W to mix kv and k respectively. Add K_EPS to wk to avoid divide-by-zero +if RUN_DEVICE == 'cuda': + wkv = TimeX.apply(w, kv, B,C,T, 0) + wk = TimeX.apply(w, k, B,C,T, K_EPS) +else: + w = w[:,-T:].unsqueeze(1) + wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C) + wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + K_EPS + +# The RWKV formula +rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2) +rwkv = self.output(rwkv) # final output projection +``` + +The self.key, self.receptance, self.output matrices are all initialized to zero. + +The time_mix, time_decay, time_first vectors are transferred from a smaller trained model. + +#### The GPT mode - FFN block + +The FFN block has three tricks comparing with the usual GPT: + +1. My time_mix trick. + +2. The sqReLU from the Primer paper. + +3. An extra receptance-gate (similar to the receptance-gate in ATT block). +```python +# Mix x with the previous timestep to produce xk, xr +xx = self.time_shift(x) +xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) +xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + +# The usual FFN operation +k = self.key(xk) +k = torch.square(torch.relu(k)) # from the Primer paper +kv = self.value(k) + +# Apply an extra receptance-gate to kv +rkv = torch.sigmoid(self.receptance(xr)) * kv +return rkv +``` +The self.value, self.receptance matrices are all initialized to zero. + +======================================================================== + ### From GPT to RWKV-2 (the formulas) Let F[t] be the system state at t.