diff --git a/RWKV-v4/src/model.py b/RWKV-v4/src/model.py index e4faf89..94dc0cd 100644 --- a/RWKV-v4/src/model.py +++ b/RWKV-v4/src/model.py @@ -152,7 +152,7 @@ def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in nn.init.normal_(m.weight, mean=0.0, std=-scale) -class RWKV_TimeMix(nn.Module): +class RWKV_TimeMix(torch.jit.ScriptModule): def __init__(self, config, layer_id): super().__init__() self.layer_id = layer_id @@ -196,11 +196,11 @@ class RWKV_TimeMix(nn.Module): self.receptance.scale_init = 0 self.output.scale_init = 0 - def forward(self, x): - B, T, C = x.size() # x = (Batch,Time,Channel) + @torch.jit.script_method + def jit_func(self, x): # Mix x with the previous timestep to produce xk, xv, xr - xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + xx = self.time_shift(x) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) @@ -209,13 +209,21 @@ class RWKV_TimeMix(nn.Module): k = self.key(xk) v = self.value(xv) r = self.receptance(xr) + sr = torch.sigmoid(r) + + return sr, k, v + + def forward(self, x): + B, T, C = x.size() # x = (Batch,Time,Channel) + + sr, k, v = self.jit_func(x) - rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) + rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) rwkv = self.output(rwkv) return rwkv -class RWKV_ChannelMix(nn.Module): +class RWKV_ChannelMix(torch.jit.ScriptModule): def __init__(self, config, layer_id): super().__init__() self.layer_id = layer_id @@ -240,6 +248,7 @@ class RWKV_ChannelMix(nn.Module): self.value.scale_init = 0 self.receptance.scale_init = 0 + @torch.jit.script_method def forward(self, x): xx = self.time_shift(x) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)