|
|
|
|
@ -19,22 +19,26 @@ class MyDataset(Dataset):
|
|
|
|
|
self.vocab_size = args.vocab_size
|
|
|
|
|
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
|
|
|
|
|
|
|
|
|
|
if args.data_file.endswith('/'):
|
|
|
|
|
d_all = []
|
|
|
|
|
for p in os.listdir(args.data_file):
|
|
|
|
|
if p.endswith(".idx"):
|
|
|
|
|
d_all += [p[:-4]]
|
|
|
|
|
d_all.sort()
|
|
|
|
|
rank_zero_info(d_all)
|
|
|
|
|
exit(0)
|
|
|
|
|
else:
|
|
|
|
|
if args.my_pile_version == 1:
|
|
|
|
|
self.data = MMapIndexedDataset(args.data_file)
|
|
|
|
|
self.data_size = len(self.data._bin_buffer) // 2
|
|
|
|
|
self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
|
|
|
|
|
rank_zero_info(f"Data has {self.data_size} tokens.")
|
|
|
|
|
else:
|
|
|
|
|
data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n')
|
|
|
|
|
data_list = [i.strip().split(' ') for i in data_list]
|
|
|
|
|
self.data = []
|
|
|
|
|
self.data_size = int(data_list[-1][-1])
|
|
|
|
|
rank_zero_info(f"Data has {self.data_size} chunks.")
|
|
|
|
|
for d in data_list:
|
|
|
|
|
data = MMapIndexedDataset(d[0])
|
|
|
|
|
data_size = len(data._bin_buffer) // data._index._dtype_size
|
|
|
|
|
assert (data_size - args.ctx_len) == int(d[1])
|
|
|
|
|
self.data += [[int(d[-1]), int(d[1]), data]]
|
|
|
|
|
# rank_zero_info(self.data)
|
|
|
|
|
|
|
|
|
|
if args.my_qa_mask > 0:
|
|
|
|
|
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
|
|
|
|
|
self.data_pile_size = len(self.data_pile._bin_buffer) // 2
|
|
|
|
|
self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
|
|
|
|
|
|
|
|
|
|
if args.my_pile_stage > 0:
|
|
|
|
|
# assert self.data_size == 332115325534 and self.vocab_size == 50277
|
|
|
|
|
@ -184,7 +188,17 @@ class MyDataset(Dataset):
|
|
|
|
|
i = np.random.randint(0, self.data_size - req_len)
|
|
|
|
|
|
|
|
|
|
if args.data_type == "binidx":
|
|
|
|
|
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
|
|
|
|
|
if args.my_pile_version == 1:
|
|
|
|
|
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
|
|
|
|
|
else:
|
|
|
|
|
# self.data : cutoff, chunk_count, data
|
|
|
|
|
for j in range(len(data)):
|
|
|
|
|
if i < data[j][0]:
|
|
|
|
|
ii = i
|
|
|
|
|
i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1]
|
|
|
|
|
dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int)
|
|
|
|
|
# print(ii, j, i)
|
|
|
|
|
break
|
|
|
|
|
elif args.data_type == "numpy":
|
|
|
|
|
dix = data[i : i + req_len]
|
|
|
|
|
else:
|
|
|
|
|
|