BlinkDL 3 years ago
parent ceafd4e7af
commit 7b92a979d8

@ -107,10 +107,11 @@ if __name__ == "__main__":
args.log_every_n_steps = int(1e20) args.log_every_n_steps = int(1e20)
args.max_epochs = -1 # continue forever args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2) args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.devices) * args.micro_bsz
if args.my_pile_stage > 0: if args.my_pile_stage > 0:
args.epoch_steps = 40320 // (int(args.devices) * args.micro_bsz) args.epoch_steps = 40320 // args.real_bsz
assert args.epoch_steps * int(args.devices) * args.micro_bsz == 40320 assert args.epoch_steps * args.real_bsz == 40320
if args.my_pile_stage == 2: if args.my_pile_stage == 2:
assert args.lr_final == args.lr_init assert args.lr_final == args.lr_init
if args.my_pile_stage >= 2: # find latest saved model if args.my_pile_stage >= 2: # find latest saved model
@ -131,12 +132,12 @@ if __name__ == "__main__":
else: else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
if args.my_pile_stage == 2: if args.my_pile_stage == 2:
args.warmup_steps = 5 args.warmup_steps = 10
else: else:
args.warmup_steps = 50 args.warmup_steps = 50
args.epoch_begin = max_p + 1 args.epoch_begin = max_p + 1
samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz samples_per_epoch = args.epoch_steps * args.real_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len tokens_per_epoch = samples_per_epoch * args.ctx_len
rank_zero_info( rank_zero_info(
f""" f"""

Loading…
Cancel
Save