@ -69,7 +69,7 @@ class RWKV_TimeMix(nn.Module):
if h < config.n_head - 1:
decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
else:
decay_speed = 0
decay_speed = 0.0
ww[h] = torch.exp(curve * decay_speed)
# print('layer', layer_id, 'head', h, 'decay_speed', round(decay_speed, 4), ww[h][:5].numpy(), '...', ww[h][-5:].numpy())
self.time_w = nn.Parameter(ww)