|
|
|
|
@ -55,6 +55,8 @@ class RWKV_RNN(nn.Module):
|
|
|
|
|
w[x] = w[x].float()
|
|
|
|
|
elif self.FLOAT_MODE == "bf16":
|
|
|
|
|
w[x] = w[x].bfloat16()
|
|
|
|
|
elif self.FLOAT_MODE == "fp16":
|
|
|
|
|
w[x] = w[x].half()
|
|
|
|
|
|
|
|
|
|
w[x].requires_grad = False
|
|
|
|
|
if args.RUN_DEVICE == 'cuda' and x != 'emb.weight':
|
|
|
|
|
@ -106,6 +108,10 @@ class RWKV_RNN(nn.Module):
|
|
|
|
|
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
|
|
|
|
|
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
|
|
|
|
|
state[5*i+0] = x.float()
|
|
|
|
|
elif self.FLOAT_MODE == "fp16":
|
|
|
|
|
xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k)
|
|
|
|
|
xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r)
|
|
|
|
|
state[5*i+0] = x.float()
|
|
|
|
|
else:
|
|
|
|
|
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
|
|
|
|
|
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
|
|
|
|
|
@ -124,6 +130,11 @@ class RWKV_RNN(nn.Module):
|
|
|
|
|
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
|
|
|
|
|
xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r)
|
|
|
|
|
state[5*i+1] = x.float()
|
|
|
|
|
elif self.FLOAT_MODE == "fp16":
|
|
|
|
|
xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k)
|
|
|
|
|
xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v)
|
|
|
|
|
xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r)
|
|
|
|
|
state[5*i+1] = x.float()
|
|
|
|
|
else:
|
|
|
|
|
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
|
|
|
|
|
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
|
|
|
|
|
@ -134,7 +145,7 @@ class RWKV_RNN(nn.Module):
|
|
|
|
|
k = kw @ xk
|
|
|
|
|
v = vw @ xv
|
|
|
|
|
|
|
|
|
|
if self.FLOAT_MODE == "bf16":
|
|
|
|
|
if '16' in self.FLOAT_MODE:
|
|
|
|
|
kk = k.float()
|
|
|
|
|
vv = v.float()
|
|
|
|
|
else:
|
|
|
|
|
@ -158,6 +169,8 @@ class RWKV_RNN(nn.Module):
|
|
|
|
|
state[5*i+4] = p
|
|
|
|
|
if self.FLOAT_MODE == "bf16":
|
|
|
|
|
wkv = (a / b).type(torch.bfloat16)
|
|
|
|
|
elif self.FLOAT_MODE == "fp16":
|
|
|
|
|
wkv = (a / b).half()
|
|
|
|
|
else:
|
|
|
|
|
wkv = a / b
|
|
|
|
|
|
|
|
|
|
|