From 64c9015dd13c6caadc1147a03c4c1bb251978f42 Mon Sep 17 00:00:00 2001 From: PENG Bo <33809201+BlinkDL@users.noreply.github.com> Date: Wed, 11 May 2022 02:38:26 +0800 Subject: [PATCH] Update README.md --- README.md | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 8916ebf..91d6c15 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,27 @@ # The RWKV Language Model -## RWKV v2: RNN with Transformer Performance +## RWKV v2: RNN with Transformer-level Performance -RWKV v2 is a RNN which can also be directly trained like a GPT transformer (parallelizable). You only need x_t, a_t, b_t of position t to compute the vectors for position t+1. Hence it can be 100x faster than GPT, and 100x more VRAM friendly, and you get a free sentence embedding. +RWKV v2 is a RNN with Transformer-level performance, which can also be directly trained like a GPT transformer (parallelizable). And it's attention-free. You only need x_t, a_t, b_t of position t to compute the vectors for position t+1. + +So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, and fast training. Moreover you get a free sentence embedding. I am training it on the Pile: https://github.com/BlinkDL/RWKV-v2-RNN-Pile -See the release for a **27M params model on enwik8 with 0.72 BPC(dev)**. +It might reach GPT-Neo performance within 100B tokens: + +![RWKV-v2-430M-Pile](RWKV-v2-430M-Pile.png) -![RWKV-v2-RNN](RWKV-v2-RNN-run.png) +See the release for a 27M params model on enwik8 with 0.72 BPC(dev). ## How it works +RWKV is inspired by Apple's AFT (https://arxiv.org/abs/2105.14103). + +The pseudocode (execution from top to bottom): + +![RWKV-v2-RNN](RWKV-v2-RNN.png) + The a b c d factors work together to build a time-decay curve: X, 1, W, W^2, W^3, ... Write out the formulas for "token at pos 2" and "token at pos 3" and you will get the idea: @@ -24,11 +34,7 @@ RWKV v2 is parallelizable because the time-decay of each channel is data-indepen It's also using my SmallInitEmb trick https://github.com/BlinkDL/SmallInitEmb (applicable to all transformers), and a custom CUDA kernel https://github.com/BlinkDL/RWKV-CUDA . -I find it might be nice to make the model stay on a mid-lr for a long period, because in theory that's where most learning shall happen. For example: 6e-4 to 1e-4 in 15% of steps, stays on 1e-4 for 60% of steps (actually I monitor the loss and decay the lr when it plateaus), then 1e-4 to 1e-5 in 25% of steps. - -The pseudocode (execution from top to bottom): - -![RWKV-v2-RNN](RWKV-v2-RNN.png) +I find it might be nice to make the model stay on a mid-lr for a long period, because in theory that's where most learning shall happen. For example: 6e-4 to 1e-4 in 20% of steps, stays at 1e-4 for 30% of steps (actually I monitor the loss and decay the lr when it plateaus), then 1e-4 to 1e-5 in 50% of steps. # Better Learning Rate Schedule via Variantional Method of Loss Curve