|
|
|
|
@ -108,11 +108,13 @@ class RWKV_TimeMix(MyModule):
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
self.ctx_len = args.ctx_len
|
|
|
|
|
self.n_embd = args.n_embd
|
|
|
|
|
self.my_testing = self.args.my_testing
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): # fancy init
|
|
|
|
|
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
|
|
|
|
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
|
|
|
|
ddd = torch.ones(1, 1, args.n_embd)
|
|
|
|
|
for i in range(args.n_embd):
|
|
|
|
|
ddd[0, 0, i] = i / args.n_embd
|
|
|
|
|
|
|
|
|
|
# fancy time_decay
|
|
|
|
|
decay_speed = torch.ones(args.dim_att)
|
|
|
|
|
@ -126,12 +128,9 @@ class RWKV_TimeMix(MyModule):
|
|
|
|
|
self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)
|
|
|
|
|
|
|
|
|
|
# fancy time_mix
|
|
|
|
|
x = torch.ones(1, 1, args.n_embd)
|
|
|
|
|
for i in range(args.n_embd):
|
|
|
|
|
x[0, 0, i] = i / args.n_embd
|
|
|
|
|
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
|
|
|
|
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
|
|
|
|
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
|
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
|
|
|
self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
|
|
|
|
@ -147,24 +146,17 @@ class RWKV_TimeMix(MyModule):
|
|
|
|
|
self.vv = nn.Linear(args.n_embd, d_qkv, bias=False)
|
|
|
|
|
self.oo = nn.Linear(d_qkv, args.n_embd, bias=False)
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
x = torch.ones(1, 1, args.n_embd)
|
|
|
|
|
for i in range(args.n_embd):
|
|
|
|
|
x[0, 0, i] = i / args.n_embd
|
|
|
|
|
self.time_mix_qq = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_kk = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_vv = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
|
|
|
|
self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
|
|
|
|
|
|
|
|
|
if 'a' not in os.environ["RWKV_MY_TESTING"]:
|
|
|
|
|
@MyFunction
|
|
|
|
|
def jit_func(self, x):
|
|
|
|
|
|
|
|
|
|
# Mix x with the previous timestep to produce xk, xv, xr
|
|
|
|
|
xx = self.time_shift(x)
|
|
|
|
|
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
|
|
|
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
|
|
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
|
|
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
|
|
|
|
|
|
|
|
# Use xk, xv, xr to produce k, v, r
|
|
|
|
|
k = self.key(xk)
|
|
|
|
|
v = self.value(xv)
|
|
|
|
|
r = self.receptance(xr)
|
|
|
|
|
@ -188,25 +180,20 @@ class RWKV_TimeMix(MyModule):
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def jit_funcQKV(self, x):
|
|
|
|
|
# Mix x with the previous timestep to produce xk, xv, xr
|
|
|
|
|
xx = self.time_shift(x)
|
|
|
|
|
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
|
|
|
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
|
|
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
|
|
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
|
|
|
xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq)
|
|
|
|
|
xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk)
|
|
|
|
|
xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv)
|
|
|
|
|
|
|
|
|
|
# Use xk, xv, xr to produce k, v, r
|
|
|
|
|
k = self.key(xk)
|
|
|
|
|
v = self.value(xv)
|
|
|
|
|
r = self.receptance(xr)
|
|
|
|
|
sr = torch.sigmoid(r)
|
|
|
|
|
|
|
|
|
|
qq = self.qq(xqq)
|
|
|
|
|
kk = self.kk(xkk)
|
|
|
|
|
vv = self.vv(xvv)
|
|
|
|
|
|
|
|
|
|
return sr, k, v, qq, kk, vv
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
@ -223,18 +210,15 @@ class RWKV_ChannelMix(MyModule):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.args = args
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
self.my_testing = self.args.my_testing
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): # fancy init of time_mix
|
|
|
|
|
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
|
|
|
|
|
|
|
|
|
x = torch.ones(1, 1, args.n_embd)
|
|
|
|
|
ddd = torch.ones(1, 1, args.n_embd)
|
|
|
|
|
for i in range(args.n_embd):
|
|
|
|
|
x[0, 0, i] = i / args.n_embd
|
|
|
|
|
|
|
|
|
|
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
|
|
|
ddd[0, 0, i] = i / args.n_embd
|
|
|
|
|
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
|
|
|
|
|
|
|
|
|
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
|
|
|
|
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
|
|
|
|
@ -255,7 +239,6 @@ class MishGLU(MyModule):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.args = args
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
self.my_testing = self.args.my_testing
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
@ -478,7 +461,7 @@ class RWKV(pl.LightningModule):
|
|
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
|
args = self.args
|
|
|
|
|
if args.my_qa_mask == 0:
|
|
|
|
|
if args.my_qa_mask != 1:
|
|
|
|
|
idx, targets = batch
|
|
|
|
|
logits = self(idx)
|
|
|
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
|
|
|
|
|