diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index 8fb2d27..0cc35a1 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -12,6 +12,8 @@ class train_callback(pl.Callback): def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): args = self.args + # if args.cuda_cleanup > 0: + # torch.cuda.empty_cache() g_step = trainer.global_step # LR schedule diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index ec40cf9..0c43e10 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -94,6 +94,7 @@ if __name__ == "__main__": parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough + # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) args = parser.parse_args() args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") @@ -123,13 +124,14 @@ if __name__ == "__main__": else: p = int(p) if p > max_p: + args.my_pile_prev_p = max_p # in case max_p is corrupted max_p = p if max_p == -1: args.load_model = f"{args.proj_dir}/rwkv-init.pth" else: args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" if args.my_pile_stage == 2: - args.warmup_steps = 10 + args.warmup_steps = 5 else: args.warmup_steps = 50 args.epoch_begin = max_p + 1 @@ -212,7 +214,21 @@ if __name__ == "__main__": args.load_model = init_weight_name print(f"########## Loading {args.load_model}... ##########") - load_dict = torch.load(args.load_model, map_location="cpu") + try: + load_dict = torch.load(args.load_model, map_location="cpu") + except: + print(f'Bad checkpoint {args.load_model}') + if args.my_pile_stage >= 2: # try again using another checkpoint + max_p = args.my_pile_prev_p + if max_p == -1: + args.load_model = f"{args.proj_dir}/rwkv-init.pth" + args.warmup_steps = 0 + else: + args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" + args.epoch_begin = max_p + 1 + print(f'Trying {args.load_model}') + load_dict = torch.load(args.load_model, map_location="cpu") + model.load_state_dict(load_dict) trainer = Trainer.from_argparse_args(