diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index c66f4cf..44e4789 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -153,14 +153,19 @@ class MyDataset(Dataset): magic_prime = args.magic_prime data = self.data - if args.my_pile_stage > 0: + if args.my_pile_stage > 0 and args.my_pile_stage != 4: ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank if args.my_qa_mask > 0: ii_orig = ii if ii % 2 == 0: ii = (ii // 2) * args.magic_prime - magic_prime = 324331313 + if args.ctx_len == 1024: + magic_prime = 324331313 + elif args.ctx_len == 2048: + magic_prime = 162165671 + elif args.ctx_len == 4096: + magic_prime = 81082817 data = self.data_pile else: ii = ii // 2 diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 621a398..67804c8 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -4,6 +4,8 @@ import os, math, gc import torch +torch._C._jit_set_profiling_executor(True) +torch._C._jit_set_profiling_mode(True) import torch.nn as nn from torch.nn import functional as F import pytorch_lightning as pl diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index 89407f0..d5cf452 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -11,7 +11,7 @@ def my_save(dd, ff): fn = ff.split('/')[-1] fff = '/dev/shm/' + fn torch.save(dd, fff) - subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b/{fn} --quiet", shell=True) + subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True) class train_callback(pl.Callback): def __init__(self, args): @@ -106,7 +106,8 @@ class train_callback(pl.Callback): lll["kt/s"] = kt_s trainer.my_wandb.log(lll, step=int(real_step)) if args.magic_prime > 0: - if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1: + expand_factor = 2 if args.my_qa_mask > 0 else 1 + if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1: to_save_dict = pl_module.state_dict() my_save( to_save_dict, diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index ba63e03..d148d90 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -222,9 +222,9 @@ if __name__ == "__main__": # # Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} # -# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer +# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer # Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions) -# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer +# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer # ############################################################################ """