From 79915b3696fc744d410b69a09510b7161bda835b Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Tue, 28 Mar 2023 04:30:44 +0000 Subject: [PATCH] better --- RWKV-v4neo/src/dataset.py | 24 +++++++++--------------- RWKV-v4neo/train.py | 22 ++++++++++++++++------ 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 5261586..6519991 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -37,7 +37,8 @@ class MyDataset(Dataset): # rank_zero_info(self.data) if args.my_qa_mask > 0: - self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document') + self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document') + # self.data_pile = MMapIndexedDataset('/fsx/pile_deduped/pile_0.87_deduped_text_document') self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size if args.my_pile_stage > 0: @@ -164,23 +165,16 @@ class MyDataset(Dataset): if args.my_qa_mask > 0: ii_orig = ii if ii % 2 == 0: - ii = (ii // 2) * args.magic_prime - if args.ctx_len == 1024: - magic_prime = 324331313 - elif args.ctx_len == 2048: - magic_prime = 162165671 - elif args.ctx_len == 4096: - magic_prime = 81082817 - elif args.ctx_len == 8192: - magic_prime = 40541399 + ii = -1 data = self.data_pile else: ii = ii // 2 - - factor = (math.sqrt(5) - 1) / 2 - factor = int(magic_prime * factor) - i = ((factor * ii * ii * ii) % magic_prime) * ctx_len - if (args.my_qa_mask == 0) or (data == self.data_pile): + if ii < 0: + i = np.random.randint(0, self.data_pile_size - req_len) + else: + factor = (math.sqrt(5) - 1) / 2 + factor = int(magic_prime * factor) + i = ((factor * ii * ii * ii) % magic_prime) * ctx_len i = i + args.my_pile_shift # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") elif args.my_pile_stage == 4: diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 03eda09..e5d6c26 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -76,7 +76,7 @@ if __name__ == "__main__": parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 parser.add_argument("--lr_final", default=1e-5, type=float) - parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model + parser.add_argument("--warmup_steps", default=-1, type=int) # try 50 if you load a model parser.add_argument("--beta1", default=0.9, type=float) parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence parser.add_argument("--adam_eps", default=1e-8, type=float) @@ -173,9 +173,18 @@ if __name__ == "__main__": args.magic_prime = 40541399 args.epoch_count = 1005 else: - if args.ctx_len == 4096: + if args.ctx_len == 1024: + args.magic_prime = 1694947181 + args.epoch_count = 42036 + elif args.ctx_len == 2048: + args.magic_prime = 847473509 + args.epoch_count = 21017 + elif args.ctx_len == 4096: args.magic_prime = 423736637 args.epoch_count = 10508 + elif args.ctx_len == 6144: + args.magic_prime = 282491051 + args.epoch_count = 7005 elif args.ctx_len == 8192: args.magic_prime = 211868243 args.epoch_count = 5253 @@ -207,10 +216,11 @@ if __name__ == "__main__": args.load_model = f"{args.proj_dir}/rwkv-init.pth" else: args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" - if args.my_pile_stage == 2: - args.warmup_steps = 10 - else: - args.warmup_steps = 30 + if args.warmup_steps < 0: + if args.my_pile_stage == 2: + args.warmup_steps = 10 + else: + args.warmup_steps = 30 args.epoch_begin = max_p + 1 samples_per_epoch = args.epoch_steps * args.real_bsz