main
BlinkDL 3 years ago
parent 3d2b04ba0c
commit 1945cb58ed

@ -19,22 +19,26 @@ class MyDataset(Dataset):
self.vocab_size = args.vocab_size self.vocab_size = args.vocab_size
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)") rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
if args.data_file.endswith('/'): if args.my_pile_version == 1:
d_all = []
for p in os.listdir(args.data_file):
if p.endswith(".idx"):
d_all += [p[:-4]]
d_all.sort()
rank_zero_info(d_all)
exit(0)
else:
self.data = MMapIndexedDataset(args.data_file) self.data = MMapIndexedDataset(args.data_file)
self.data_size = len(self.data._bin_buffer) // 2 self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
rank_zero_info(f"Data has {self.data_size} tokens.") rank_zero_info(f"Data has {self.data_size} tokens.")
else:
data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n')
data_list = [i.strip().split(' ') for i in data_list]
self.data = []
self.data_size = int(data_list[-1][-1])
rank_zero_info(f"Data has {self.data_size} chunks.")
for d in data_list:
data = MMapIndexedDataset(d[0])
data_size = len(data._bin_buffer) // data._index._dtype_size
assert (data_size - args.ctx_len) == int(d[1])
self.data += [[int(d[-1]), int(d[1]), data]]
# rank_zero_info(self.data)
if args.my_qa_mask > 0: if args.my_qa_mask > 0:
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document') self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
self.data_pile_size = len(self.data_pile._bin_buffer) // 2 self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
if args.my_pile_stage > 0: if args.my_pile_stage > 0:
# assert self.data_size == 332115325534 and self.vocab_size == 50277 # assert self.data_size == 332115325534 and self.vocab_size == 50277
@ -184,7 +188,17 @@ class MyDataset(Dataset):
i = np.random.randint(0, self.data_size - req_len) i = np.random.randint(0, self.data_size - req_len)
if args.data_type == "binidx": if args.data_type == "binidx":
dix = data.get(idx=0, offset=i, length=req_len).astype(int) if args.my_pile_version == 1:
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
else:
# self.data : cutoff, chunk_count, data
for j in range(len(data)):
if i < data[j][0]:
ii = i
i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1]
dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int)
# print(ii, j, i)
break
elif args.data_type == "numpy": elif args.data_type == "numpy":
dix = data[i : i + req_len] dix = data[i : i + req_len]
else: else:

@ -80,8 +80,9 @@ if __name__ == "__main__":
parser.add_argument("--beta1", default=0.9, type=float) parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float) parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
parser.add_argument("--my_pile_edecay", default=0, type=int) parser.add_argument("--my_pile_edecay", default=0, type=int)
@ -157,18 +158,27 @@ if __name__ == "__main__":
if args.my_pile_stage > 0: if args.my_pile_stage > 0:
magic_prime_bak = args.magic_prime magic_prime_bak = args.magic_prime
if args.ctx_len == 1024:
args.magic_prime = 324331313 if args.my_pile_version == 1:
args.epoch_count = 8043 if args.ctx_len == 1024:
elif args.ctx_len == 2048: args.magic_prime = 324331313
args.magic_prime = 162165671 args.epoch_count = 8043
args.epoch_count = 4021 elif args.ctx_len == 2048:
elif args.ctx_len == 4096: args.magic_prime = 162165671
args.magic_prime = 81082817 args.epoch_count = 4021
args.epoch_count = 2010 elif args.ctx_len == 4096:
elif args.ctx_len == 8192: args.magic_prime = 81082817
args.magic_prime = 40541399 args.epoch_count = 2010
args.epoch_count = 1005 elif args.ctx_len == 8192:
args.magic_prime = 40541399
args.epoch_count = 1005
else:
if args.ctx_len == 4096:
args.magic_prime = 423736637
args.epoch_count = 10508
elif args.ctx_len == 8192:
args.magic_prime = 211868309
args.epoch_count = 5253
if args.my_pile_shift < 0: if args.my_pile_shift < 0:
args.my_pile_shift = 0 args.my_pile_shift = 0

Loading…
Cancel
Save