|
|
|
|
@ -371,6 +371,19 @@ class GPT(nn.Module):
|
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
|
elif 'mlp.weight.weight' in k:
|
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
|
elif 'attn.time_w' in k:
|
|
|
|
|
if self.config.n_head > 1: # different time_w for different head
|
|
|
|
|
for h in range(self.config.n_head):
|
|
|
|
|
curve = torch.tensor([i for i in range(self.config.ctx_len)]) / (self.config.ctx_len - 1)
|
|
|
|
|
curve = torch.pow(curve, 24) # concentrated effect
|
|
|
|
|
curve = curve - torch.mean(curve) + 1 # normalize mean to 1
|
|
|
|
|
mix_strength = 1 - 1.2 * h / (self.config.n_head - 1) # mix_strength from 1 to -0.2
|
|
|
|
|
ww[k][h] = (1 - mix_strength) + curve * mix_strength
|
|
|
|
|
# special tweak because of time_shift
|
|
|
|
|
ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 2] * 2 + 1) / 3
|
|
|
|
|
ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] + 1) / 2
|
|
|
|
|
ww[k][h][self.config.ctx_len - 1] = 1
|
|
|
|
|
# print(k, h, mix_strength, ww[k][h])
|
|
|
|
|
|
|
|
|
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
|
|
|
|
|
|
|
|
|
@ -437,6 +450,6 @@ class GPT(nn.Module):
|
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
|
if targets is not None:
|
|
|
|
|
loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(x.view(-1, x.size(-1)), targets.view(-1)) # try increasing smoothing if you see nan
|
|
|
|
|
loss = LabelSmoothingCrossEntropy(smoothing=5e-5)(x.view(-1, x.size(-1)), targets.view(-1))
|
|
|
|
|
|
|
|
|
|
return x, loss
|
|
|
|
|
|