From 7b92a979d8be0a314e176729ada7448c2269122d Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Tue, 6 Sep 2022 17:12:22 +0000 Subject: [PATCH] fix --- RWKV-v4neo/train.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 0c43e10..042c160 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -107,10 +107,11 @@ if __name__ == "__main__": args.log_every_n_steps = int(1e20) args.max_epochs = -1 # continue forever args.betas = (args.beta1, args.beta2) + args.real_bsz = int(args.devices) * args.micro_bsz if args.my_pile_stage > 0: - args.epoch_steps = 40320 // (int(args.devices) * args.micro_bsz) - assert args.epoch_steps * int(args.devices) * args.micro_bsz == 40320 + args.epoch_steps = 40320 // args.real_bsz + assert args.epoch_steps * args.real_bsz == 40320 if args.my_pile_stage == 2: assert args.lr_final == args.lr_init if args.my_pile_stage >= 2: # find latest saved model @@ -131,12 +132,12 @@ if __name__ == "__main__": else: args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" if args.my_pile_stage == 2: - args.warmup_steps = 5 + args.warmup_steps = 10 else: args.warmup_steps = 50 args.epoch_begin = max_p + 1 - samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz + samples_per_epoch = args.epoch_steps * args.real_bsz tokens_per_epoch = samples_per_epoch * args.ctx_len rank_zero_info( f"""