From 99a3dff414409ec5ceda8e6b2cda759f1f9c05ef Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Tue, 6 Sep 2022 11:12:47 +0000 Subject: [PATCH] code for pile training --- .gitignore | 4 ++++ RWKV-v4neo/src/dataset.py | 8 ++++---- RWKV-v4neo/src/model.py | 8 +++++--- RWKV-v4neo/src/trainer.py | 42 ++++++++++++++++++++++++--------------- RWKV-v4neo/train.py | 40 ++++++++++++++++++++++++++++--------- 5 files changed, 70 insertions(+), 32 deletions(-) diff --git a/.gitignore b/.gitignore index 19616df..9fa2131 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,10 @@ *.xlsx *.xls wandb/ +data/ +vocab.json +*.sh +*log/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 51a9935..f0e8262 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -22,11 +22,11 @@ class MyDataset(Dataset): self.data_size = len(self.data._bin_buffer) // 2 print(f"Data has {self.data_size} tokens.") - if args.my_pile_mode > 0: + if args.my_pile_stage > 0: assert self.data_size == 332115325534 and self.vocab_size == 50277 and args.ctx_len == 1024 self.samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz assert self.samples_per_epoch == 40320 - print(f"########## Pile 20b-tokenized mode {args.my_pile_mode} ##########") + print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") self.magic_prime = 324331313 dataset_slot = self.data_size // args.ctx_len assert MaybeIsPrime(self.magic_prime) @@ -46,7 +46,7 @@ class MyDataset(Dataset): aa = (i) % 10000 bb = (i * i) % 10000 cc = aa + bb - self.data += f'.{aa}+{bb}={cc}.' + self.data += f".{aa}+{bb}={cc}." else: self.data = open(args.data_file, "r", encoding=args.data_type).read() print("Building token list...") @@ -84,7 +84,7 @@ class MyDataset(Dataset): ctx_len = args.ctx_len req_len = ctx_len + 1 - if args.my_pile_mode > 0: + if args.my_pile_stage > 0: ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank factor = (math.sqrt(5) - 1) / 2 factor = int(self.magic_prime * factor) diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index acf006b..68e63c0 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -12,8 +12,10 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.strategies import DeepSpeedStrategy import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam + # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam + def __nop(ob): return ob @@ -278,7 +280,7 @@ class RWKV(pl.LightningModule): lr_2x = set() lr_3x = set() for n, p in self.named_parameters(): - if ("time_mix" in n) and (self.args.my_pile_mode == 2): + if ("time_mix" in n) and (self.args.my_pile_stage == 2): lr_2x.add(n) elif "time_decay" in n: lr_2x.add(n) @@ -382,7 +384,7 @@ class RWKV(pl.LightningModule): m[n] = p else: if n == "emb.weight": - scale = -25 * self.args.lr_init + scale = -1 * self.args.lr_init else: if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) @@ -406,7 +408,7 @@ class RWKV(pl.LightningModule): if scale == 0: nn.init.zeros_(m[n]) elif scale < 0: - nn.init.normal_(m[n], mean=0.0, std=-scale) + nn.init.uniform_(m[n], a=scale, b=-scale) else: nn.init.orthogonal_(m[n], gain=gain * scale) diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index fd770b6..8fb2d27 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -2,9 +2,8 @@ import os, math, time, datetime import torch from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning import seed_everything from pytorch_lightning.utilities import rank_zero_info, rank_zero_only -from pytorch_lightning.callbacks import TQDMProgressBar + class train_callback(pl.Callback): def __init__(self, args): @@ -33,9 +32,9 @@ class train_callback(pl.Callback): for param_group in trainer.optimizers[0].param_groups: if args.layerwise_lr > 0: - if self.args.my_pile_mode == 0: + if self.args.my_pile_stage != 2: param_group["lr"] = lr * param_group["my_lr_scale"] - elif self.args.my_pile_mode == 2: + else: if param_group["my_lr_scale"] > 1: param_group["lr"] = lr * 5 else: @@ -63,10 +62,10 @@ class train_callback(pl.Callback): print("Login to wandb...") import wandb - model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd) + model_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" wandb.init( project=args.wandb, - name=model_name + "-" + args.my_timestamp, + name=model_name + " " + args.my_timestamp, config=args, save_code=False, ) @@ -76,10 +75,12 @@ class train_callback(pl.Callback): args = self.args if trainer.is_global_zero: # logging t_now = time.time_ns() + token_per_step = args.ctx_len * float(args.devices) * args.micro_bsz + real_step = trainer.global_step + args.epoch_begin * args.epoch_steps try: t_cost = (t_now - trainer.my_time_ns) / 1e9 self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True) - self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True) + self.log("Kt/s", token_per_step / t_cost / 1000, prog_bar=True, on_step=True) except: pass trainer.my_time_ns = t_now @@ -89,11 +90,12 @@ class train_callback(pl.Callback): trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count self.log("lr", trainer.my_lr, prog_bar=True, on_step=True) self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True) + # self.log("s", real_step, prog_bar=True, on_step=True) if len(args.wandb) > 0: trainer.my_wandb.log( - {"loss": trainer.my_loss, "lr": trainer.my_lr}, - step=trainer.global_step, + {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}, + step=int(real_step), ) def on_train_epoch_start(self, trainer, pl_module): @@ -120,11 +122,19 @@ class train_callback(pl.Callback): @rank_zero_only -def generate_init_weight(model, temp_name): - try: - os.remove(temp_name) - except: - pass +def generate_init_weight(model, init_weight_name): mm = model.generate_init_weight() - print(f"Saving to {temp_name}...") - torch.save(mm, temp_name) + + if model.args.my_pile_stage == 1: + print(f"Combine weights from {model.args.load_model}...") + load_dict = torch.load(model.args.load_model, map_location="cpu") + for k in load_dict: + assert k in mm + mm[k] = load_dict[k].reshape(mm[k].shape) + + print(f"Save to {init_weight_name}...") + torch.save(mm, init_weight_name) + + if model.args.my_pile_stage == 1: + print("Done. Now go for stage 2.") + exit(0) diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index d99ec22..ec40cf9 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -3,7 +3,7 @@ ######################################################################################################## if __name__ == "__main__": - print("\n!!! NOTE: THIS IS STILL WIP !!!\n") + print("\n!!! work in progress !!!\n") import os, warnings, math, datetime, sys, time import numpy as np from argparse import ArgumentParser @@ -23,7 +23,7 @@ if __name__ == "__main__": warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") ######################################################################################################## - + # # example: train a simple L12-D768 RWKV on dummy data # # python train.py --load_model "" --wandb "" --proj_dir "out" \ @@ -91,7 +91,7 @@ if __name__ == "__main__": parser.add_argument("--adam_eps", default=1e-8, type=float) parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower - parser.add_argument("--my_pile_mode", default=0, type=int) # my special pile mode + parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough @@ -107,11 +107,32 @@ if __name__ == "__main__": args.max_epochs = -1 # continue forever args.betas = (args.beta1, args.beta2) - if args.my_pile_mode > 0: + if args.my_pile_stage > 0: args.epoch_steps = 40320 // (int(args.devices) * args.micro_bsz) assert args.epoch_steps * int(args.devices) * args.micro_bsz == 40320 - if args.my_pile_mode == 2: + if args.my_pile_stage == 2: assert args.lr_final == args.lr_init + if args.my_pile_stage >= 2: # find latest saved model + pths = os.listdir(args.proj_dir) + max_p = -1 + for p in pths: + if p.startswith("rwkv") and p.endswith(".pth"): + p = ((p.split("-"))[1].split("."))[0] + if p == "init": + p = -1 + else: + p = int(p) + if p > max_p: + max_p = p + if max_p == -1: + args.load_model = f"{args.proj_dir}/rwkv-init.pth" + else: + args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" + if args.my_pile_stage == 2: + args.warmup_steps = 10 + else: + args.warmup_steps = 50 + args.epoch_begin = max_p + 1 samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz tokens_per_epoch = samples_per_epoch * args.ctx_len @@ -175,7 +196,7 @@ if __name__ == "__main__": args.precision = "bf16" ######################################################################################################## - + from src.trainer import train_callback, generate_init_weight from src.dataset import MyDataset from src.model import RWKV @@ -185,9 +206,10 @@ if __name__ == "__main__": model = RWKV(args) - if len(args.load_model) == 0: - args.load_model = f"{args.proj_dir}/rwkv-init.pth" - generate_init_weight(model, args.load_model) # save initial weights to tmp file + if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights? + init_weight_name = f"{args.proj_dir}/rwkv-init.pth" + generate_init_weight(model, init_weight_name) # save initial weights + args.load_model = init_weight_name print(f"########## Loading {args.load_model}... ##########") load_dict = torch.load(args.load_model, map_location="cpu")