main
BlinkDL 3 years ago
parent 0c7cd08255
commit 79915b3696

@ -37,7 +37,8 @@ class MyDataset(Dataset):
# rank_zero_info(self.data)
if args.my_qa_mask > 0:
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document')
# self.data_pile = MMapIndexedDataset('/fsx/pile_deduped/pile_0.87_deduped_text_document')
self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
if args.my_pile_stage > 0:
@ -164,23 +165,16 @@ class MyDataset(Dataset):
if args.my_qa_mask > 0:
ii_orig = ii
if ii % 2 == 0:
ii = (ii // 2) * args.magic_prime
if args.ctx_len == 1024:
magic_prime = 324331313
elif args.ctx_len == 2048:
magic_prime = 162165671
elif args.ctx_len == 4096:
magic_prime = 81082817
elif args.ctx_len == 8192:
magic_prime = 40541399
ii = -1
data = self.data_pile
else:
ii = ii // 2
if ii < 0:
i = np.random.randint(0, self.data_pile_size - req_len)
else:
factor = (math.sqrt(5) - 1) / 2
factor = int(magic_prime * factor)
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
if (args.my_qa_mask == 0) or (data == self.data_pile):
i = i + args.my_pile_shift
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
elif args.my_pile_stage == 4:

@ -76,7 +76,7 @@ if __name__ == "__main__":
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float)
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
parser.add_argument("--warmup_steps", default=-1, type=int) # try 50 if you load a model
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float)
@ -173,9 +173,18 @@ if __name__ == "__main__":
args.magic_prime = 40541399
args.epoch_count = 1005
else:
if args.ctx_len == 4096:
if args.ctx_len == 1024:
args.magic_prime = 1694947181
args.epoch_count = 42036
elif args.ctx_len == 2048:
args.magic_prime = 847473509
args.epoch_count = 21017
elif args.ctx_len == 4096:
args.magic_prime = 423736637
args.epoch_count = 10508
elif args.ctx_len == 6144:
args.magic_prime = 282491051
args.epoch_count = 7005
elif args.ctx_len == 8192:
args.magic_prime = 211868243
args.epoch_count = 5253
@ -207,6 +216,7 @@ if __name__ == "__main__":
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
if args.warmup_steps < 0:
if args.my_pile_stage == 2:
args.warmup_steps = 10
else:

Loading…
Cancel
Save