|
|
|
|
@ -11,7 +11,7 @@ def my_save(dd, ff):
|
|
|
|
|
fn = ff.split('/')[-1]
|
|
|
|
|
fff = '/dev/shm/' + fn
|
|
|
|
|
torch.save(dd, fff)
|
|
|
|
|
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b/{fn} --quiet", shell=True)
|
|
|
|
|
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
|
|
|
|
|
|
|
|
|
|
class train_callback(pl.Callback):
|
|
|
|
|
def __init__(self, args):
|
|
|
|
|
@ -106,7 +106,8 @@ class train_callback(pl.Callback):
|
|
|
|
|
lll["kt/s"] = kt_s
|
|
|
|
|
trainer.my_wandb.log(lll, step=int(real_step))
|
|
|
|
|
if args.magic_prime > 0:
|
|
|
|
|
if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1:
|
|
|
|
|
expand_factor = 2 if args.my_qa_mask > 0 else 1
|
|
|
|
|
if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1:
|
|
|
|
|
to_save_dict = pl_module.state_dict()
|
|
|
|
|
my_save(
|
|
|
|
|
to_save_dict,
|
|
|
|
|
|