From 710d3e34b7651d34c391f71e3c0717b92eeef148 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Sat, 28 Aug 2021 03:30:54 +0800 Subject: [PATCH] better init for RWKV --- src/model.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/model.py b/src/model.py index c06bd1d..c97e82f 100644 --- a/src/model.py +++ b/src/model.py @@ -63,20 +63,15 @@ class RWKV_TimeMix(nn.Module): self.head_size = config.n_attn // config.n_head with torch.no_grad(): # initial time_w curves for better convergence - ww = torch.zeros(config.n_head, config.ctx_len) - curve = torch.tensor([0.9 ** (config.ctx_len - 1 - i) for i in range(config.ctx_len)]) - curve = curve * 2 + 0.7 + ww = torch.ones(config.n_head, config.ctx_len) + curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance for h in range(config.n_head): - if config.n_head > 1: - mix_strength = 1 - 1.2 * h / (config.n_head - 1) # mix_strength from 1 to -0.2 + if h < config.n_head - 1: + decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1)) else: - mix_strength = 0.5 - ww[h] = (1 - mix_strength) + curve * mix_strength - # special tweaks because of time_shift - ww[h][config.ctx_len - 3] = (ww[h][config.ctx_len - 3] * 2 + 1) / 3 - ww[h][config.ctx_len - 2] = (ww[h][config.ctx_len - 2] * 1 + 2) / 3 - ww[h][config.ctx_len - 1] = 1 - # print(h, mix_strength, ww[h]) + decay_speed = 0 + ww[h] = torch.exp(curve * decay_speed) + # print('layer', layer_id, 'head', h, 'decay_speed', round(decay_speed, 4), ww[h][:5].numpy(), '...', ww[h][-5:].numpy()) self.time_w = nn.Parameter(ww) self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))