diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 71a8275..5261586 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -189,6 +189,9 @@ class MyDataset(Dataset): i = np.random.randint(0, self.data_size - req_len) else: i = np.random.randint(0, self.data_size) + else: + # cheat: pick a random spot in dataset + i = np.random.randint(0, self.data_size - req_len) if args.data_type == "binidx": if args.my_pile_version == 1: