|
|
|
@ -76,7 +76,7 @@ if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
|
|
|
|
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
|
|
|
|
parser.add_argument("--lr_final", default=1e-5, type=float)
|
|
|
|
parser.add_argument("--lr_final", default=1e-5, type=float)
|
|
|
|
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
|
|
|
|
parser.add_argument("--warmup_steps", default=-1, type=int) # try 50 if you load a model
|
|
|
|
parser.add_argument("--beta1", default=0.9, type=float)
|
|
|
|
parser.add_argument("--beta1", default=0.9, type=float)
|
|
|
|
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
|
|
|
|
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
|
|
|
|
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
|
|
|
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
|
|
|
@ -173,9 +173,18 @@ if __name__ == "__main__":
|
|
|
|
args.magic_prime = 40541399
|
|
|
|
args.magic_prime = 40541399
|
|
|
|
args.epoch_count = 1005
|
|
|
|
args.epoch_count = 1005
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if args.ctx_len == 4096:
|
|
|
|
if args.ctx_len == 1024:
|
|
|
|
|
|
|
|
args.magic_prime = 1694947181
|
|
|
|
|
|
|
|
args.epoch_count = 42036
|
|
|
|
|
|
|
|
elif args.ctx_len == 2048:
|
|
|
|
|
|
|
|
args.magic_prime = 847473509
|
|
|
|
|
|
|
|
args.epoch_count = 21017
|
|
|
|
|
|
|
|
elif args.ctx_len == 4096:
|
|
|
|
args.magic_prime = 423736637
|
|
|
|
args.magic_prime = 423736637
|
|
|
|
args.epoch_count = 10508
|
|
|
|
args.epoch_count = 10508
|
|
|
|
|
|
|
|
elif args.ctx_len == 6144:
|
|
|
|
|
|
|
|
args.magic_prime = 282491051
|
|
|
|
|
|
|
|
args.epoch_count = 7005
|
|
|
|
elif args.ctx_len == 8192:
|
|
|
|
elif args.ctx_len == 8192:
|
|
|
|
args.magic_prime = 211868243
|
|
|
|
args.magic_prime = 211868243
|
|
|
|
args.epoch_count = 5253
|
|
|
|
args.epoch_count = 5253
|
|
|
|
@ -207,6 +216,7 @@ if __name__ == "__main__":
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
|
|
|
|
|
|
|
if args.warmup_steps < 0:
|
|
|
|
if args.my_pile_stage == 2:
|
|
|
|
if args.my_pile_stage == 2:
|
|
|
|
args.warmup_steps = 10
|
|
|
|
args.warmup_steps = 10
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|