main
BlinkDL 3 years ago
parent 7476c69f32
commit c13879ab97

@ -153,14 +153,19 @@ class MyDataset(Dataset):
magic_prime = args.magic_prime magic_prime = args.magic_prime
data = self.data data = self.data
if args.my_pile_stage > 0: if args.my_pile_stage > 0 and args.my_pile_stage != 4:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
if args.my_qa_mask > 0: if args.my_qa_mask > 0:
ii_orig = ii ii_orig = ii
if ii % 2 == 0: if ii % 2 == 0:
ii = (ii // 2) * args.magic_prime ii = (ii // 2) * args.magic_prime
magic_prime = 324331313 if args.ctx_len == 1024:
magic_prime = 324331313
elif args.ctx_len == 2048:
magic_prime = 162165671
elif args.ctx_len == 4096:
magic_prime = 81082817
data = self.data_pile data = self.data_pile
else: else:
ii = ii // 2 ii = ii // 2

@ -4,6 +4,8 @@
import os, math, gc import os, math, gc
import torch import torch
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
import pytorch_lightning as pl import pytorch_lightning as pl

@ -11,7 +11,7 @@ def my_save(dd, ff):
fn = ff.split('/')[-1] fn = ff.split('/')[-1]
fff = '/dev/shm/' + fn fff = '/dev/shm/' + fn
torch.save(dd, fff) torch.save(dd, fff)
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b/{fn} --quiet", shell=True) subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
class train_callback(pl.Callback): class train_callback(pl.Callback):
def __init__(self, args): def __init__(self, args):
@ -106,7 +106,8 @@ class train_callback(pl.Callback):
lll["kt/s"] = kt_s lll["kt/s"] = kt_s
trainer.my_wandb.log(lll, step=int(real_step)) trainer.my_wandb.log(lll, step=int(real_step))
if args.magic_prime > 0: if args.magic_prime > 0:
if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1: expand_factor = 2 if args.my_qa_mask > 0 else 1
if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1:
to_save_dict = pl_module.state_dict() to_save_dict = pl_module.state_dict()
my_save( my_save(
to_save_dict, to_save_dict,

@ -222,9 +222,9 @@ if __name__ == "__main__":
# #
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} # Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
# #
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer # Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions) # Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer # Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
# #
############################################################################ ############################################################################
""" """

Loading…
Cancel
Save