diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 33f40c1..0e28642 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -2,12 +2,13 @@ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## -import json +import json, math import numpy as np import torch from torch.utils.data import Dataset from pytorch_lightning.utilities import rank_zero_info from .binidx import MMapIndexedDataset +from .utils import MaybeIsPrime class MyDataset(Dataset): @@ -20,6 +21,18 @@ class MyDataset(Dataset): print("current vocab size =", self.vocab_size, "(make sure it's correct)") self.data_size = len(self.data._bin_buffer) // 2 print(f"data has {self.data_size} tokens.") + + if args.my_pile_mode > 0: + assert self.data_size == 332115325534 and self.vocab_size == 50277 and args.ctx_len == 1024 + self.samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz + assert self.samples_per_epoch == 40320 + print("########## Pile 20b-tokenized mode {args.my_pile_mode} ##########") + self.magic_prime = 324331313 + dataset_slot = self.data_size // args.ctx_len + assert MaybeIsPrime(self.magic_prime) + assert self.magic_prime % 3 == 2 + assert self.magic_prime / dataset_slot > 0.999999 and self.magic_prime / dataset_slot <= 1 + elif args.data_type == "numpy": self.data = np.load(args.data_file).astype("int") self.vocab_size = args.vocab_size @@ -48,15 +61,29 @@ class MyDataset(Dataset): self.itos = {i: ch for i, ch in enumerate(unique)} def __len__(self): - return self.args.epoch_steps * int(self.args.devices) * self.args.micro_bsz + return self.args.epoch_steps * self.args.micro_bsz def __getitem__(self, idx): # # we are cheating: pick a random spot in dataset # + rank = self.global_rank + epoch = self.real_epoch + world_size = self.world_size + # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}") + ctx_len = self.args.ctx_len req_len = ctx_len + 1 - i = np.random.randint(0, self.data_size - req_len) + + if self.args.my_pile_mode > 0: + ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank + factor = (math.sqrt(5) - 1) / 2 + factor = int(self.magic_prime * factor) + i = ((factor * ii * ii * ii) % self.magic_prime) * ctx_len + # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") + else: + i = np.random.randint(0, self.data_size - req_len) + if "MMapIndexedDataset" in str(type(self.data)): dix = self.data.get(idx=0, offset=i, length=req_len).astype(int) elif "numpy" in str(type(self.data)): diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index b091bef..d29db67 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -2,7 +2,7 @@ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## -import os, math, gc, time +import os, math, gc from re import L import torch import torch.nn as nn @@ -20,7 +20,7 @@ def __nop(ob): MyModule = nn.Module MyFunction = __nop -if os.environ["RWKV_JIT"] == "1": +if os.environ["RWKV_JIT_ON"] == "1": MyModule = torch.jit.ScriptModule MyFunction = torch.jit.script_method @@ -273,9 +273,31 @@ class RWKV(pl.LightningModule): self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) def configure_optimizers(self): + lr_1x = set() + lr_2x = set() + lr_3x = set() + for n, p in self.named_parameters(): + if ("time_mix" in n) and (self.args.my_pile_mode == 2): + lr_2x.add(n) + elif "time_decay" in n: + lr_2x.add(n) + elif "time_first" in n: + lr_3x.add(n) + else: + lr_1x.add(n) + lr_1x = sorted(list(lr_1x)) + lr_2x = sorted(list(lr_2x)) + lr_3x = sorted(list(lr_3x)) + # print('1x', lr_1x) + # print('2x', lr_2x) + # print('3x', lr_3x) + param_dict = {n: p for n, p in self.named_parameters()} optim_groups = [ - {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, ] + if self.deepspeed_offload: return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) @@ -326,25 +348,13 @@ class RWKV(pl.LightningModule): idx, targets = batch logits = self(idx) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) - - if self.trainer.global_rank == 0: - t_now = time.time_ns() - try: - t_cost = (t_now - self.trainer.my_time_ns) / 1e9 - self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True) - self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True) - except: - pass - self.trainer.my_time_ns = t_now - self.trainer.my_loss = loss.item() - self.trainer.my_loss_sum += self.trainer.my_loss - self.trainer.my_loss_count += 1 - self.trainer.my_epoch_loss = self.trainer.my_loss_sum / self.trainer.my_loss_count - self.log("lr", self.trainer.my_lr, prog_bar=True, on_step=True) - self.log("loss", self.trainer.my_epoch_loss, prog_bar=True, on_step=True) - return L2Wrap.apply(loss, logits) + def training_step_end(self, batch_parts): + all = self.all_gather(batch_parts) + if self.trainer.is_global_zero: + self.trainer.my_loss_all = all + def generate_init_weight(self): print( f""" diff --git a/RWKV-v4neo/src/utils.py b/RWKV-v4neo/src/utils.py new file mode 100644 index 0000000..f2cbe99 --- /dev/null +++ b/RWKV-v4neo/src/utils.py @@ -0,0 +1,50 @@ +import random + + +def MaybeIsPrime(number): + if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number): + return True + else: + return False + + +def FermatPrimalityTest(number): + if number > 1: + for time in range(3): + randomNumber = random.randint(2, number) - 1 + if pow(randomNumber, number - 1, number) != 1: + return False + return True + else: + return False + + +def MillerRabinPrimalityTest(number): + if number == 2: + return True + elif number == 1 or number % 2 == 0: + return False + oddPartOfNumber = number - 1 + timesTwoDividNumber = 0 + while oddPartOfNumber % 2 == 0: + oddPartOfNumber = oddPartOfNumber // 2 + timesTwoDividNumber = timesTwoDividNumber + 1 + + for time in range(3): + while True: + randomNumber = random.randint(2, number) - 1 + if randomNumber != 0 and randomNumber != 1: + break + + randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number) + + if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1): + iterationNumber = 1 + + while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1): + randomNumberWithPower = pow(randomNumberWithPower, 2, number) + iterationNumber = iterationNumber + 1 + if randomNumberWithPower != (number - 1): + return False + + return True diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 28c0928..33ae2d3 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -4,10 +4,11 @@ if __name__ == "__main__": print("\n!!! NOTE: THIS IS STILL WIP !!!\n") - import os, warnings, math, datetime, sys + import os, warnings, math, datetime, sys, time import numpy as np from argparse import ArgumentParser import torch + from torch.utils.data import DataLoader import deepspeed import pytorch_lightning as pl from pytorch_lightning import Trainer @@ -83,10 +84,12 @@ if __name__ == "__main__": parser.add_argument("--adam_eps", default=1e-8, type=float) parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower + parser.add_argument("--my_pile_mode", default=0, type=int) # my special pile mode args = parser.parse_args() args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") args.enable_checkpointing = False + args.replace_sampler_ddp = False args.logger = False args.gradient_clip_val = 1.0 args.num_sanity_val_steps = 0 @@ -95,6 +98,12 @@ if __name__ == "__main__": args.max_epochs = -1 # continue forever args.betas = (args.beta1, args.beta2) + if args.my_pile_mode > 0: + args.epoch_steps = 40320 // (int(args.devices) * args.micro_bsz) + assert args.epoch_steps * int(args.devices) * args.micro_bsz == 40320 + if args.my_pile_mode == 2: + assert args.lr_final == args.lr_init + samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz tokens_per_epoch = samples_per_epoch * args.ctx_len rank_zero_info( @@ -138,9 +147,9 @@ if __name__ == "__main__": if args.precision == "fp16": rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") - os.environ["RWKV_JIT"] = "1" + os.environ["RWKV_JIT_ON"] = "1" if "deepspeed_stage_3" in args.strategy: - os.environ["RWKV_JIT"] = "0" + os.environ["RWKV_JIT_ON"] = "0" import torch @@ -170,9 +179,37 @@ if __name__ == "__main__": args = self.args g_step = trainer.global_step - # logging - if trainer.global_rank == 0: - if g_step == 0: + # LR schedule + w_step = args.warmup_steps + if g_step < w_step: + lr = args.lr_init * (g_step / w_step) + else: + if args.lr_final == args.lr_init: + lr = args.lr_init + else: + progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1) + progress = min(1, max(0, progress)) + + if args.lr_final == 0 or args.lr_init == 0: # linear decay + lr = args.lr_init + (args.lr_final - args.lr_init) * progress + else: # exp decay + lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) + + for param_group in trainer.optimizers[0].param_groups: + if self.args.my_pile_mode == 0: + param_group["lr"] = lr * param_group["my_lr_scale"] + elif self.args.my_pile_mode == 2: + if param_group["my_lr_scale"] > 1: + param_group["lr"] = lr * 5 + else: + param_group["lr"] = lr + # print(param_group["lr"], param_group["my_lr_scale"]) + + trainer.my_lr = lr + # rank_zero_info(f"{g_step} {lr}") + + if g_step == 0: + if trainer.is_global_zero: # logging trainer.my_loss_sum = 0 trainer.my_loss_count = 0 trainer.my_log = open(args.proj_dir + "/train_log.txt", "a") @@ -196,39 +233,42 @@ if __name__ == "__main__": ) trainer.my_wandb = wandb - # LR schedule - w_step = args.warmup_steps - if g_step < w_step: - lr = args.lr_init * (g_step / w_step) - else: - progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1) - progress = min(1, max(0, progress)) - - if args.lr_final == 0 or args.lr_init == 0: # linear decay - lr = args.lr_init + (args.lr_final - args.lr_init) * progress - else: # exp decay - lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) - - for param_group in trainer.optimizers[0].param_groups: - param_group["lr"] = lr - - trainer.my_lr = lr - # rank_zero_info(f"{g_step} {lr}") - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): args = self.args - # logging - if trainer.global_rank == 0: + if trainer.is_global_zero: # logging + t_now = time.time_ns() + try: + t_cost = (t_now - trainer.my_time_ns) / 1e9 + self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True) + self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True) + except: + pass + trainer.my_time_ns = t_now + trainer.my_loss = trainer.my_loss_all.float().mean().item() + trainer.my_loss_sum += trainer.my_loss + trainer.my_loss_count += 1 + trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count + self.log("lr", trainer.my_lr, prog_bar=True, on_step=True) + self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True) + if len(args.wandb) > 0: trainer.my_wandb.log( {"loss": trainer.my_loss, "lr": trainer.my_lr}, step=trainer.global_step, ) + def on_train_epoch_start(self, trainer, pl_module): + args = self.args + dataset = trainer.train_dataloader.dataset.datasets + assert "MyDataset" in str(dataset) + dataset.global_rank = trainer.global_rank + dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) + dataset.world_size = trainer.world_size + def on_train_epoch_end(self, trainer, pl_module): args = self.args - if trainer.global_rank == 0: - if trainer.current_epoch % args.epoch_save == 0 or trainer.current_epoch == args.epoch_count - 1: + if trainer.is_global_zero: # logging & save state_dict + if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1: torch.save( pl_module.state_dict(), f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", @@ -251,7 +291,6 @@ if __name__ == "__main__": ######################################################################################################## - from torch.utils.data import DataLoader from src.dataset import MyDataset from src.model import RWKV @@ -261,8 +300,8 @@ if __name__ == "__main__": model = RWKV(args) if len(args.load_model) == 0: - args.load_model = f"{args.proj_dir}/rwkv-init.pth" # init weights to tmp file - generate_init_weight(model, args.load_model) + args.load_model = f"{args.proj_dir}/rwkv-init.pth" + generate_init_weight(model, args.load_model) # save initial weights to tmp file print(f"########## Loading {args.load_model}... ##########") load_dict = torch.load(args.load_model, map_location="cpu") @@ -273,5 +312,7 @@ if __name__ == "__main__": callbacks=[train_callback(args)], ) - train_loader = DataLoader(train_data, batch_size=args.micro_bsz, num_workers=1) - trainer.fit(model, train_loader) + # must set shuffle=False, persistent_workers=False (because worker is in another thread) + data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) + + trainer.fit(model, data_loader)