improve RWKV time_w initialization

main
BlinkDL 4 years ago
parent 1ea53a2f03
commit 4fd8716976

@ -371,6 +371,19 @@ class GPT(nn.Module):
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
elif 'mlp.weight.weight' in k:
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
elif 'attn.time_w' in k:
if self.config.n_head > 1: # different time_w for different head
for h in range(self.config.n_head):
curve = torch.tensor([i for i in range(self.config.ctx_len)]) / (self.config.ctx_len - 1)
curve = torch.pow(curve, 24) # concentrated effect
curve = curve - torch.mean(curve) + 1 # normalize mean to 1
mix_strength = 1 - 1.2 * h / (self.config.n_head - 1) # mix_strength from 1 to -0.2
ww[k][h] = (1 - mix_strength) + curve * mix_strength
# special tweak because of time_shift
ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 2] * 2 + 1) / 3
ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] + 1) / 2
ww[k][h][self.config.ctx_len - 1] = 1
# print(k, h, mix_strength, ww[k][h])
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
@ -437,6 +450,6 @@ class GPT(nn.Module):
loss = None
if targets is not None:
loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(x.view(-1, x.size(-1)), targets.view(-1)) # try increasing smoothing if you see nan
loss = LabelSmoothingCrossEntropy(smoothing=5e-5)(x.view(-1, x.size(-1)), targets.view(-1))
return x, loss

Loading…
Cancel
Save