better init for RWKV

main
BlinkDL 4 years ago
parent 619ed00e4b
commit 710d3e34b7

@ -63,20 +63,15 @@ class RWKV_TimeMix(nn.Module):
self.head_size = config.n_attn // config.n_head self.head_size = config.n_attn // config.n_head
with torch.no_grad(): # initial time_w curves for better convergence with torch.no_grad(): # initial time_w curves for better convergence
ww = torch.zeros(config.n_head, config.ctx_len) ww = torch.ones(config.n_head, config.ctx_len)
curve = torch.tensor([0.9 ** (config.ctx_len - 1 - i) for i in range(config.ctx_len)]) curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
curve = curve * 2 + 0.7
for h in range(config.n_head): for h in range(config.n_head):
if config.n_head > 1: if h < config.n_head - 1:
mix_strength = 1 - 1.2 * h / (config.n_head - 1) # mix_strength from 1 to -0.2 decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
else: else:
mix_strength = 0.5 decay_speed = 0
ww[h] = (1 - mix_strength) + curve * mix_strength ww[h] = torch.exp(curve * decay_speed)
# special tweaks because of time_shift # print('layer', layer_id, 'head', h, 'decay_speed', round(decay_speed, 4), ww[h][:5].numpy(), '...', ww[h][-5:].numpy())
ww[h][config.ctx_len - 3] = (ww[h][config.ctx_len - 3] * 2 + 1) / 3
ww[h][config.ctx_len - 2] = (ww[h][config.ctx_len - 2] * 1 + 2) / 3
ww[h][config.ctx_len - 1] = 1
# print(h, mix_strength, ww[h])
self.time_w = nn.Parameter(ww) self.time_w = nn.Parameter(ww)
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len)) self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))

Loading…
Cancel
Save