|
|
|
|
@ -107,10 +107,11 @@ if __name__ == "__main__":
|
|
|
|
|
args.log_every_n_steps = int(1e20)
|
|
|
|
|
args.max_epochs = -1 # continue forever
|
|
|
|
|
args.betas = (args.beta1, args.beta2)
|
|
|
|
|
args.real_bsz = int(args.devices) * args.micro_bsz
|
|
|
|
|
|
|
|
|
|
if args.my_pile_stage > 0:
|
|
|
|
|
args.epoch_steps = 40320 // (int(args.devices) * args.micro_bsz)
|
|
|
|
|
assert args.epoch_steps * int(args.devices) * args.micro_bsz == 40320
|
|
|
|
|
args.epoch_steps = 40320 // args.real_bsz
|
|
|
|
|
assert args.epoch_steps * args.real_bsz == 40320
|
|
|
|
|
if args.my_pile_stage == 2:
|
|
|
|
|
assert args.lr_final == args.lr_init
|
|
|
|
|
if args.my_pile_stage >= 2: # find latest saved model
|
|
|
|
|
@ -131,12 +132,12 @@ if __name__ == "__main__":
|
|
|
|
|
else:
|
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
|
|
|
|
if args.my_pile_stage == 2:
|
|
|
|
|
args.warmup_steps = 5
|
|
|
|
|
args.warmup_steps = 10
|
|
|
|
|
else:
|
|
|
|
|
args.warmup_steps = 50
|
|
|
|
|
args.epoch_begin = max_p + 1
|
|
|
|
|
|
|
|
|
|
samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
|
|
|
|
|
samples_per_epoch = args.epoch_steps * args.real_bsz
|
|
|
|
|
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
|
|
|
|
rank_zero_info(
|
|
|
|
|
f"""
|
|
|
|
|
|