|
|
|
|
@ -63,20 +63,15 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
self.head_size = config.n_attn // config.n_head
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): # initial time_w curves for better convergence
|
|
|
|
|
ww = torch.zeros(config.n_head, config.ctx_len)
|
|
|
|
|
curve = torch.tensor([0.9 ** (config.ctx_len - 1 - i) for i in range(config.ctx_len)])
|
|
|
|
|
curve = curve * 2 + 0.7
|
|
|
|
|
ww = torch.ones(config.n_head, config.ctx_len)
|
|
|
|
|
curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
|
|
|
|
|
for h in range(config.n_head):
|
|
|
|
|
if config.n_head > 1:
|
|
|
|
|
mix_strength = 1 - 1.2 * h / (config.n_head - 1) # mix_strength from 1 to -0.2
|
|
|
|
|
if h < config.n_head - 1:
|
|
|
|
|
decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
|
|
|
|
|
else:
|
|
|
|
|
mix_strength = 0.5
|
|
|
|
|
ww[h] = (1 - mix_strength) + curve * mix_strength
|
|
|
|
|
# special tweaks because of time_shift
|
|
|
|
|
ww[h][config.ctx_len - 3] = (ww[h][config.ctx_len - 3] * 2 + 1) / 3
|
|
|
|
|
ww[h][config.ctx_len - 2] = (ww[h][config.ctx_len - 2] * 1 + 2) / 3
|
|
|
|
|
ww[h][config.ctx_len - 1] = 1
|
|
|
|
|
# print(h, mix_strength, ww[h])
|
|
|
|
|
decay_speed = 0
|
|
|
|
|
ww[h] = torch.exp(curve * decay_speed)
|
|
|
|
|
# print('layer', layer_id, 'head', h, 'decay_speed', round(decay_speed, 4), ww[h][:5].numpy(), '...', ww[h][-5:].numpy())
|
|
|
|
|
self.time_w = nn.Parameter(ww)
|
|
|
|
|
|
|
|
|
|
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
|
|
|
|
|
|