diff --git a/RWKV-vs-MHA.png b/RWKV-vs-MHA.png index 28ef683..6668850 100644 Binary files a/RWKV-vs-MHA.png and b/RWKV-vs-MHA.png differ diff --git a/src/model.py b/src/model.py index dcd8153..2d9f118 100644 --- a/src/model.py +++ b/src/model.py @@ -12,11 +12,11 @@ logger = logging.getLogger(__name__) ######################################################################################################## # RWKV: RWKV Time-mix + RWKV Channel-mix ######################################################################################################## -# -# fancy initialization of lin & emb layers, for faster convergence -# note it will change ALL lin & emb layers in the module (including token emb & final projection) -# -def RWKV_Init(module, config): + +rwkv_emb_scale = 0.4 # try 0.4 for char-level english. try 1.0 for chinese. +rwkv_layer_decay = 1.0 # decay weights in higher layers. try 0.5 ~ 1.0. + +def RWKV_Init(module, config): # fancy initialization of every lin & emb layer in the module for m in module.modules(): if not isinstance(m, (nn.Linear, nn.Embedding)): continue @@ -36,12 +36,12 @@ def RWKV_Init(module, config): if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection? - scale = 0.4 # 0.4 is a safe choice, 0.8 is better for chinese + scale = rwkv_emb_scale if isinstance(m, nn.Embedding): gain = math.sqrt(max(shape[0], shape[1])) if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb? - scale = 0.4 # 0.4 is a safe choice, 0.8 is better for chinese + scale = rwkv_emb_scale if hasattr(m, 'scale_init'): scale = m.scale_init @@ -90,12 +90,12 @@ class RWKV_TimeMix(nn.Module): self.key = nn.Linear(config.n_embd, config.n_attn) self.value = nn.Linear(config.n_embd, config.n_attn) self.receptance = nn.Linear(config.n_embd, config.n_attn) - + self.output = nn.Linear(config.n_attn, config.n_embd) self.key.scale_init = 0 self.receptance.scale_init = 0 - self.output.scale_init = 1 / pow(1+layer_id, 0.5) # 0.5 ~ 0.7 gives similar results + self.output.scale_init = 1 / pow(1+layer_id, rwkv_layer_decay) # decay weight in higher layers def forward(self, x): B, T, C = x.size() @@ -137,7 +137,7 @@ class RWKV_ChannelMix(nn.Module): self.receptance = nn.Linear(config.n_embd, config.n_embd) self.receptance.scale_init = 0 - self.weight.scale_init = 1 / pow(1+layer_id, 0.5) # 0.5 ~ 0.7 gives similar results + self.weight.scale_init = 1 / pow(1+layer_id, rwkv_layer_decay) # decay weight in higher layers def forward(self, x): B, T, C = x.size() @@ -359,6 +359,7 @@ class GPTConfig: class Block(nn.Module): def __init__(self, config, layer_id): super().__init__() + self.config = config self.ln1 = nn.LayerNorm(config.n_embd) self.ln2 = nn.LayerNorm(config.n_embd)