From 4ca274aad7cd21b3472b3cc5a6ad8498b8a86fc7 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Mon, 20 Mar 2023 10:14:36 +0000 Subject: [PATCH] fix --- RWKV-v4neo/src/dataset.py | 7 +++++-- RWKV-v4neo/train.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 662023f..71a8275 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -183,9 +183,12 @@ class MyDataset(Dataset): if (args.my_qa_mask == 0) or (data == self.data_pile): i = i + args.my_pile_shift # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") - else: + elif args.my_pile_stage == 4: # cheat: pick a random spot in dataset - i = np.random.randint(0, self.data_size - req_len) + if args.my_pile_version == 1: + i = np.random.randint(0, self.data_size - req_len) + else: + i = np.random.randint(0, self.data_size) if args.data_type == "binidx": if args.my_pile_version == 1: diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 3368c8c..03eda09 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -177,7 +177,7 @@ if __name__ == "__main__": args.magic_prime = 423736637 args.epoch_count = 10508 elif args.ctx_len == 8192: - args.magic_prime = 211868309 + args.magic_prime = 211868243 args.epoch_count = 5253 if args.my_pile_shift < 0: args.my_pile_shift = 0