|
|
|
@ -94,6 +94,7 @@ if __name__ == "__main__":
|
|
|
|
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
|
|
|
|
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
|
|
|
|
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
|
|
|
|
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
|
|
|
|
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
|
|
|
|
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
|
|
|
|
|
|
|
|
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
|
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
|
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
|
|
|
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
|
|
|
@ -123,13 +124,14 @@ if __name__ == "__main__":
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
p = int(p)
|
|
|
|
p = int(p)
|
|
|
|
if p > max_p:
|
|
|
|
if p > max_p:
|
|
|
|
|
|
|
|
args.my_pile_prev_p = max_p # in case max_p is corrupted
|
|
|
|
max_p = p
|
|
|
|
max_p = p
|
|
|
|
if max_p == -1:
|
|
|
|
if max_p == -1:
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
|
|
|
if args.my_pile_stage == 2:
|
|
|
|
if args.my_pile_stage == 2:
|
|
|
|
args.warmup_steps = 10
|
|
|
|
args.warmup_steps = 5
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
args.warmup_steps = 50
|
|
|
|
args.warmup_steps = 50
|
|
|
|
args.epoch_begin = max_p + 1
|
|
|
|
args.epoch_begin = max_p + 1
|
|
|
|
@ -212,7 +214,21 @@ if __name__ == "__main__":
|
|
|
|
args.load_model = init_weight_name
|
|
|
|
args.load_model = init_weight_name
|
|
|
|
|
|
|
|
|
|
|
|
print(f"########## Loading {args.load_model}... ##########")
|
|
|
|
print(f"########## Loading {args.load_model}... ##########")
|
|
|
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
|
|
|
try:
|
|
|
|
|
|
|
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
|
|
|
|
|
|
|
except:
|
|
|
|
|
|
|
|
print(f'Bad checkpoint {args.load_model}')
|
|
|
|
|
|
|
|
if args.my_pile_stage >= 2: # try again using another checkpoint
|
|
|
|
|
|
|
|
max_p = args.my_pile_prev_p
|
|
|
|
|
|
|
|
if max_p == -1:
|
|
|
|
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
|
|
|
|
|
|
|
args.warmup_steps = 0
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
|
|
|
|
|
|
|
args.epoch_begin = max_p + 1
|
|
|
|
|
|
|
|
print(f'Trying {args.load_model}')
|
|
|
|
|
|
|
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
|
|
|
|
|
|
|
|
|
|
|
model.load_state_dict(load_dict)
|
|
|
|
model.load_state_dict(load_dict)
|
|
|
|
|
|
|
|
|
|
|
|
trainer = Trainer.from_argparse_args(
|
|
|
|
trainer = Trainer.from_argparse_args(
|
|
|
|
|