BlinkDL 3 years ago
parent 8abea9c08d
commit cdb098c0e0

@ -339,36 +339,42 @@ class RWKV(pl.LightningModule):
gain = 1.0 gain = 1.0
scale = 1.0 scale = 1.0
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n: if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n:
m[n] = p.cpu() m[n] = p
continue
elif n == "emb.weight":
scale = -25 * self.args.lr_init
else: else:
if shape[0] > shape[1]: if n == "emb.weight":
gain = math.sqrt(shape[0] / shape[1]) scale = -25 * self.args.lr_init
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q."]: else:
if kk in n: if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q."]:
if kk in n:
scale = 0
if n == "head.weight":
scale = 0.5
if "head_k." in n:
scale = 0.1
if "head_q." in n:
scale = 0 scale = 0
if n == "head.weight":
scale = 0.5
if "head_k." in n:
scale = 0.1
if "head_q." in n:
scale = 0
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}") print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
if self.args.accelerator.upper() == "GPU": if self.args.accelerator.upper() == "GPU":
m[n] = torch.empty((shape[0], shape[1]), device="cuda") m[n] = torch.empty((shape[0], shape[1]), device="cuda")
else: else:
m[n] = torch.empty((shape[0], shape[1])) m[n] = torch.empty((shape[0], shape[1]))
if scale == 0: if scale == 0:
nn.init.zeros_(m[n]) nn.init.zeros_(m[n])
elif scale < 0: elif scale < 0:
nn.init.normal_(m[n], mean=0.0, std=-scale) nn.init.normal_(m[n], mean=0.0, std=-scale)
else: else:
nn.init.orthogonal_(m[n], gain=gain * scale) nn.init.orthogonal_(m[n], gain=gain * scale)
m[n] = m[n].cpu()
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
m[n] = m[n].half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
m[n] = m[n].bfloat16()
# if n == "emb.weight": # if n == "emb.weight":
# print(m[n]) # print(m[n])

@ -3,10 +3,12 @@
######################################################################################################## ########################################################################################################
if __name__ == "__main__": if __name__ == "__main__":
print("\n\n\n!!! NOTE: THIS IS STILL WIP !!!\n\n\n") print("\n!!! NOTE: THIS IS STILL WIP !!!\n")
import os, warnings, math, datetime import os, warnings, math, datetime
import numpy as np import numpy as np
from argparse import ArgumentParser from argparse import ArgumentParser
import torch
import deepspeed
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
@ -14,6 +16,19 @@ if __name__ == "__main__":
from pytorch_lightning.callbacks import TQDMProgressBar from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning import Callback from pytorch_lightning import Callback
rank_zero_info(
f"""
############################################################################
#
# torch {torch.__version__}, recommend 1.12.1+cu116 or newer
#
# deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer
#
# pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
#
############################################################################
"""
)
seed_everything(42) seed_everything(42)
np.set_printoptions(precision=4, suppress=True, linewidth=200) np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
@ -23,37 +38,48 @@ if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser) parser = Trainer.add_argparse_args(parser)
parser.add_argument("--load_model", default="", type=str)
parser.add_argument("--wandb", default="", type=str) parser.add_argument("--wandb", default="", type=str)
parser.add_argument("--proj_dir", default="out", type=str) parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int)
parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int)
parser.add_argument("--epoch_count", default=500, type=int)
parser.add_argument("--epoch_begin", default=0, type=int)
parser.add_argument("--epoch_save", default=5, type=int)
parser.add_argument("--micro_bsz", default=12, type=int)
parser.add_argument("--n_layer", default=6, type=int) parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int) parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) parser.add_argument("--pre_ffn", default=0, type=int)
parser.add_argument("--head_qk", default=0, type=int) parser.add_argument("--head_qk", default=0, type=int)
parser.add_argument("--lr_init", default=6e-4, type=float) parser.add_argument("--lr_init", default=6e-4, type=float)
parser.add_argument("--lr_final", default=1e-5, type=float) parser.add_argument("--lr_final", default=1e-5, type=float)
parser.add_argument("--warmup_steps", default=0, type=int) parser.add_argument("--warmup_steps", default=0, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int)
parser.add_argument("--epoch_bias", default=0, type=int)
parser.add_argument("--epoch_save", default=5, type=int)
parser.add_argument("--beta1", default=0.9, type=float) parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) parser.add_argument("--beta2", default=0.99, type=float)
parser.add_argument("--adam_eps", default=1e-8, type=float) parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--micro_bsz", default=12, type=int)
parser.add_argument("--data_workers", default=1, type=int)
parser.add_argument("--grad_cp", default=0, type=int) parser.add_argument("--grad_cp", default=0, type=int)
parser.add_argument("--load_model", default="", type=str) parser.add_argument("--data_workers", default=1, type=int)
parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int)
args = parser.parse_args() args = parser.parse_args()
args.enable_checkpointing = False args.enable_checkpointing = False
args.logger = False args.logger = False
args.gradient_clip_val = 1.0 args.gradient_clip_val = 1.0
args.num_sanity_val_steps = 0 args.num_sanity_val_steps = 0
args.check_val_every_n_epoch = int(1e20)
args.auto_select_gpus = True
args.log_every_n_steps = int(1e20)
args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2) args.betas = (args.beta1, args.beta2)
args.proj_dir = args.proj_dir.strip().strip("\\/")
samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len tokens_per_epoch = samples_per_epoch * args.ctx_len
@ -61,11 +87,11 @@ if __name__ == "__main__":
f""" f"""
############################################################################ ############################################################################
# #
# RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()} {args.strategy.upper()} {'with grad_cp' if args.grad_cp > 0 else ''} # RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()} {args.strateg.upper()} {'with grad_cp' if args.grad_cp > 0 else ''}
# #
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} # Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
# #
# Epoch = {args.epoch_bias} to {args.epoch_bias + args.max_epochs - 1}, save every {args.epoch_save} epoch # Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
# #
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens # Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
# #
@ -133,7 +159,12 @@ if __name__ == "__main__":
import wandb import wandb
model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd) model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd)
wandb.init(project=args.wandb, name=model_name + "-" + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"), config=args, save_code=False) wandb.init(
project=args.wandb,
name=model_name + "-" + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"),
config=args,
save_code=False,
)
trainer.my_wandb = wandb trainer.my_wandb = wandb
# LR schedule # LR schedule
@ -141,7 +172,7 @@ if __name__ == "__main__":
if g_step < w_step: if g_step < w_step:
lr = args.lr_init * (g_step / w_step) lr = args.lr_init * (g_step / w_step)
else: else:
progress = (g_step - w_step) / (args.max_epochs * args.epoch_steps - w_step - 1) progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1)
progress = min(1, max(0, progress)) progress = min(1, max(0, progress))
if args.lr_final == 0 or args.lr_init == 0: # linear decay if args.lr_final == 0 or args.lr_init == 0: # linear decay
@ -160,14 +191,21 @@ if __name__ == "__main__":
# logging # logging
if trainer.global_rank == 0: if trainer.global_rank == 0:
if len(args.wandb) > 0: if len(args.wandb) > 0:
trainer.my_wandb.log({"loss": trainer.my_loss, "lr": trainer.my_lr}, step=trainer.global_step) trainer.my_wandb.log(
{"loss": trainer.my_loss, "lr": trainer.my_lr},
step=trainer.global_step,
)
def on_train_epoch_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module):
args = self.args args = self.args
if trainer.current_epoch % args.epoch_save == 0 or trainer.current_epoch == args.max_epochs - 1: if trainer.global_rank == 0:
torch.save(pl_module.state_dict(), f"{args.proj_dir}/rwkv-{args.epoch_bias + trainer.current_epoch}.pth") if trainer.current_epoch % args.epoch_save == 0 or trainer.current_epoch == args.epoch_count - 1:
trainer.my_log.write(f"{args.epoch_bias + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") torch.save(
trainer.my_log.flush() pl_module.state_dict(),
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
)
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
trainer.my_log.flush()
@rank_zero_only @rank_zero_only
def generate_init_weight(model, temp_name): def generate_init_weight(model, temp_name):
@ -193,8 +231,6 @@ if __name__ == "__main__":
if len(args.load_model) == 0: if len(args.load_model) == 0:
args.load_model = f"{args.proj_dir}/rwkv-init.pth" # init weights to tmp file args.load_model = f"{args.proj_dir}/rwkv-init.pth" # init weights to tmp file
generate_init_weight(model, args.load_model) generate_init_weight(model, args.load_model)
else:
args.load_model = f"{args.proj_dir}/{args.load_model}"
print(f"\nLoading {args.load_model}...\n") print(f"\nLoading {args.load_model}...\n")
load_dict = torch.load(args.load_model, map_location="cpu") load_dict = torch.load(args.load_model, map_location="cpu")

Loading…
Cancel
Save