|
|
|
|
@ -108,7 +108,7 @@ class RWKV_TimeMix(MyModule):
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
self.ctx_len = args.ctx_len
|
|
|
|
|
self.n_embd = args.n_embd
|
|
|
|
|
|
|
|
|
|
self.my_testing = self.args.my_testing
|
|
|
|
|
attn_sz = args.n_embd
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): # fancy init
|
|
|
|
|
@ -142,6 +142,9 @@ class RWKV_TimeMix(MyModule):
|
|
|
|
|
|
|
|
|
|
self.output = nn.Linear(attn_sz, args.n_embd, bias=False)
|
|
|
|
|
|
|
|
|
|
# if self.my_testing > 0:
|
|
|
|
|
# self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd))
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def jit_func(self, x):
|
|
|
|
|
|
|
|
|
|
@ -174,7 +177,7 @@ class RWKV_ChannelMix(MyModule):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.args = args
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
|
|
|
|
|
self.my_testing = self.args.my_testing
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): # fancy init of time_mix
|
|
|
|
|
@ -192,6 +195,12 @@ class RWKV_ChannelMix(MyModule):
|
|
|
|
|
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
|
|
|
|
self.value = nn.Linear(hidden_sz, args.n_embd, bias=False)
|
|
|
|
|
|
|
|
|
|
# if self.my_testing in [1]:
|
|
|
|
|
# self.aaa = nn.Parameter(torch.zeros(1, 1, hidden_sz))
|
|
|
|
|
# elif self.my_testing in [2]:
|
|
|
|
|
# self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
xx = self.time_shift(x)
|
|
|
|
|
@ -205,6 +214,19 @@ class RWKV_ChannelMix(MyModule):
|
|
|
|
|
rkv = torch.sigmoid(self.receptance(xr)) * kv
|
|
|
|
|
return rkv
|
|
|
|
|
|
|
|
|
|
# k = self.key(xk)
|
|
|
|
|
# # if self.my_testing in [0, 2]:
|
|
|
|
|
# k = torch.square(torch.relu(k))
|
|
|
|
|
# # elif self.my_testing == 1:
|
|
|
|
|
# # k = torch.square(torch.relu(k)) + k * self.aaa
|
|
|
|
|
# kv = self.value(k)
|
|
|
|
|
# r = self.receptance(xr)
|
|
|
|
|
# # if self.my_testing == 0:
|
|
|
|
|
# r = torch.sigmoid(r)
|
|
|
|
|
# # elif self.my_testing == 2:
|
|
|
|
|
# # r = torch.sigmoid(r) + r * self.aaa
|
|
|
|
|
# rkv = r * kv
|
|
|
|
|
# return rkv
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
# The RWKV Model with our blocks
|
|
|
|
|
@ -401,9 +423,38 @@ class RWKV(pl.LightningModule):
|
|
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
|
args = self.args
|
|
|
|
|
if args.my_qa_mask == 0:
|
|
|
|
|
idx, targets = batch
|
|
|
|
|
logits = self(idx)
|
|
|
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
|
|
|
|
else:
|
|
|
|
|
idx, targets, mask = batch
|
|
|
|
|
mask = mask.view(-1)
|
|
|
|
|
sum_mask = torch.sum(mask).item()
|
|
|
|
|
# if sum_mask == 0:
|
|
|
|
|
# return torch.tensor([0.0], requires_grad=True)
|
|
|
|
|
|
|
|
|
|
logits = self(idx)
|
|
|
|
|
if sum_mask == mask.shape[0]:
|
|
|
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
|
|
|
|
# print('rank', self.global_rank, 'loss', loss.item())
|
|
|
|
|
else:
|
|
|
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
|
|
|
|
|
# loss_raw = loss
|
|
|
|
|
loss = torch.sum(loss * mask) / sum_mask
|
|
|
|
|
|
|
|
|
|
# torch.set_printoptions(threshold=10000)
|
|
|
|
|
# if True: #self.global_rank == 1:
|
|
|
|
|
# tmp = ''
|
|
|
|
|
# sss = 0
|
|
|
|
|
# ccc = 0
|
|
|
|
|
# for i in range(mask.shape[0]):
|
|
|
|
|
# if mask[i] > 0:
|
|
|
|
|
# tmp += str(idx.view(-1)[i].item()) + ','
|
|
|
|
|
# sss += loss_raw.view(-1)[i].float().item()
|
|
|
|
|
# ccc += 1
|
|
|
|
|
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
|
|
|
|
|
|
|
|
|
|
return L2Wrap.apply(loss, logits)
|
|
|
|
|
|
|
|
|
|
def training_step_end(self, batch_parts):
|
|
|
|
|
|