From d1674732ed183ea16f4c5f7651d3343ebc79c2eb Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Tue, 6 Sep 2022 06:51:12 +0000 Subject: [PATCH] clean code --- RWKV-v4neo/src/dataset.py | 2 +- RWKV-v4neo/src/model.py | 7 +- RWKV-v4neo/src/trainer.py | 130 +++++++++++++++++++++++++++++++++++++ RWKV-v4neo/train.py | 133 +------------------------------------- 4 files changed, 138 insertions(+), 134 deletions(-) create mode 100644 RWKV-v4neo/src/trainer.py diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 4884315..51a9935 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -26,7 +26,7 @@ class MyDataset(Dataset): assert self.data_size == 332115325534 and self.vocab_size == 50277 and args.ctx_len == 1024 self.samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz assert self.samples_per_epoch == 40320 - print("########## Pile 20b-tokenized mode {args.my_pile_mode} ##########") + print(f"########## Pile 20b-tokenized mode {args.my_pile_mode} ##########") self.magic_prime = 324331313 dataset_slot = self.data_size // args.ctx_len assert MaybeIsPrime(self.magic_prime) diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 39a6326..acf006b 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -12,7 +12,7 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.strategies import DeepSpeedStrategy import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam - +# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam def __nop(ob): return ob @@ -44,7 +44,7 @@ class WKV(torch.autograd.Function): ctx.T = T ctx.C = C assert T <= T_MAX - assert B * C % min(C, 1024) == 0 + assert B * C % min(C, 32) == 0 if "32" in os.environ["RWKV_FLOAT_MODE"]: w = -torch.exp(w.contiguous()) u = u.contiguous() @@ -71,7 +71,7 @@ class WKV(torch.autograd.Function): T = ctx.T C = ctx.C assert T <= T_MAX - assert B * C % min(C, 1024) == 0 + assert B * C % min(C, 32) == 0 w, u, k, v = ctx.saved_tensors gw = torch.zeros((B, C), device="cuda").contiguous() gu = torch.zeros((B, C), device="cuda").contiguous() @@ -306,6 +306,7 @@ class RWKV(pl.LightningModule): if self.deepspeed_offload: return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) @property def deepspeed_offload(self) -> bool: diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py new file mode 100644 index 0000000..fd770b6 --- /dev/null +++ b/RWKV-v4neo/src/trainer.py @@ -0,0 +1,130 @@ +import os, math, time, datetime +import torch +from torch.utils.data import DataLoader +import pytorch_lightning as pl +from pytorch_lightning import seed_everything +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only +from pytorch_lightning.callbacks import TQDMProgressBar + +class train_callback(pl.Callback): + def __init__(self, args): + super().__init__() + self.args = args + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + args = self.args + g_step = trainer.global_step + + # LR schedule + w_step = args.warmup_steps + if g_step < w_step: + lr = args.lr_init * (g_step / w_step) + else: + if args.lr_final == args.lr_init: + lr = args.lr_init + else: + progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1) + progress = min(1, max(0, progress)) + + if args.lr_final == 0 or args.lr_init == 0: # linear decay + lr = args.lr_init + (args.lr_final - args.lr_init) * progress + else: # exp decay + lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) + + for param_group in trainer.optimizers[0].param_groups: + if args.layerwise_lr > 0: + if self.args.my_pile_mode == 0: + param_group["lr"] = lr * param_group["my_lr_scale"] + elif self.args.my_pile_mode == 2: + if param_group["my_lr_scale"] > 1: + param_group["lr"] = lr * 5 + else: + param_group["lr"] = lr + # print(param_group["lr"], param_group["my_lr_scale"]) + else: + param_group["lr"] = lr + + trainer.my_lr = lr + # rank_zero_info(f"{g_step} {lr}") + + if g_step == 0: + if trainer.is_global_zero: # logging + trainer.my_loss_sum = 0 + trainer.my_loss_count = 0 + trainer.my_log = open(args.proj_dir + "/train_log.txt", "a") + trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n") + try: + print(f"\n{trainer.strategy.config}\n") + trainer.my_log.write(f"{trainer.strategy.config}\n") + except: + pass + trainer.my_log.flush() + if len(args.wandb) > 0: + print("Login to wandb...") + import wandb + + model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd) + wandb.init( + project=args.wandb, + name=model_name + "-" + args.my_timestamp, + config=args, + save_code=False, + ) + trainer.my_wandb = wandb + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + args = self.args + if trainer.is_global_zero: # logging + t_now = time.time_ns() + try: + t_cost = (t_now - trainer.my_time_ns) / 1e9 + self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True) + self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True) + except: + pass + trainer.my_time_ns = t_now + trainer.my_loss = trainer.my_loss_all.float().mean().item() + trainer.my_loss_sum += trainer.my_loss + trainer.my_loss_count += 1 + trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count + self.log("lr", trainer.my_lr, prog_bar=True, on_step=True) + self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True) + + if len(args.wandb) > 0: + trainer.my_wandb.log( + {"loss": trainer.my_loss, "lr": trainer.my_lr}, + step=trainer.global_step, + ) + + def on_train_epoch_start(self, trainer, pl_module): + args = self.args + dataset = trainer.train_dataloader.dataset.datasets + assert "MyDataset" in str(dataset) + dataset.global_rank = trainer.global_rank + dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) + dataset.world_size = trainer.world_size + + def on_train_epoch_end(self, trainer, pl_module): + args = self.args + if trainer.is_global_zero: # logging & save state_dict + if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1: + torch.save( + pl_module.state_dict(), + f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", + ) + trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") + trainer.my_log.flush() + + trainer.my_loss_sum = 0 + trainer.my_loss_count = 0 + + +@rank_zero_only +def generate_init_weight(model, temp_name): + try: + os.remove(temp_name) + except: + pass + mm = model.generate_init_weight() + print(f"Saving to {temp_name}...") + torch.save(mm, temp_name) diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 4716da0..d99ec22 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -14,8 +14,6 @@ if __name__ == "__main__": from pytorch_lightning import Trainer from pytorch_lightning import seed_everything from pytorch_lightning.utilities import rank_zero_info, rank_zero_only - from pytorch_lightning.callbacks import TQDMProgressBar - from pytorch_lightning import Callback # print("WARNING: THIS IS ONLY FOR DEBUG") # seed_everything(42) @@ -95,7 +93,7 @@ if __name__ == "__main__": parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower parser.add_argument("--my_pile_mode", default=0, type=int) # my special pile mode parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) - parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 500 might be faster (but more VRAM) + parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough args = parser.parse_args() args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") @@ -161,8 +159,6 @@ if __name__ == "__main__": if "deepspeed_stage_3" in args.strategy: os.environ["RWKV_JIT_ON"] = "0" - import torch - torch.backends.cudnn.benchmark = True if args.precision == "fp32": torch.backends.cudnn.allow_tf32 = False @@ -179,131 +175,8 @@ if __name__ == "__main__": args.precision = "bf16" ######################################################################################################## - - class train_callback(pl.Callback): - def __init__(self, args): - super().__init__() - self.args = args - - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): - args = self.args - g_step = trainer.global_step - - # LR schedule - w_step = args.warmup_steps - if g_step < w_step: - lr = args.lr_init * (g_step / w_step) - else: - if args.lr_final == args.lr_init: - lr = args.lr_init - else: - progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1) - progress = min(1, max(0, progress)) - - if args.lr_final == 0 or args.lr_init == 0: # linear decay - lr = args.lr_init + (args.lr_final - args.lr_init) * progress - else: # exp decay - lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) - - for param_group in trainer.optimizers[0].param_groups: - if args.layerwise_lr > 0: - if self.args.my_pile_mode == 0: - param_group["lr"] = lr * param_group["my_lr_scale"] - elif self.args.my_pile_mode == 2: - if param_group["my_lr_scale"] > 1: - param_group["lr"] = lr * 5 - else: - param_group["lr"] = lr - # print(param_group["lr"], param_group["my_lr_scale"]) - else: - param_group["lr"] = lr - - trainer.my_lr = lr - # rank_zero_info(f"{g_step} {lr}") - - if g_step == 0: - if trainer.is_global_zero: # logging - trainer.my_loss_sum = 0 - trainer.my_loss_count = 0 - trainer.my_log = open(args.proj_dir + "/train_log.txt", "a") - trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n") - try: - print(f"\n{trainer.strategy.config}\n") - trainer.my_log.write(f"{trainer.strategy.config}\n") - except: - pass - trainer.my_log.flush() - if len(args.wandb) > 0: - print("Login to wandb...") - import wandb - - model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd) - wandb.init( - project=args.wandb, - name=model_name + "-" + args.my_timestamp, - config=args, - save_code=False, - ) - trainer.my_wandb = wandb - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - args = self.args - if trainer.is_global_zero: # logging - t_now = time.time_ns() - try: - t_cost = (t_now - trainer.my_time_ns) / 1e9 - self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True) - self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True) - except: - pass - trainer.my_time_ns = t_now - trainer.my_loss = trainer.my_loss_all.float().mean().item() - trainer.my_loss_sum += trainer.my_loss - trainer.my_loss_count += 1 - trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count - self.log("lr", trainer.my_lr, prog_bar=True, on_step=True) - self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True) - - if len(args.wandb) > 0: - trainer.my_wandb.log( - {"loss": trainer.my_loss, "lr": trainer.my_lr}, - step=trainer.global_step, - ) - - def on_train_epoch_start(self, trainer, pl_module): - args = self.args - dataset = trainer.train_dataloader.dataset.datasets - assert "MyDataset" in str(dataset) - dataset.global_rank = trainer.global_rank - dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) - dataset.world_size = trainer.world_size - - def on_train_epoch_end(self, trainer, pl_module): - args = self.args - if trainer.is_global_zero: # logging & save state_dict - if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1: - torch.save( - pl_module.state_dict(), - f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", - ) - trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") - trainer.my_log.flush() - - trainer.my_loss_sum = 0 - trainer.my_loss_count = 0 - - @rank_zero_only - def generate_init_weight(model, temp_name): - try: - os.remove(temp_name) - except: - pass - mm = model.generate_init_weight() - print(f"Saving to {temp_name}...") - torch.save(mm, temp_name) - - ######################################################################################################## - + + from src.trainer import train_callback, generate_init_weight from src.dataset import MyDataset from src.model import RWKV