|
|
|
@ -12,11 +12,11 @@ logger = logging.getLogger(__name__)
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
|
|
|
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
#
|
|
|
|
|
|
|
|
# fancy initialization of lin & emb layers, for faster convergence
|
|
|
|
rwkv_emb_scale = 0.4 # try 0.4 for char-level english. try 1.0 for chinese.
|
|
|
|
# note it will change ALL lin & emb layers in the module (including token emb & final projection)
|
|
|
|
rwkv_layer_decay = 1.0 # decay weights in higher layers. try 0.5 ~ 1.0.
|
|
|
|
#
|
|
|
|
|
|
|
|
def RWKV_Init(module, config):
|
|
|
|
def RWKV_Init(module, config): # fancy initialization of every lin & emb layer in the module
|
|
|
|
for m in module.modules():
|
|
|
|
for m in module.modules():
|
|
|
|
if not isinstance(m, (nn.Linear, nn.Embedding)):
|
|
|
|
if not isinstance(m, (nn.Linear, nn.Embedding)):
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
@ -36,12 +36,12 @@ def RWKV_Init(module, config):
|
|
|
|
if shape[0] > shape[1]:
|
|
|
|
if shape[0] > shape[1]:
|
|
|
|
gain = math.sqrt(shape[0] / shape[1])
|
|
|
|
gain = math.sqrt(shape[0] / shape[1])
|
|
|
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
|
|
|
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
|
|
|
|
scale = 0.4 # 0.4 is a safe choice, 0.8 is better for chinese
|
|
|
|
scale = rwkv_emb_scale
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(m, nn.Embedding):
|
|
|
|
if isinstance(m, nn.Embedding):
|
|
|
|
gain = math.sqrt(max(shape[0], shape[1]))
|
|
|
|
gain = math.sqrt(max(shape[0], shape[1]))
|
|
|
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
|
|
|
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
|
|
|
|
scale = 0.4 # 0.4 is a safe choice, 0.8 is better for chinese
|
|
|
|
scale = rwkv_emb_scale
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(m, 'scale_init'):
|
|
|
|
if hasattr(m, 'scale_init'):
|
|
|
|
scale = m.scale_init
|
|
|
|
scale = m.scale_init
|
|
|
|
@ -90,12 +90,12 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
|
|
|
|
|
|
|
self.output = nn.Linear(config.n_attn, config.n_embd)
|
|
|
|
self.output = nn.Linear(config.n_attn, config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
self.key.scale_init = 0
|
|
|
|
self.key.scale_init = 0
|
|
|
|
self.receptance.scale_init = 0
|
|
|
|
self.receptance.scale_init = 0
|
|
|
|
self.output.scale_init = 1 / pow(1+layer_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
self.output.scale_init = 1 / pow(1+layer_id, rwkv_layer_decay) # decay weight in higher layers
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
B, T, C = x.size()
|
|
|
|
B, T, C = x.size()
|
|
|
|
@ -137,7 +137,7 @@ class RWKV_ChannelMix(nn.Module):
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
self.receptance.scale_init = 0
|
|
|
|
self.receptance.scale_init = 0
|
|
|
|
self.weight.scale_init = 1 / pow(1+layer_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
self.weight.scale_init = 1 / pow(1+layer_id, rwkv_layer_decay) # decay weight in higher layers
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
B, T, C = x.size()
|
|
|
|
B, T, C = x.size()
|
|
|
|
@ -359,6 +359,7 @@ class GPTConfig:
|
|
|
|
class Block(nn.Module):
|
|
|
|
class Block(nn.Module):
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
|
|
|
self.ln1 = nn.LayerNorm(config.n_embd)
|
|
|
|
self.ln1 = nn.LayerNorm(config.n_embd)
|
|
|
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
|
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
|
|
|