From cdb098c0e057e75082e2785972219f320fb4e282 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Mon, 5 Sep 2022 09:07:09 +0000 Subject: [PATCH] fix --- RWKV-v4neo/src/model.py | 56 +++++++++++++++------------- RWKV-v4neo/train.py | 82 +++++++++++++++++++++++++++++------------ 2 files changed, 90 insertions(+), 48 deletions(-) diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 13c0bf2..cd7cced 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -339,36 +339,42 @@ class RWKV(pl.LightningModule): gain = 1.0 scale = 1.0 if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n: - m[n] = p.cpu() - continue - elif n == "emb.weight": - scale = -25 * self.args.lr_init + m[n] = p else: - if shape[0] > shape[1]: - gain = math.sqrt(shape[0] / shape[1]) - for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q."]: - if kk in n: + if n == "emb.weight": + scale = -25 * self.args.lr_init + else: + if shape[0] > shape[1]: + gain = math.sqrt(shape[0] / shape[1]) + for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q."]: + if kk in n: + scale = 0 + if n == "head.weight": + scale = 0.5 + if "head_k." in n: + scale = 0.1 + if "head_q." in n: scale = 0 - if n == "head.weight": - scale = 0.5 - if "head_k." in n: - scale = 0.1 - if "head_q." in n: - scale = 0 - print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}") + print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}") - if self.args.accelerator.upper() == "GPU": - m[n] = torch.empty((shape[0], shape[1]), device="cuda") - else: - m[n] = torch.empty((shape[0], shape[1])) + if self.args.accelerator.upper() == "GPU": + m[n] = torch.empty((shape[0], shape[1]), device="cuda") + else: + m[n] = torch.empty((shape[0], shape[1])) - if scale == 0: - nn.init.zeros_(m[n]) - elif scale < 0: - nn.init.normal_(m[n], mean=0.0, std=-scale) - else: - nn.init.orthogonal_(m[n], gain=gain * scale) + if scale == 0: + nn.init.zeros_(m[n]) + elif scale < 0: + nn.init.normal_(m[n], mean=0.0, std=-scale) + else: + nn.init.orthogonal_(m[n], gain=gain * scale) + + m[n] = m[n].cpu() + if os.environ["RWKV_FLOAT_MODE"] == "fp16": + m[n] = m[n].half() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + m[n] = m[n].bfloat16() # if n == "emb.weight": # print(m[n]) diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index e701dde..bbe5883 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -3,10 +3,12 @@ ######################################################################################################## if __name__ == "__main__": - print("\n\n\n!!! NOTE: THIS IS STILL WIP !!!\n\n\n") + print("\n!!! NOTE: THIS IS STILL WIP !!!\n") import os, warnings, math, datetime import numpy as np from argparse import ArgumentParser + import torch + import deepspeed import pytorch_lightning as pl from pytorch_lightning import Trainer from pytorch_lightning import seed_everything @@ -14,6 +16,19 @@ if __name__ == "__main__": from pytorch_lightning.callbacks import TQDMProgressBar from pytorch_lightning import Callback + rank_zero_info( + f""" +############################################################################ +# +# torch {torch.__version__}, recommend 1.12.1+cu116 or newer +# +# deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer +# +# pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer +# +############################################################################ +""" + ) seed_everything(42) np.set_printoptions(precision=4, suppress=True, linewidth=200) warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") @@ -23,37 +38,48 @@ if __name__ == "__main__": parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) + + parser.add_argument("--load_model", default="", type=str) parser.add_argument("--wandb", default="", type=str) parser.add_argument("--proj_dir", default="out", type=str) + + parser.add_argument("--data_file", default="", type=str) + parser.add_argument("--data_type", default="utf-8", type=str) + parser.add_argument("--vocab_size", default=0, type=int) + + parser.add_argument("--ctx_len", default=1024, type=int) + parser.add_argument("--epoch_steps", default=1000, type=int) + parser.add_argument("--epoch_count", default=500, type=int) + parser.add_argument("--epoch_begin", default=0, type=int) + parser.add_argument("--epoch_save", default=5, type=int) + + parser.add_argument("--micro_bsz", default=12, type=int) parser.add_argument("--n_layer", default=6, type=int) parser.add_argument("--n_embd", default=512, type=int) parser.add_argument("--pre_ffn", default=0, type=int) parser.add_argument("--head_qk", default=0, type=int) + parser.add_argument("--lr_init", default=6e-4, type=float) parser.add_argument("--lr_final", default=1e-5, type=float) parser.add_argument("--warmup_steps", default=0, type=int) - parser.add_argument("--epoch_steps", default=1000, type=int) - parser.add_argument("--epoch_bias", default=0, type=int) - parser.add_argument("--epoch_save", default=5, type=int) parser.add_argument("--beta1", default=0.9, type=float) parser.add_argument("--beta2", default=0.99, type=float) parser.add_argument("--adam_eps", default=1e-8, type=float) - parser.add_argument("--ctx_len", default=1024, type=int) - parser.add_argument("--micro_bsz", default=12, type=int) - parser.add_argument("--data_workers", default=1, type=int) + parser.add_argument("--grad_cp", default=0, type=int) - parser.add_argument("--load_model", default="", type=str) - parser.add_argument("--data_file", default="", type=str) - parser.add_argument("--data_type", default="utf-8", type=str) - parser.add_argument("--vocab_size", default=0, type=int) + parser.add_argument("--data_workers", default=1, type=int) + args = parser.parse_args() args.enable_checkpointing = False args.logger = False args.gradient_clip_val = 1.0 args.num_sanity_val_steps = 0 + args.check_val_every_n_epoch = int(1e20) + args.auto_select_gpus = True + args.log_every_n_steps = int(1e20) + args.max_epochs = -1 # continue forever args.betas = (args.beta1, args.beta2) - args.proj_dir = args.proj_dir.strip().strip("\\/") samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz tokens_per_epoch = samples_per_epoch * args.ctx_len @@ -61,11 +87,11 @@ if __name__ == "__main__": f""" ############################################################################ # -# RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()} {args.strategy.upper()} {'with grad_cp' if args.grad_cp > 0 else ''} +# RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()} {args.strateg.upper()} {'with grad_cp' if args.grad_cp > 0 else ''} # # Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} # -# Epoch = {args.epoch_bias} to {args.epoch_bias + args.max_epochs - 1}, save every {args.epoch_save} epoch +# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch # # Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens # @@ -133,7 +159,12 @@ if __name__ == "__main__": import wandb model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd) - wandb.init(project=args.wandb, name=model_name + "-" + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"), config=args, save_code=False) + wandb.init( + project=args.wandb, + name=model_name + "-" + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"), + config=args, + save_code=False, + ) trainer.my_wandb = wandb # LR schedule @@ -141,7 +172,7 @@ if __name__ == "__main__": if g_step < w_step: lr = args.lr_init * (g_step / w_step) else: - progress = (g_step - w_step) / (args.max_epochs * args.epoch_steps - w_step - 1) + progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1) progress = min(1, max(0, progress)) if args.lr_final == 0 or args.lr_init == 0: # linear decay @@ -160,14 +191,21 @@ if __name__ == "__main__": # logging if trainer.global_rank == 0: if len(args.wandb) > 0: - trainer.my_wandb.log({"loss": trainer.my_loss, "lr": trainer.my_lr}, step=trainer.global_step) + trainer.my_wandb.log( + {"loss": trainer.my_loss, "lr": trainer.my_lr}, + step=trainer.global_step, + ) def on_train_epoch_end(self, trainer, pl_module): args = self.args - if trainer.current_epoch % args.epoch_save == 0 or trainer.current_epoch == args.max_epochs - 1: - torch.save(pl_module.state_dict(), f"{args.proj_dir}/rwkv-{args.epoch_bias + trainer.current_epoch}.pth") - trainer.my_log.write(f"{args.epoch_bias + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") - trainer.my_log.flush() + if trainer.global_rank == 0: + if trainer.current_epoch % args.epoch_save == 0 or trainer.current_epoch == args.epoch_count - 1: + torch.save( + pl_module.state_dict(), + f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", + ) + trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") + trainer.my_log.flush() @rank_zero_only def generate_init_weight(model, temp_name): @@ -193,8 +231,6 @@ if __name__ == "__main__": if len(args.load_model) == 0: args.load_model = f"{args.proj_dir}/rwkv-init.pth" # init weights to tmp file generate_init_weight(model, args.load_model) - else: - args.load_model = f"{args.proj_dir}/{args.load_model}" print(f"\nLoading {args.load_model}...\n") load_dict = torch.load(args.load_model, map_location="cpu")