diff --git a/src/model.py b/src/model.py index 71bc098..1eeb086 100644 --- a/src/model.py +++ b/src/model.py @@ -69,7 +69,7 @@ class RWKV_TimeMix(nn.Module): if h < config.n_head - 1: decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1)) else: - decay_speed = 0 + decay_speed = 0.0 ww[h] = torch.exp(curve * decay_speed) # print('layer', layer_id, 'head', h, 'decay_speed', round(decay_speed, 4), ww[h][:5].numpy(), '...', ww[h][-5:].numpy()) self.time_w = nn.Parameter(ww)