diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index a8d868c..089da07 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -32,6 +32,10 @@ class MyDataset(Dataset): self.data_size = len(self.data._bin_buffer) // 2 print(f"Data has {self.data_size} tokens.") + if args.my_qa_mask == 1: + self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document') + self.data_pile_size = len(self.data_pile._bin_buffer) // 2 + if args.my_pile_stage > 0: # assert self.data_size == 332115325534 and self.vocab_size == 50277 self.samples_per_epoch = args.epoch_steps * args.real_bsz @@ -146,25 +150,69 @@ class MyDataset(Dataset): else: ctx_len = args.ctx_len req_len = ctx_len + 1 + magic_prime = args.magic_prime + data = self.data if args.my_pile_stage > 0: ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank + + if args.my_qa_mask == 1: + ii_orig = ii + if ii % 2 == 0: + ii = (ii // 2) * args.magic_prime + magic_prime = 324331313 + data = self.data_pile + else: + ii = ii // 2 + factor = (math.sqrt(5) - 1) / 2 - factor = int(args.magic_prime * factor) - i = ((factor * ii * ii * ii) % args.magic_prime) * ctx_len - i = i + args.my_pile_shift + factor = int(magic_prime * factor) + i = ((factor * ii * ii * ii) % magic_prime) * ctx_len + if (args.my_qa_mask == 0) or (data == self.data_pile): + i = i + args.my_pile_shift # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") else: # cheat: pick a random spot in dataset i = np.random.randint(0, self.data_size - req_len) if args.data_type == "binidx": - dix = self.data.get(idx=0, offset=i, length=req_len).astype(int) + dix = data.get(idx=0, offset=i, length=req_len).astype(int) elif args.data_type == "numpy": - dix = self.data[i : i + req_len] + dix = data[i : i + req_len] else: - dix = [self.stoi[s] for s in self.data[i : i + req_len]] + dix = [self.stoi[s] for s in data[i : i + req_len]] + + if args.my_qa_mask == 1: + if data == self.data_pile: + z = [1] * ctx_len + else: + z = [0] * ctx_len + z_sum = 0 + isGood = False + for i in range(3, ctx_len): + if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187: + isGood = True + if dix[i] == 0: + isGood = False + if isGood: + z[i] = 1 + z_sum += 1 + if z_sum == 0: + z = [1] * ctx_len + i = np.random.randint(0, self.data_pile_size - req_len) + dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int) + z = torch.tensor(z, dtype=torch.bfloat16) x = torch.tensor(dix[:-1], dtype=torch.long) y = torch.tensor(dix[1:], dtype=torch.long) + + # if ii_orig < 50: + # # if rank == 1: + # print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:]) + # else: + # exit(0) + + if args.my_qa_mask == 1: + return x, y, z + return x, y diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 38906b4..3857e03 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -108,7 +108,7 @@ class RWKV_TimeMix(MyModule): self.layer_id = layer_id self.ctx_len = args.ctx_len self.n_embd = args.n_embd - + self.my_testing = self.args.my_testing attn_sz = args.n_embd with torch.no_grad(): # fancy init @@ -142,6 +142,9 @@ class RWKV_TimeMix(MyModule): self.output = nn.Linear(attn_sz, args.n_embd, bias=False) + # if self.my_testing > 0: + # self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd)) + @MyFunction def jit_func(self, x): @@ -174,7 +177,7 @@ class RWKV_ChannelMix(MyModule): super().__init__() self.args = args self.layer_id = layer_id - + self.my_testing = self.args.my_testing self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): # fancy init of time_mix @@ -192,6 +195,12 @@ class RWKV_ChannelMix(MyModule): self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False) self.value = nn.Linear(hidden_sz, args.n_embd, bias=False) + # if self.my_testing in [1]: + # self.aaa = nn.Parameter(torch.zeros(1, 1, hidden_sz)) + # elif self.my_testing in [2]: + # self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd)) + + @MyFunction def forward(self, x): xx = self.time_shift(x) @@ -205,6 +214,19 @@ class RWKV_ChannelMix(MyModule): rkv = torch.sigmoid(self.receptance(xr)) * kv return rkv + # k = self.key(xk) + # # if self.my_testing in [0, 2]: + # k = torch.square(torch.relu(k)) + # # elif self.my_testing == 1: + # # k = torch.square(torch.relu(k)) + k * self.aaa + # kv = self.value(k) + # r = self.receptance(xr) + # # if self.my_testing == 0: + # r = torch.sigmoid(r) + # # elif self.my_testing == 2: + # # r = torch.sigmoid(r) + r * self.aaa + # rkv = r * kv + # return rkv ######################################################################################################## # The RWKV Model with our blocks @@ -401,9 +423,38 @@ class RWKV(pl.LightningModule): def training_step(self, batch, batch_idx): args = self.args - idx, targets = batch - logits = self(idx) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + if args.my_qa_mask == 0: + idx, targets = batch + logits = self(idx) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + else: + idx, targets, mask = batch + mask = mask.view(-1) + sum_mask = torch.sum(mask).item() + # if sum_mask == 0: + # return torch.tensor([0.0], requires_grad=True) + + logits = self(idx) + if sum_mask == mask.shape[0]: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + # print('rank', self.global_rank, 'loss', loss.item()) + else: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none') + # loss_raw = loss + loss = torch.sum(loss * mask) / sum_mask + + # torch.set_printoptions(threshold=10000) + # if True: #self.global_rank == 1: + # tmp = '' + # sss = 0 + # ccc = 0 + # for i in range(mask.shape[0]): + # if mask[i] > 0: + # tmp += str(idx.view(-1)[i].item()) + ',' + # sss += loss_raw.view(-1)[i].float().item() + # ccc += 1 + # print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx) + return L2Wrap.apply(loss, logits) def training_step_end(self, batch_parts): diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index a109724..ac82d38 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -98,7 +98,7 @@ class train_callback(pl.Callback): lll["kt/s"] = kt_s trainer.my_wandb.log(lll, step=int(real_step)) if args.magic_prime > 0: - if int(real_step) == int(args.magic_prime // args.real_bsz) - 1: + if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1: to_save_dict = pl_module.state_dict() torch.save( to_save_dict, diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index fd90b10..073a5f7 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -100,6 +100,7 @@ if __name__ == "__main__": parser.add_argument("--my_pos_emb", default=0, type=int) parser.add_argument("--load_partial", default=0, type=int) parser.add_argument("--magic_prime", default=0, type=int) + parser.add_argument("--my_qa_mask", default=0, type=int) parser.add_argument("--my_testing", default=0, type=int) parser = Trainer.add_argparse_args(parser)