Update README.md

main
PENG Bo 4 years ago committed by GitHub
parent a4d3a44e13
commit 812a7d76cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,6 +23,8 @@ All of the trained models will be open-source. Inference is very fast (only matr
## Quick start
### Inference
Check https://github.com/BlinkDL/RWKV-v2-RNN-Pile for L24-D1024 and L12-D768 models trained on the Pile (and the latest code). It's very fast on CPU (the default mode).
Read the inference code in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/blob/main/src/model.py and try using the final hidden state.xx .aa .bb) as a faithful sentence embedding for other tasks.
@ -31,12 +33,19 @@ See the release here for a 27M params model on enwik8 with 0.72 BPC(dev). Run ru
You can even run it in your browser: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng https://blinkdl.github.io/AI-Writer/eng/ (this is using tf.js WASM single-thread mode).
Fine-tuning & training (for a small model, try 4e-5 lr, and decay to 1e-5 when it plateaus):
https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN
### Training / Fine-tuning
Training: https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN
You will be training the "GPT" version because it's paralleziable and faster to train. I find RWKV-2 can extrapolate, so training with ctxLen 768 can work for ctxLen of several thousand. You can fine-tune the model with longer ctxLen later and it can quickly adapt to longer ctxLens.
I find it might be nice to make the model stay on a mid-lr for a long period, because in theory that's where most learning shall happen. For example: constant 6e-4 for 10% of steps, 6e-4 to 1e-4 in 15% of steps, stays at 1e-4 for 25% of steps (I monitor the loss and decay the lr when it plateaus or hits a NaN), then 1e-4 to 1e-5 in 50% of steps.
Fine-tuning: for a small model, try 4e-5 lr, and decay to 1e-5 when it plateaus.
**Important**: For fine-tuning the Pile model, change K_EPS from 1e-16 to 1e-9 (to avoid NaN) in https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v2-RNN/src/model.py and https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v2-RNN/src/model_run.py and disable HeadQK (so it's a pure RNN). You can compare the output with the latest code ( https://github.com/BlinkDL/RWKV-v2-RNN-Pile ) to verify it.
**Fixing NaN or loss spikes**: load a previous checkpoint, decrease LR a bit, and increase beta2 (try 0.99 -> 0.999 -> 0.9999 as time goes on).
**Fixing NaN or loss spikes**: load a previous checkpoint, decrease LR a bit.
## How it works
@ -115,8 +124,6 @@ kv / k is the memory mechanism. The token with high k can be remembered for a lo
RWKV v2 is parallelizable because the time-decay of each channel is data-independent (and trainable). For example, in usual RNN you can adjust the time-decay of a channel from say 0.8 to 0.5 (these are called "gates"), while in RWKV v2 you simply move the information from a W-0.8-channel to a W-0.5-channel to achieve the same effect.
I find it might be nice to make the model stay on a mid-lr for a long period, because in theory that's where most learning shall happen. For example: constant 6e-4 for 10% of steps, 6e-4 to 1e-4 in 15% of steps, stays at 1e-4 for 25% of steps (actually I monitor the loss and decay the lr when it plateaus), then 1e-4 to 1e-5 in 50% of steps.
## How to sample a large dataset (for training)
I am using a trick to sample the Pile deterministically yet randomly enough.

Loading…
Cancel
Save