faster (use torch 1.12.1+cu116 or newer)

main
BlinkDL 3 years ago
parent 2f33901c10
commit 2b4539cd08

@ -152,7 +152,7 @@ def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in
nn.init.normal_(m.weight, mean=0.0, std=-scale)
class RWKV_TimeMix(nn.Module):
class RWKV_TimeMix(torch.jit.ScriptModule):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
@ -196,11 +196,11 @@ class RWKV_TimeMix(nn.Module):
self.receptance.scale_init = 0
self.output.scale_init = 0
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
@torch.jit.script_method
def jit_func(self, x):
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1))
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
@ -209,13 +209,21 @@ class RWKV_TimeMix(nn.Module):
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)
return sr, k, v
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v = self.jit_func(x)
rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
rwkv = self.output(rwkv)
return rwkv
class RWKV_ChannelMix(nn.Module):
class RWKV_ChannelMix(torch.jit.ScriptModule):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
@ -240,6 +248,7 @@ class RWKV_ChannelMix(nn.Module):
self.value.scale_init = 0
self.receptance.scale_init = 0
@torch.jit.script_method
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)

Loading…
Cancel
Save