|
|
|
@ -69,7 +69,7 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
if h < config.n_head - 1:
|
|
|
|
if h < config.n_head - 1:
|
|
|
|
decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
|
|
|
|
decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
decay_speed = 0
|
|
|
|
decay_speed = 0.0
|
|
|
|
ww[h] = torch.exp(curve * decay_speed)
|
|
|
|
ww[h] = torch.exp(curve * decay_speed)
|
|
|
|
# print('layer', layer_id, 'head', h, 'decay_speed', round(decay_speed, 4), ww[h][:5].numpy(), '...', ww[h][-5:].numpy())
|
|
|
|
# print('layer', layer_id, 'head', h, 'decay_speed', round(decay_speed, 4), ww[h][:5].numpy(), '...', ww[h][-5:].numpy())
|
|
|
|
self.time_w = nn.Parameter(ww)
|
|
|
|
self.time_w = nn.Parameter(ww)
|
|
|
|
|