diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index ccd9f00..45c25e4 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -23,15 +23,14 @@ class MyDataset(Dataset): print(f"Data has {self.data_size} tokens.") if args.my_pile_stage > 0: - assert self.data_size == 332115325534 and self.vocab_size == 50277 and args.ctx_len == 1024 + assert self.data_size == 332115325534 and self.vocab_size == 50277 self.samples_per_epoch = args.epoch_steps * args.real_bsz assert self.samples_per_epoch == 40320 print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") - self.magic_prime = 324331313 dataset_slot = self.data_size // args.ctx_len - assert MaybeIsPrime(self.magic_prime) - assert self.magic_prime % 3 == 2 - assert self.magic_prime / dataset_slot > 0.999999 and self.magic_prime / dataset_slot <= 1 + assert MaybeIsPrime(args.magic_prime) + assert args.magic_prime % 3 == 2 + assert args.magic_prime / dataset_slot > 0.999999 and args.magic_prime / dataset_slot <= 1 elif args.data_type == "numpy": self.data = np.load(args.data_file).astype("int") self.vocab_size = args.vocab_size @@ -87,8 +86,9 @@ class MyDataset(Dataset): if args.my_pile_stage > 0: ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank factor = (math.sqrt(5) - 1) / 2 - factor = int(self.magic_prime * factor) - i = ((factor * ii * ii * ii) % self.magic_prime) * ctx_len + factor = int(args.magic_prime * factor) + i = ((factor * ii * ii * ii) % args.magic_prime) * ctx_len + i = i + args.my_pile_shift # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") else: i = np.random.randint(0, self.data_size - req_len) diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index afbe750..9dbf6ac 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -31,12 +31,12 @@ if os.environ["RWKV_JIT_ON"] == "1": # CUDA Kernel ######################################################################################################## -T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] +T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM! # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice from torch.utils.cpp_extension import load -wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", f"-DTmax={T_MAX}"]) +wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", f"-DTmax={T_MAX}"]) class WKV(torch.autograd.Function): diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index 81bea20..74f3ab9 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -14,17 +14,17 @@ class train_callback(pl.Callback): args = self.args # if args.cuda_cleanup > 0: # torch.cuda.empty_cache() - g_step = trainer.global_step + real_step = trainer.global_step + args.epoch_begin * args.epoch_steps # LR schedule w_step = args.warmup_steps - if g_step < w_step: - lr = args.lr_init * (g_step / w_step) + if trainer.global_step < w_step: + lr = args.lr_init * (trainer.global_step / w_step) else: - if args.lr_final == args.lr_init: + if args.lr_final == args.lr_init or args.epoch_count == 0: lr = args.lr_init else: - progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1) + progress = (real_step - w_step + 1) / (args.epoch_count * args.epoch_steps - w_step) progress = min(1, max(0, progress)) if args.lr_final == 0 or args.lr_init == 0: # linear decay @@ -40,9 +40,9 @@ class train_callback(pl.Callback): param_group["lr"] = lr trainer.my_lr = lr - # rank_zero_info(f"{g_step} {lr}") + # rank_zero_info(f"{real_step} {lr}") - if g_step == 0: + if trainer.global_step == 0: if trainer.is_global_zero: # logging trainer.my_loss_sum = 0 trainer.my_loss_count = 0 diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 042c160..734489c 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -92,6 +92,7 @@ if __name__ == "__main__": parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode + parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) @@ -108,8 +109,27 @@ if __name__ == "__main__": args.max_epochs = -1 # continue forever args.betas = (args.beta1, args.beta2) args.real_bsz = int(args.devices) * args.micro_bsz + os.environ["RWKV_T_MAX"] = str(args.ctx_len) + + if not os.path.exists(args.proj_dir): + os.makedirs(args.proj_dir) if args.my_pile_stage > 0: + if args.ctx_len == 1024: + args.magic_prime = 324331313 + elif args.ctx_len == 2048: + args.magic_prime = 162165671 + elif args.ctx_len == 4096: + args.magic_prime = 81082817 + if args.my_pile_shift < 0: + if args.ctx_len == 1024: + args.my_pile_shift = 0 + elif args.ctx_len == 2048: + args.my_pile_shift = 512 + elif args.ctx_len == 4096: + args.my_pile_shift = 768 + + args.epoch_count = 8043 args.epoch_steps = 40320 // args.real_bsz assert args.epoch_steps * args.real_bsz == 40320 if args.my_pile_stage == 2: @@ -125,7 +145,7 @@ if __name__ == "__main__": else: p = int(p) if p > max_p: - args.my_pile_prev_p = max_p # in case max_p is corrupted + args.my_pile_prev_p = max_p # in case max_p is corrupted max_p = p if max_p == -1: args.load_model = f"{args.proj_dir}/rwkv-init.pth" @@ -164,9 +184,6 @@ if __name__ == "__main__": ) rank_zero_info(str(vars(args)) + "\n") - if not os.path.exists(args.proj_dir): - os.makedirs(args.proj_dir) - assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy"] if args.lr_final == 0 or args.lr_init == 0: @@ -218,8 +235,8 @@ if __name__ == "__main__": try: load_dict = torch.load(args.load_model, map_location="cpu") except: - print(f'Bad checkpoint {args.load_model}') - if args.my_pile_stage >= 2: # try again using another checkpoint + print(f"Bad checkpoint {args.load_model}") + if args.my_pile_stage >= 2: # try again using another checkpoint max_p = args.my_pile_prev_p if max_p == -1: args.load_model = f"{args.proj_dir}/rwkv-init.pth" @@ -227,7 +244,7 @@ if __name__ == "__main__": else: args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" args.epoch_begin = max_p + 1 - print(f'Trying {args.load_model}') + print(f"Trying {args.load_model}") load_dict = torch.load(args.load_model, map_location="cpu") model.load_state_dict(load_dict)