diff --git a/src/model.py b/src/model.py index 41808a3..eb2724a 100644 --- a/src/model.py +++ b/src/model.py @@ -371,6 +371,19 @@ class GPT(nn.Module): ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results elif 'mlp.weight.weight' in k: ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results + elif 'attn.time_w' in k: + if self.config.n_head > 1: # different time_w for different head + for h in range(self.config.n_head): + curve = torch.tensor([i for i in range(self.config.ctx_len)]) / (self.config.ctx_len - 1) + curve = torch.pow(curve, 24) # concentrated effect + curve = curve - torch.mean(curve) + 1 # normalize mean to 1 + mix_strength = 1 - 1.2 * h / (self.config.n_head - 1) # mix_strength from 1 to -0.2 + ww[k][h] = (1 - mix_strength) + curve * mix_strength + # special tweak because of time_shift + ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 2] * 2 + 1) / 3 + ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] + 1) / 2 + ww[k][h][self.config.ctx_len - 1] = 1 + # print(k, h, mix_strength, ww[k][h]) logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) @@ -437,6 +450,6 @@ class GPT(nn.Module): loss = None if targets is not None: - loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(x.view(-1, x.size(-1)), targets.view(-1)) # try increasing smoothing if you see nan + loss = LabelSmoothingCrossEntropy(smoothing=5e-5)(x.view(-1, x.size(-1)), targets.view(-1)) return x, loss