diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index ee5ce8e..662023f 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -19,22 +19,26 @@ class MyDataset(Dataset): self.vocab_size = args.vocab_size rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)") - if args.data_file.endswith('/'): - d_all = [] - for p in os.listdir(args.data_file): - if p.endswith(".idx"): - d_all += [p[:-4]] - d_all.sort() - rank_zero_info(d_all) - exit(0) - else: + if args.my_pile_version == 1: self.data = MMapIndexedDataset(args.data_file) - self.data_size = len(self.data._bin_buffer) // 2 + self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size rank_zero_info(f"Data has {self.data_size} tokens.") + else: + data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n') + data_list = [i.strip().split(' ') for i in data_list] + self.data = [] + self.data_size = int(data_list[-1][-1]) + rank_zero_info(f"Data has {self.data_size} chunks.") + for d in data_list: + data = MMapIndexedDataset(d[0]) + data_size = len(data._bin_buffer) // data._index._dtype_size + assert (data_size - args.ctx_len) == int(d[1]) + self.data += [[int(d[-1]), int(d[1]), data]] + # rank_zero_info(self.data) if args.my_qa_mask > 0: self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document') - self.data_pile_size = len(self.data_pile._bin_buffer) // 2 + self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size if args.my_pile_stage > 0: # assert self.data_size == 332115325534 and self.vocab_size == 50277 @@ -184,7 +188,17 @@ class MyDataset(Dataset): i = np.random.randint(0, self.data_size - req_len) if args.data_type == "binidx": - dix = data.get(idx=0, offset=i, length=req_len).astype(int) + if args.my_pile_version == 1: + dix = data.get(idx=0, offset=i, length=req_len).astype(int) + else: + # self.data : cutoff, chunk_count, data + for j in range(len(data)): + if i < data[j][0]: + ii = i + i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1] + dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int) + # print(ii, j, i) + break elif args.data_type == "numpy": dix = data[i : i + req_len] else: diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 873bd52..3368c8c 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -80,8 +80,9 @@ if __name__ == "__main__": parser.add_argument("--beta1", default=0.9, type=float) parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence parser.add_argument("--adam_eps", default=1e-8, type=float) - parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower + + parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift parser.add_argument("--my_pile_edecay", default=0, type=int) @@ -157,18 +158,27 @@ if __name__ == "__main__": if args.my_pile_stage > 0: magic_prime_bak = args.magic_prime - if args.ctx_len == 1024: - args.magic_prime = 324331313 - args.epoch_count = 8043 - elif args.ctx_len == 2048: - args.magic_prime = 162165671 - args.epoch_count = 4021 - elif args.ctx_len == 4096: - args.magic_prime = 81082817 - args.epoch_count = 2010 - elif args.ctx_len == 8192: - args.magic_prime = 40541399 - args.epoch_count = 1005 + + if args.my_pile_version == 1: + if args.ctx_len == 1024: + args.magic_prime = 324331313 + args.epoch_count = 8043 + elif args.ctx_len == 2048: + args.magic_prime = 162165671 + args.epoch_count = 4021 + elif args.ctx_len == 4096: + args.magic_prime = 81082817 + args.epoch_count = 2010 + elif args.ctx_len == 8192: + args.magic_prime = 40541399 + args.epoch_count = 1005 + else: + if args.ctx_len == 4096: + args.magic_prime = 423736637 + args.epoch_count = 10508 + elif args.ctx_len == 8192: + args.magic_prime = 211868309 + args.epoch_count = 5253 if args.my_pile_shift < 0: args.my_pile_shift = 0