|
|
|
|
@ -92,6 +92,7 @@ if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
|
|
|
|
|
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
|
|
|
|
|
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
|
|
|
|
|
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
|
|
|
|
|
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
|
|
|
|
|
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
|
|
|
|
|
@ -108,8 +109,27 @@ if __name__ == "__main__":
|
|
|
|
|
args.max_epochs = -1 # continue forever
|
|
|
|
|
args.betas = (args.beta1, args.beta2)
|
|
|
|
|
args.real_bsz = int(args.devices) * args.micro_bsz
|
|
|
|
|
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(args.proj_dir):
|
|
|
|
|
os.makedirs(args.proj_dir)
|
|
|
|
|
|
|
|
|
|
if args.my_pile_stage > 0:
|
|
|
|
|
if args.ctx_len == 1024:
|
|
|
|
|
args.magic_prime = 324331313
|
|
|
|
|
elif args.ctx_len == 2048:
|
|
|
|
|
args.magic_prime = 162165671
|
|
|
|
|
elif args.ctx_len == 4096:
|
|
|
|
|
args.magic_prime = 81082817
|
|
|
|
|
if args.my_pile_shift < 0:
|
|
|
|
|
if args.ctx_len == 1024:
|
|
|
|
|
args.my_pile_shift = 0
|
|
|
|
|
elif args.ctx_len == 2048:
|
|
|
|
|
args.my_pile_shift = 512
|
|
|
|
|
elif args.ctx_len == 4096:
|
|
|
|
|
args.my_pile_shift = 768
|
|
|
|
|
|
|
|
|
|
args.epoch_count = 8043
|
|
|
|
|
args.epoch_steps = 40320 // args.real_bsz
|
|
|
|
|
assert args.epoch_steps * args.real_bsz == 40320
|
|
|
|
|
if args.my_pile_stage == 2:
|
|
|
|
|
@ -125,7 +145,7 @@ if __name__ == "__main__":
|
|
|
|
|
else:
|
|
|
|
|
p = int(p)
|
|
|
|
|
if p > max_p:
|
|
|
|
|
args.my_pile_prev_p = max_p # in case max_p is corrupted
|
|
|
|
|
args.my_pile_prev_p = max_p # in case max_p is corrupted
|
|
|
|
|
max_p = p
|
|
|
|
|
if max_p == -1:
|
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
|
|
|
|
@ -164,9 +184,6 @@ if __name__ == "__main__":
|
|
|
|
|
)
|
|
|
|
|
rank_zero_info(str(vars(args)) + "\n")
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(args.proj_dir):
|
|
|
|
|
os.makedirs(args.proj_dir)
|
|
|
|
|
|
|
|
|
|
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy"]
|
|
|
|
|
|
|
|
|
|
if args.lr_final == 0 or args.lr_init == 0:
|
|
|
|
|
@ -218,8 +235,8 @@ if __name__ == "__main__":
|
|
|
|
|
try:
|
|
|
|
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
|
|
|
|
except:
|
|
|
|
|
print(f'Bad checkpoint {args.load_model}')
|
|
|
|
|
if args.my_pile_stage >= 2: # try again using another checkpoint
|
|
|
|
|
print(f"Bad checkpoint {args.load_model}")
|
|
|
|
|
if args.my_pile_stage >= 2: # try again using another checkpoint
|
|
|
|
|
max_p = args.my_pile_prev_p
|
|
|
|
|
if max_p == -1:
|
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
|
|
|
|
@ -227,7 +244,7 @@ if __name__ == "__main__":
|
|
|
|
|
else:
|
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
|
|
|
|
args.epoch_begin = max_p + 1
|
|
|
|
|
print(f'Trying {args.load_model}')
|
|
|
|
|
print(f"Trying {args.load_model}")
|
|
|
|
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
|
|
|
|
|
|
|
|
|
model.load_state_dict(load_dict)
|
|
|
|
|
|