|
|
|
|
@ -135,42 +135,89 @@ class RWKV_TimeMix(MyModule):
|
|
|
|
|
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
|
|
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
|
|
|
|
|
|
|
|
self.key = nn.Linear(args.n_embd, attn_sz, bias=False)
|
|
|
|
|
self.value = nn.Linear(args.n_embd, attn_sz, bias=False)
|
|
|
|
|
self.receptance = nn.Linear(args.n_embd, attn_sz, bias=False)
|
|
|
|
|
|
|
|
|
|
self.output = nn.Linear(attn_sz, args.n_embd, bias=False)
|
|
|
|
|
|
|
|
|
|
# if self.my_testing > 0:
|
|
|
|
|
# self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd))
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def jit_func(self, x):
|
|
|
|
|
|
|
|
|
|
# Mix x with the previous timestep to produce xk, xv, xr
|
|
|
|
|
xx = self.time_shift(x)
|
|
|
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
|
|
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
|
|
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
|
|
|
|
|
|
|
|
# Use xk, xv, xr to produce k, v, r
|
|
|
|
|
k = self.key(xk)
|
|
|
|
|
v = self.value(xv)
|
|
|
|
|
r = self.receptance(xr)
|
|
|
|
|
sr = torch.sigmoid(r)
|
|
|
|
|
|
|
|
|
|
return sr, k, v
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size() # x = (Batch,Time,Channel)
|
|
|
|
|
|
|
|
|
|
sr, k, v = self.jit_func(x)
|
|
|
|
|
|
|
|
|
|
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
|
|
|
|
rwkv = self.output(rwkv)
|
|
|
|
|
return rwkv
|
|
|
|
|
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
|
|
|
|
self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
|
|
|
|
d_qkv = args.n_embd // 16
|
|
|
|
|
self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
|
|
|
|
|
self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
|
|
|
|
|
self.vv = nn.Linear(args.n_embd, d_qkv, bias=False)
|
|
|
|
|
self.oo = nn.Linear(d_qkv, args.n_embd, bias=False)
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
x = torch.ones(1, 1, args.n_embd)
|
|
|
|
|
for i in range(args.n_embd):
|
|
|
|
|
x[0, 0, i] = i / args.n_embd
|
|
|
|
|
self.time_mix_qq = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_kk = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_vv = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
|
|
|
|
|
|
|
|
|
if 'a' not in os.environ["RWKV_MY_TESTING"]:
|
|
|
|
|
@MyFunction
|
|
|
|
|
def jit_func(self, x):
|
|
|
|
|
|
|
|
|
|
# Mix x with the previous timestep to produce xk, xv, xr
|
|
|
|
|
xx = self.time_shift(x)
|
|
|
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
|
|
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
|
|
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
|
|
|
|
|
|
|
|
# Use xk, xv, xr to produce k, v, r
|
|
|
|
|
k = self.key(xk)
|
|
|
|
|
v = self.value(xv)
|
|
|
|
|
r = self.receptance(xr)
|
|
|
|
|
sr = torch.sigmoid(r)
|
|
|
|
|
return sr, k, v
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size() # x = (Batch,Time,Channel)
|
|
|
|
|
sr, k, v = self.jit_func(x)
|
|
|
|
|
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
|
|
|
|
return self.output(rwkv)
|
|
|
|
|
|
|
|
|
|
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
|
|
|
|
@MyFunction
|
|
|
|
|
def QKV(self, q, k, v):
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
|
|
|
|
att = att.masked_fill(self.att_mask == 0, float('-inf'))
|
|
|
|
|
att = F.softmax(att, dim = -1)
|
|
|
|
|
x = att @ v
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def jit_funcQKV(self, x):
|
|
|
|
|
# Mix x with the previous timestep to produce xk, xv, xr
|
|
|
|
|
xx = self.time_shift(x)
|
|
|
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
|
|
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
|
|
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
|
|
|
xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq)
|
|
|
|
|
xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk)
|
|
|
|
|
xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv)
|
|
|
|
|
|
|
|
|
|
# Use xk, xv, xr to produce k, v, r
|
|
|
|
|
k = self.key(xk)
|
|
|
|
|
v = self.value(xv)
|
|
|
|
|
r = self.receptance(xr)
|
|
|
|
|
sr = torch.sigmoid(r)
|
|
|
|
|
|
|
|
|
|
qq = self.qq(xqq)
|
|
|
|
|
kk = self.kk(xkk)
|
|
|
|
|
vv = self.vv(xvv)
|
|
|
|
|
|
|
|
|
|
return sr, k, v, qq, kk, vv
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size() # x = (Batch,Time,Channel)
|
|
|
|
|
sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
|
|
|
|
|
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
|
|
|
|
rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
|
|
|
|
|
return rwkv
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
class RWKV_ChannelMix(MyModule):
|
|
|
|
|
def __init__(self, args, layer_id):
|
|
|
|
|
@ -195,38 +242,15 @@ class RWKV_ChannelMix(MyModule):
|
|
|
|
|
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
|
|
|
|
self.value = nn.Linear(hidden_sz, args.n_embd, bias=False)
|
|
|
|
|
|
|
|
|
|
# if self.my_testing in [1]:
|
|
|
|
|
# self.aaa = nn.Parameter(torch.zeros(1, 1, hidden_sz))
|
|
|
|
|
# elif self.my_testing in [2]:
|
|
|
|
|
# self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
xx = self.time_shift(x)
|
|
|
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
|
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
|
|
|
|
|
|
|
|
k = self.key(xk)
|
|
|
|
|
k = torch.square(torch.relu(k))
|
|
|
|
|
kv = self.value(k)
|
|
|
|
|
|
|
|
|
|
rkv = torch.sigmoid(self.receptance(xr)) * kv
|
|
|
|
|
return rkv
|
|
|
|
|
|
|
|
|
|
# k = self.key(xk)
|
|
|
|
|
# # if self.my_testing in [0, 2]:
|
|
|
|
|
# k = torch.square(torch.relu(k))
|
|
|
|
|
# # elif self.my_testing == 1:
|
|
|
|
|
# # k = torch.square(torch.relu(k)) + k * self.aaa
|
|
|
|
|
# kv = self.value(k)
|
|
|
|
|
# r = self.receptance(xr)
|
|
|
|
|
# # if self.my_testing == 0:
|
|
|
|
|
# r = torch.sigmoid(r)
|
|
|
|
|
# # elif self.my_testing == 2:
|
|
|
|
|
# # r = torch.sigmoid(r) + r * self.aaa
|
|
|
|
|
# rkv = r * kv
|
|
|
|
|
# return rkv
|
|
|
|
|
return torch.sigmoid(self.receptance(xr)) * kv
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
# The RWKV Model with our blocks
|
|
|
|
|
@ -479,7 +503,7 @@ class RWKV(pl.LightningModule):
|
|
|
|
|
|
|
|
|
|
gain = 1.0
|
|
|
|
|
scale = 1.0
|
|
|
|
|
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n:
|
|
|
|
|
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
|
|
|
|
|
m[n] = p
|
|
|
|
|
else:
|
|
|
|
|
if n == "emb.weight":
|
|
|
|
|
@ -487,7 +511,7 @@ class RWKV(pl.LightningModule):
|
|
|
|
|
else:
|
|
|
|
|
if shape[0] > shape[1]:
|
|
|
|
|
gain = math.sqrt(shape[0] / shape[1])
|
|
|
|
|
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q."]:
|
|
|
|
|
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
|
|
|
|
|
if kk in n:
|
|
|
|
|
scale = 0
|
|
|
|
|
if n == "head.weight":
|
|
|
|
|
|