|
|
|
@ -80,7 +80,7 @@ class RWKV_ChannelMix(nn.Module):
|
|
|
|
v = self.value(x)
|
|
|
|
v = self.value(x)
|
|
|
|
r = self.receptance(x)
|
|
|
|
r = self.receptance(x)
|
|
|
|
|
|
|
|
|
|
|
|
wkv = self.weight(F.mish(k) * v) # mish is a bit better than gelu
|
|
|
|
wkv = self.weight(F.mish(k) * v) # mish is a bit better than gelu
|
|
|
|
y = torch.sigmoid(r) * wkv
|
|
|
|
y = torch.sigmoid(r) * wkv
|
|
|
|
|
|
|
|
|
|
|
|
return y
|
|
|
|
return y
|
|
|
|
@ -292,8 +292,9 @@ class Block(nn.Module):
|
|
|
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
|
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
self.ln1 = nn.Identity() # remove first LayerNorm -> faster convergence for deep models
|
|
|
|
self.ln1 = nn.Identity()
|
|
|
|
self.ln2 = SimpleRMSNorm(config.n_embd) # SimpleRMSNorm is good enough for RWKV -> less parameters
|
|
|
|
# self.ln1 = SimpleRMSNorm(config.n_embd) # turn on this if you see nan in large RWKV models
|
|
|
|
|
|
|
|
self.ln2 = SimpleRMSNorm(config.n_embd)
|
|
|
|
self.attn = RWKV_TimeMix(config)
|
|
|
|
self.attn = RWKV_TimeMix(config)
|
|
|
|
self.mlp = RWKV_ChannelMix(config)
|
|
|
|
self.mlp = RWKV_ChannelMix(config)
|
|
|
|
elif config.model_type == 'RotaryMHA':
|
|
|
|
elif config.model_type == 'RotaryMHA':
|
|
|
|
@ -319,7 +320,7 @@ class GPT(nn.Module):
|
|
|
|
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
|
|
|
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
|
|
|
|
|
|
|
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
self.ln_f = SimpleRMSNorm(config.n_embd) # SimpleRMSNorm is good enough for RWKV -> less parameters
|
|
|
|
self.ln_f = SimpleRMSNorm(config.n_embd)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.ln_f = nn.LayerNorm(config.n_embd)
|
|
|
|
self.ln_f = nn.LayerNorm(config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
@ -328,6 +329,16 @@ class GPT(nn.Module):
|
|
|
|
self.ctx_size = config.ctx_size
|
|
|
|
self.ctx_size = config.ctx_size
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.model_type == 'RWKV':
|
|
|
|
|
|
|
|
ww = self.state_dict()
|
|
|
|
|
|
|
|
for k in ww: # reduce weight to avoid nan
|
|
|
|
|
|
|
|
if 'receptance.weight' in k:
|
|
|
|
|
|
|
|
ww[k] /= math.pow(config.n_embd, 0.5)
|
|
|
|
|
|
|
|
elif 'key.weight' in k:
|
|
|
|
|
|
|
|
ww[k] /= math.pow(config.n_embd, 0.25)
|
|
|
|
|
|
|
|
elif 'value.weight' in k:
|
|
|
|
|
|
|
|
ww[k] /= math.pow(config.n_embd, 0.25)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
|
|
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
|
|
|
|
|
|
|
|
|
|
|
def get_ctx_size(self):
|
|
|
|
def get_ctx_size(self):
|
|
|
|
@ -335,14 +346,17 @@ class GPT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, module):
|
|
|
|
def _init_weights(self, module):
|
|
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
|
|
# if self.config.model_type == 'RWKV' and isinstance(module, nn.Linear):
|
|
|
|
if self.config.model_type == 'RWKV':
|
|
|
|
# gain_layer = min(3, module.weight.shape[0] / module.weight.shape[1])
|
|
|
|
if isinstance(module, nn.Linear):
|
|
|
|
# depth_factor = min(1, 1 / math.sqrt(self.config.n_layer / 5))
|
|
|
|
gain_layer = min(3, module.weight.shape[0] / module.weight.shape[1])
|
|
|
|
# nn.init.orthogonal_(module.weight, gain = gain_layer * depth_factor) # will nan for large models
|
|
|
|
depth_factor = 1 # min(1, 1 / math.sqrt(self.config.n_layer / 5))
|
|
|
|
# else:
|
|
|
|
nn.init.orthogonal_(module.weight, gain = gain_layer * depth_factor)
|
|
|
|
module.weight.data.normal_(mean=0.0, std=0.01)
|
|
|
|
else:
|
|
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
|
|
nn.init.orthogonal_(module.weight, gain = 1.0)
|
|
|
|
module.bias.data.zero_()
|
|
|
|
else:
|
|
|
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=0.01)
|
|
|
|
|
|
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
|
|
|
|
|
|
module.bias.data.zero_()
|
|
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self, train_config):
|
|
|
|
def configure_optimizers(self, train_config):
|
|
|
|
# separate out all parameters to those that will and won't experience regularizing weight decay
|
|
|
|
# separate out all parameters to those that will and won't experience regularizing weight decay
|
|
|
|
|