|
|
|
@ -237,7 +237,7 @@ class RotaryMHA_Plus(nn.Module):
|
|
|
|
# The GPT Model with our blocks
|
|
|
|
# The GPT Model with our blocks
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
class LabelSmoothingCrossEntropy(nn.Module): # might be able to avoid nan loss
|
|
|
|
class LabelSmoothingCrossEntropy(nn.Module): # can avoid nan loss
|
|
|
|
def __init__(self, smoothing=0.0):
|
|
|
|
def __init__(self, smoothing=0.0):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.confidence = 1.0 - smoothing
|
|
|
|
self.confidence = 1.0 - smoothing
|
|
|
|
@ -251,6 +251,29 @@ class LabelSmoothingCrossEntropy(nn.Module): # might be able to avoid nan loss
|
|
|
|
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
|
|
|
|
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
|
|
|
|
return torch.mean(torch.sum(-true_dist * pred, dim=-1))
|
|
|
|
return torch.mean(torch.sum(-true_dist * pred, dim=-1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, d):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.dd = d ** (-1. / 2)
|
|
|
|
|
|
|
|
self.weight = nn.Parameter(torch.ones(d))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
norm_x = x.norm(2, dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
x_normed = x / (norm_x * self.dd + 1e-12)
|
|
|
|
|
|
|
|
return self.weight * x_normed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SimpleRMSNorm(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, d):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.dd = d ** (-1. / 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
norm_x = x.norm(2, dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
x_normed = x / (norm_x * self.dd + 1e-12)
|
|
|
|
|
|
|
|
return x_normed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
class GPTConfig:
|
|
|
|
class GPTConfig:
|
|
|
|
def __init__(self, vocab_size, ctx_size, **kwargs):
|
|
|
|
def __init__(self, vocab_size, ctx_size, **kwargs):
|
|
|
|
self.vocab_size = vocab_size
|
|
|
|
self.vocab_size = vocab_size
|
|
|
|
@ -266,6 +289,8 @@ class Block(nn.Module):
|
|
|
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
|
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
|
|
|
|
self.ln1 = nn.Identity() # remove first LayerNorm -> faster convergence for deep models
|
|
|
|
|
|
|
|
self.ln2 = SimpleRMSNorm(config.n_embd) # SimpleRMSNorm is good enough for RWKV -> less parameters
|
|
|
|
self.attn = RWKV_TimeMix(config)
|
|
|
|
self.attn = RWKV_TimeMix(config)
|
|
|
|
self.mlp = RWKV_ChannelMix(config)
|
|
|
|
self.mlp = RWKV_ChannelMix(config)
|
|
|
|
elif config.model_type == 'RotaryMHA':
|
|
|
|
elif config.model_type == 'RotaryMHA':
|
|
|
|
@ -278,6 +303,7 @@ class Block(nn.Module):
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
x = x + self.attn(self.ln1(x))
|
|
|
|
x = x + self.attn(self.ln1(x))
|
|
|
|
x = x + self.mlp(self.ln2(x))
|
|
|
|
x = x + self.mlp(self.ln2(x))
|
|
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class GPT(nn.Module):
|
|
|
|
class GPT(nn.Module):
|
|
|
|
@ -288,7 +314,11 @@ class GPT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
|
|
|
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
|
|
|
|
|
|
|
|
|
|
|
self.ln_f = nn.LayerNorm(config.n_embd)
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
|
|
|
|
self.ln_f = SimpleRMSNorm(config.n_embd) # SimpleRMSNorm is good enough for RWKV -> less parameters
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.ln_f = nn.LayerNorm(config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
self.ctx_size = config.ctx_size
|
|
|
|
self.ctx_size = config.ctx_size
|
|
|
|
@ -311,7 +341,7 @@ class GPT(nn.Module):
|
|
|
|
no_decay = set()
|
|
|
|
no_decay = set()
|
|
|
|
|
|
|
|
|
|
|
|
whitelist_weight_modules = (nn.Linear, )
|
|
|
|
whitelist_weight_modules = (nn.Linear, )
|
|
|
|
blacklist_weight_modules = (nn.LayerNorm, nn.Embedding)
|
|
|
|
blacklist_weight_modules = (RMSNorm, nn.LayerNorm, nn.Embedding)
|
|
|
|
for mn, m in self.named_modules():
|
|
|
|
for mn, m in self.named_modules():
|
|
|
|
for pn, p in m.named_parameters():
|
|
|
|
for pn, p in m.named_parameters():
|
|
|
|
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
|
|
|
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
|
|
|
|