|
|
|
|
@ -109,22 +109,21 @@ class RWKV_TimeMix(MyModule):
|
|
|
|
|
self.ctx_len = args.ctx_len
|
|
|
|
|
self.n_embd = args.n_embd
|
|
|
|
|
self.my_testing = self.args.my_testing
|
|
|
|
|
attn_sz = args.n_embd
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): # fancy init
|
|
|
|
|
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
|
|
|
|
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
|
|
|
|
|
|
|
|
|
# fancy time_decay
|
|
|
|
|
decay_speed = torch.ones(attn_sz)
|
|
|
|
|
for h in range(attn_sz):
|
|
|
|
|
decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
|
|
|
|
decay_speed = torch.ones(args.dim_att)
|
|
|
|
|
for h in range(args.dim_att):
|
|
|
|
|
decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
|
|
|
|
self.time_decay = nn.Parameter(decay_speed)
|
|
|
|
|
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
|
|
|
|
|
|
|
|
|
# fancy time_first
|
|
|
|
|
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5
|
|
|
|
|
self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
|
|
|
|
|
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5
|
|
|
|
|
self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)
|
|
|
|
|
|
|
|
|
|
# fancy time_mix
|
|
|
|
|
x = torch.ones(1, 1, args.n_embd)
|
|
|
|
|
@ -135,10 +134,10 @@ class RWKV_TimeMix(MyModule):
|
|
|
|
|
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
|
|
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
|
|
|
self.key = nn.Linear(args.n_embd, attn_sz, bias=False)
|
|
|
|
|
self.value = nn.Linear(args.n_embd, attn_sz, bias=False)
|
|
|
|
|
self.receptance = nn.Linear(args.n_embd, attn_sz, bias=False)
|
|
|
|
|
self.output = nn.Linear(attn_sz, args.n_embd, bias=False)
|
|
|
|
|
self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
|
|
|
|
self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
|
|
|
|
self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
|
|
|
|
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
|
|
|
|
|
|
|
|
|
|
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
|
|
|
|
self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
|
|
|
|
@ -175,7 +174,7 @@ class RWKV_TimeMix(MyModule):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size() # x = (Batch,Time,Channel)
|
|
|
|
|
sr, k, v = self.jit_func(x)
|
|
|
|
|
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
|
|
|
|
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
|
|
|
|
|
return self.output(rwkv)
|
|
|
|
|
|
|
|
|
|
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
|
|
|
|
@ -213,7 +212,7 @@ class RWKV_TimeMix(MyModule):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size() # x = (Batch,Time,Channel)
|
|
|
|
|
sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
|
|
|
|
|
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
|
|
|
|
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
|
|
|
|
|
rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
|
|
|
|
|
return rwkv
|
|
|
|
|
|
|
|
|
|
@ -237,10 +236,9 @@ class RWKV_ChannelMix(MyModule):
|
|
|
|
|
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
|
|
|
|
|
|
|
|
hidden_sz = 4 * args.n_embd
|
|
|
|
|
self.key = nn.Linear(args.n_embd, hidden_sz, bias=False)
|
|
|
|
|
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
|
|
|
|
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
|
|
|
|
self.value = nn.Linear(hidden_sz, args.n_embd, bias=False)
|
|
|
|
|
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
@ -252,6 +250,36 @@ class RWKV_ChannelMix(MyModule):
|
|
|
|
|
kv = self.value(k)
|
|
|
|
|
return torch.sigmoid(self.receptance(xr)) * kv
|
|
|
|
|
|
|
|
|
|
class MishGLU(MyModule):
|
|
|
|
|
def __init__(self, args, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.args = args
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
self.my_testing = self.args.my_testing
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
|
|
|
|
|
|
|
|
|
|
x = torch.ones(1, 1, args.n_embd)
|
|
|
|
|
for i in range(args.n_embd):
|
|
|
|
|
x[0, 0, i] = i / args.n_embd
|
|
|
|
|
|
|
|
|
|
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
|
|
|
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
|
|
|
|
self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
|
|
|
|
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
xx = self.time_shift(x)
|
|
|
|
|
xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
|
|
|
xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
|
|
|
a = self.aa(xa)
|
|
|
|
|
b = self.bb(xb)
|
|
|
|
|
return self.value(a * F.mish(b))
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
# The RWKV Model with our blocks
|
|
|
|
|
########################################################################################################
|
|
|
|
|
@ -277,7 +305,10 @@ class Block(nn.Module):
|
|
|
|
|
else:
|
|
|
|
|
self.att = RWKV_TimeMix(args, layer_id)
|
|
|
|
|
|
|
|
|
|
self.ffn = RWKV_ChannelMix(args, layer_id)
|
|
|
|
|
if 'g' in os.environ["RWKV_MY_TESTING"]:
|
|
|
|
|
self.ffn = MishGLU(args, layer_id)
|
|
|
|
|
else:
|
|
|
|
|
self.ffn = RWKV_ChannelMix(args, layer_id)
|
|
|
|
|
|
|
|
|
|
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
|
|
|
|
self.tiny_ln = nn.LayerNorm(args.n_embd)
|
|
|
|
|
|