main
BlinkDL 3 years ago
parent 99a3dff414
commit 94300caba1

@ -12,6 +12,8 @@ class train_callback(pl.Callback):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
args = self.args args = self.args
# if args.cuda_cleanup > 0:
# torch.cuda.empty_cache()
g_step = trainer.global_step g_step = trainer.global_step
# LR schedule # LR schedule

@ -94,6 +94,7 @@ if __name__ == "__main__":
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
args = parser.parse_args() args = parser.parse_args()
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
@ -123,13 +124,14 @@ if __name__ == "__main__":
else: else:
p = int(p) p = int(p)
if p > max_p: if p > max_p:
args.my_pile_prev_p = max_p # in case max_p is corrupted
max_p = p max_p = p
if max_p == -1: if max_p == -1:
args.load_model = f"{args.proj_dir}/rwkv-init.pth" args.load_model = f"{args.proj_dir}/rwkv-init.pth"
else: else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
if args.my_pile_stage == 2: if args.my_pile_stage == 2:
args.warmup_steps = 10 args.warmup_steps = 5
else: else:
args.warmup_steps = 50 args.warmup_steps = 50
args.epoch_begin = max_p + 1 args.epoch_begin = max_p + 1
@ -212,7 +214,21 @@ if __name__ == "__main__":
args.load_model = init_weight_name args.load_model = init_weight_name
print(f"########## Loading {args.load_model}... ##########") print(f"########## Loading {args.load_model}... ##########")
try:
load_dict = torch.load(args.load_model, map_location="cpu") load_dict = torch.load(args.load_model, map_location="cpu")
except:
print(f'Bad checkpoint {args.load_model}')
if args.my_pile_stage >= 2: # try again using another checkpoint
max_p = args.my_pile_prev_p
if max_p == -1:
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
args.warmup_steps = 0
else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
args.epoch_begin = max_p + 1
print(f'Trying {args.load_model}')
load_dict = torch.load(args.load_model, map_location="cpu")
model.load_state_dict(load_dict) model.load_state_dict(load_dict)
trainer = Trainer.from_argparse_args( trainer = Trainer.from_argparse_args(

Loading…
Cancel
Save