main
BlinkDL 3 years ago
parent f03efd0218
commit 038f06b996

@ -135,42 +135,89 @@ class RWKV_TimeMix(MyModule):
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(args.n_embd, attn_sz, bias=False)
self.value = nn.Linear(args.n_embd, attn_sz, bias=False)
self.receptance = nn.Linear(args.n_embd, attn_sz, bias=False)
self.output = nn.Linear(attn_sz, args.n_embd, bias=False)
# if self.my_testing > 0:
# self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd))
@MyFunction
def jit_func(self, x):
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)
return sr, k, v
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v = self.jit_func(x)
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
rwkv = self.output(rwkv)
return rwkv
if 'a' in os.environ["RWKV_MY_TESTING"]:
self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
d_qkv = args.n_embd // 16
self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
self.vv = nn.Linear(args.n_embd, d_qkv, bias=False)
self.oo = nn.Linear(d_qkv, args.n_embd, bias=False)
with torch.no_grad():
x = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
x[0, 0, i] = i / args.n_embd
self.time_mix_qq = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_kk = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_vv = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
if 'a' not in os.environ["RWKV_MY_TESTING"]:
@MyFunction
def jit_func(self, x):
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)
return sr, k, v
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v = self.jit_func(x)
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
return self.output(rwkv)
if 'a' in os.environ["RWKV_MY_TESTING"]:
@MyFunction
def QKV(self, q, k, v):
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.att_mask == 0, float('-inf'))
att = F.softmax(att, dim = -1)
x = att @ v
return x
@MyFunction
def jit_funcQKV(self, x):
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq)
xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk)
xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)
qq = self.qq(xqq)
kk = self.kk(xkk)
vv = self.vv(xvv)
return sr, k, v, qq, kk, vv
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
return rwkv
########################################################################################################
class RWKV_ChannelMix(MyModule):
def __init__(self, args, layer_id):
@ -195,38 +242,15 @@ class RWKV_ChannelMix(MyModule):
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, args.n_embd, bias=False)
# if self.my_testing in [1]:
# self.aaa = nn.Parameter(torch.zeros(1, 1, hidden_sz))
# elif self.my_testing in [2]:
# self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd))
@MyFunction
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
# k = self.key(xk)
# # if self.my_testing in [0, 2]:
# k = torch.square(torch.relu(k))
# # elif self.my_testing == 1:
# # k = torch.square(torch.relu(k)) + k * self.aaa
# kv = self.value(k)
# r = self.receptance(xr)
# # if self.my_testing == 0:
# r = torch.sigmoid(r)
# # elif self.my_testing == 2:
# # r = torch.sigmoid(r) + r * self.aaa
# rkv = r * kv
# return rkv
return torch.sigmoid(self.receptance(xr)) * kv
########################################################################################################
# The RWKV Model with our blocks
@ -479,7 +503,7 @@ class RWKV(pl.LightningModule):
gain = 1.0
scale = 1.0
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n:
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
m[n] = p
else:
if n == "emb.weight":
@ -487,7 +511,7 @@ class RWKV(pl.LightningModule):
else:
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q."]:
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
if kk in n:
scale = 0
if n == "head.weight":

Loading…
Cancel
Save