+ new comparison

main
BlinkDL 4 years ago
parent 04852faf04
commit 4ffd8f1b76

Binary file not shown.

Before

Width:  |  Height:  |  Size: 85 KiB

After

Width:  |  Height:  |  Size: 93 KiB

@ -12,11 +12,11 @@ logger = logging.getLogger(__name__)
######################################################################################################## ########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix # RWKV: RWKV Time-mix + RWKV Channel-mix
######################################################################################################## ########################################################################################################
#
# fancy initialization of lin & emb layers, for faster convergence rwkv_emb_scale = 0.4 # try 0.4 for char-level english. try 1.0 for chinese.
# note it will change ALL lin & emb layers in the module (including token emb & final projection) rwkv_layer_decay = 1.0 # decay weights in higher layers. try 0.5 ~ 1.0.
#
def RWKV_Init(module, config): def RWKV_Init(module, config): # fancy initialization of every lin & emb layer in the module
for m in module.modules(): for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)): if not isinstance(m, (nn.Linear, nn.Embedding)):
continue continue
@ -36,12 +36,12 @@ def RWKV_Init(module, config):
if shape[0] > shape[1]: if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1]) gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection? if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
scale = 0.4 # 0.4 is a safe choice, 0.8 is better for chinese scale = rwkv_emb_scale
if isinstance(m, nn.Embedding): if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1])) gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb? if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
scale = 0.4 # 0.4 is a safe choice, 0.8 is better for chinese scale = rwkv_emb_scale
if hasattr(m, 'scale_init'): if hasattr(m, 'scale_init'):
scale = m.scale_init scale = m.scale_init
@ -90,12 +90,12 @@ class RWKV_TimeMix(nn.Module):
self.key = nn.Linear(config.n_embd, config.n_attn) self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn) self.value = nn.Linear(config.n_embd, config.n_attn)
self.receptance = nn.Linear(config.n_embd, config.n_attn) self.receptance = nn.Linear(config.n_embd, config.n_attn)
self.output = nn.Linear(config.n_attn, config.n_embd) self.output = nn.Linear(config.n_attn, config.n_embd)
self.key.scale_init = 0 self.key.scale_init = 0
self.receptance.scale_init = 0 self.receptance.scale_init = 0
self.output.scale_init = 1 / pow(1+layer_id, 0.5) # 0.5 ~ 0.7 gives similar results self.output.scale_init = 1 / pow(1+layer_id, rwkv_layer_decay) # decay weight in higher layers
def forward(self, x): def forward(self, x):
B, T, C = x.size() B, T, C = x.size()
@ -137,7 +137,7 @@ class RWKV_ChannelMix(nn.Module):
self.receptance = nn.Linear(config.n_embd, config.n_embd) self.receptance = nn.Linear(config.n_embd, config.n_embd)
self.receptance.scale_init = 0 self.receptance.scale_init = 0
self.weight.scale_init = 1 / pow(1+layer_id, 0.5) # 0.5 ~ 0.7 gives similar results self.weight.scale_init = 1 / pow(1+layer_id, rwkv_layer_decay) # decay weight in higher layers
def forward(self, x): def forward(self, x):
B, T, C = x.size() B, T, C = x.size()
@ -359,6 +359,7 @@ class GPTConfig:
class Block(nn.Module): class Block(nn.Module):
def __init__(self, config, layer_id): def __init__(self, config, layer_id):
super().__init__() super().__init__()
self.config = config
self.ln1 = nn.LayerNorm(config.n_embd) self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd) self.ln2 = nn.LayerNorm(config.n_embd)

Loading…
Cancel
Save