main
BlinkDL 3 years ago
parent 7476c69f32
commit c13879ab97

@ -153,14 +153,19 @@ class MyDataset(Dataset):
magic_prime = args.magic_prime
data = self.data
if args.my_pile_stage > 0:
if args.my_pile_stage > 0 and args.my_pile_stage != 4:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
if args.my_qa_mask > 0:
ii_orig = ii
if ii % 2 == 0:
ii = (ii // 2) * args.magic_prime
if args.ctx_len == 1024:
magic_prime = 324331313
elif args.ctx_len == 2048:
magic_prime = 162165671
elif args.ctx_len == 4096:
magic_prime = 81082817
data = self.data_pile
else:
ii = ii // 2

@ -4,6 +4,8 @@
import os, math, gc
import torch
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
import torch.nn as nn
from torch.nn import functional as F
import pytorch_lightning as pl

@ -11,7 +11,7 @@ def my_save(dd, ff):
fn = ff.split('/')[-1]
fff = '/dev/shm/' + fn
torch.save(dd, fff)
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b/{fn} --quiet", shell=True)
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
class train_callback(pl.Callback):
def __init__(self, args):
@ -106,7 +106,8 @@ class train_callback(pl.Callback):
lll["kt/s"] = kt_s
trainer.my_wandb.log(lll, step=int(real_step))
if args.magic_prime > 0:
if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1:
expand_factor = 2 if args.my_qa_mask > 0 else 1
if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1:
to_save_dict = pl_module.state_dict()
my_save(
to_save_dict,

@ -222,9 +222,9 @@ if __name__ == "__main__":
#
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
#
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
#
############################################################################
"""

Loading…
Cancel
Save