|
|
|
|
@ -152,7 +152,7 @@ def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in
|
|
|
|
|
nn.init.normal_(m.weight, mean=0.0, std=-scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RWKV_TimeMix(nn.Module):
|
|
|
|
|
class RWKV_TimeMix(torch.jit.ScriptModule):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
@ -196,11 +196,11 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
self.receptance.scale_init = 0
|
|
|
|
|
self.output.scale_init = 0
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size() # x = (Batch,Time,Channel)
|
|
|
|
|
@torch.jit.script_method
|
|
|
|
|
def jit_func(self, x):
|
|
|
|
|
|
|
|
|
|
# Mix x with the previous timestep to produce xk, xv, xr
|
|
|
|
|
xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
|
|
|
|
xx = self.time_shift(x)
|
|
|
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
|
|
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
|
|
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
|
|
|
@ -209,13 +209,21 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
k = self.key(xk)
|
|
|
|
|
v = self.value(xv)
|
|
|
|
|
r = self.receptance(xr)
|
|
|
|
|
sr = torch.sigmoid(r)
|
|
|
|
|
|
|
|
|
|
return sr, k, v
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size() # x = (Batch,Time,Channel)
|
|
|
|
|
|
|
|
|
|
sr, k, v = self.jit_func(x)
|
|
|
|
|
|
|
|
|
|
rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
|
|
|
|
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
|
|
|
|
rwkv = self.output(rwkv)
|
|
|
|
|
return rwkv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RWKV_ChannelMix(nn.Module):
|
|
|
|
|
class RWKV_ChannelMix(torch.jit.ScriptModule):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
@ -240,6 +248,7 @@ class RWKV_ChannelMix(nn.Module):
|
|
|
|
|
self.value.scale_init = 0
|
|
|
|
|
self.receptance.scale_init = 0
|
|
|
|
|
|
|
|
|
|
@torch.jit.script_method
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
xx = self.time_shift(x)
|
|
|
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
|
|
|
|