|
|
|
@ -183,9 +183,12 @@ class MyDataset(Dataset):
|
|
|
|
if (args.my_qa_mask == 0) or (data == self.data_pile):
|
|
|
|
if (args.my_qa_mask == 0) or (data == self.data_pile):
|
|
|
|
i = i + args.my_pile_shift
|
|
|
|
i = i + args.my_pile_shift
|
|
|
|
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
|
|
|
|
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
|
|
|
|
else:
|
|
|
|
elif args.my_pile_stage == 4:
|
|
|
|
# cheat: pick a random spot in dataset
|
|
|
|
# cheat: pick a random spot in dataset
|
|
|
|
|
|
|
|
if args.my_pile_version == 1:
|
|
|
|
i = np.random.randint(0, self.data_size - req_len)
|
|
|
|
i = np.random.randint(0, self.data_size - req_len)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
i = np.random.randint(0, self.data_size)
|
|
|
|
|
|
|
|
|
|
|
|
if args.data_type == "binidx":
|
|
|
|
if args.data_type == "binidx":
|
|
|
|
if args.my_pile_version == 1:
|
|
|
|
if args.my_pile_version == 1:
|
|
|
|
|