faster (use torch 1.12.1+cu116 or newer)

main
BlinkDL 3 years ago
parent 2f33901c10
commit 2b4539cd08

@ -152,7 +152,7 @@ def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in
nn.init.normal_(m.weight, mean=0.0, std=-scale) nn.init.normal_(m.weight, mean=0.0, std=-scale)
class RWKV_TimeMix(nn.Module): class RWKV_TimeMix(torch.jit.ScriptModule):
def __init__(self, config, layer_id): def __init__(self, config, layer_id):
super().__init__() super().__init__()
self.layer_id = layer_id self.layer_id = layer_id
@ -196,11 +196,11 @@ class RWKV_TimeMix(nn.Module):
self.receptance.scale_init = 0 self.receptance.scale_init = 0
self.output.scale_init = 0 self.output.scale_init = 0
def forward(self, x): @torch.jit.script_method
B, T, C = x.size() # x = (Batch,Time,Channel) def jit_func(self, x):
# Mix x with the previous timestep to produce xk, xv, xr # Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1)) xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
@ -209,13 +209,21 @@ class RWKV_TimeMix(nn.Module):
k = self.key(xk) k = self.key(xk)
v = self.value(xv) v = self.value(xv)
r = self.receptance(xr) r = self.receptance(xr)
sr = torch.sigmoid(r)
return sr, k, v
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v = self.jit_func(x)
rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
rwkv = self.output(rwkv) rwkv = self.output(rwkv)
return rwkv return rwkv
class RWKV_ChannelMix(nn.Module): class RWKV_ChannelMix(torch.jit.ScriptModule):
def __init__(self, config, layer_id): def __init__(self, config, layer_id):
super().__init__() super().__init__()
self.layer_id = layer_id self.layer_id = layer_id
@ -240,6 +248,7 @@ class RWKV_ChannelMix(nn.Module):
self.value.scale_init = 0 self.value.scale_init = 0
self.receptance.scale_init = 0 self.receptance.scale_init = 0
@torch.jit.script_method
def forward(self, x): def forward(self, x):
xx = self.time_shift(x) xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)

Loading…
Cancel
Save