|
|
|
@ -3,10 +3,12 @@
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|
print("\n\n\n!!! NOTE: THIS IS STILL WIP !!!\n\n\n")
|
|
|
|
print("\n!!! NOTE: THIS IS STILL WIP !!!\n")
|
|
|
|
import os, warnings, math, datetime
|
|
|
|
import os, warnings, math, datetime
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
import deepspeed
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
from pytorch_lightning import seed_everything
|
|
|
|
from pytorch_lightning import seed_everything
|
|
|
|
@ -14,6 +16,19 @@ if __name__ == "__main__":
|
|
|
|
from pytorch_lightning.callbacks import TQDMProgressBar
|
|
|
|
from pytorch_lightning.callbacks import TQDMProgressBar
|
|
|
|
from pytorch_lightning import Callback
|
|
|
|
from pytorch_lightning import Callback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rank_zero_info(
|
|
|
|
|
|
|
|
f"""
|
|
|
|
|
|
|
|
############################################################################
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
# torch {torch.__version__}, recommend 1.12.1+cu116 or newer
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
# deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
# pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
############################################################################
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
)
|
|
|
|
seed_everything(42)
|
|
|
|
seed_everything(42)
|
|
|
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
|
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
|
|
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
|
|
|
|
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
|
|
|
|
@ -23,37 +38,48 @@ if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
parser = ArgumentParser()
|
|
|
|
parser = ArgumentParser()
|
|
|
|
parser = Trainer.add_argparse_args(parser)
|
|
|
|
parser = Trainer.add_argparse_args(parser)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--load_model", default="", type=str)
|
|
|
|
parser.add_argument("--wandb", default="", type=str)
|
|
|
|
parser.add_argument("--wandb", default="", type=str)
|
|
|
|
parser.add_argument("--proj_dir", default="out", type=str)
|
|
|
|
parser.add_argument("--proj_dir", default="out", type=str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--data_file", default="", type=str)
|
|
|
|
|
|
|
|
parser.add_argument("--data_type", default="utf-8", type=str)
|
|
|
|
|
|
|
|
parser.add_argument("--vocab_size", default=0, type=int)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--ctx_len", default=1024, type=int)
|
|
|
|
|
|
|
|
parser.add_argument("--epoch_steps", default=1000, type=int)
|
|
|
|
|
|
|
|
parser.add_argument("--epoch_count", default=500, type=int)
|
|
|
|
|
|
|
|
parser.add_argument("--epoch_begin", default=0, type=int)
|
|
|
|
|
|
|
|
parser.add_argument("--epoch_save", default=5, type=int)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--micro_bsz", default=12, type=int)
|
|
|
|
parser.add_argument("--n_layer", default=6, type=int)
|
|
|
|
parser.add_argument("--n_layer", default=6, type=int)
|
|
|
|
parser.add_argument("--n_embd", default=512, type=int)
|
|
|
|
parser.add_argument("--n_embd", default=512, type=int)
|
|
|
|
parser.add_argument("--pre_ffn", default=0, type=int)
|
|
|
|
parser.add_argument("--pre_ffn", default=0, type=int)
|
|
|
|
parser.add_argument("--head_qk", default=0, type=int)
|
|
|
|
parser.add_argument("--head_qk", default=0, type=int)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--lr_init", default=6e-4, type=float)
|
|
|
|
parser.add_argument("--lr_init", default=6e-4, type=float)
|
|
|
|
parser.add_argument("--lr_final", default=1e-5, type=float)
|
|
|
|
parser.add_argument("--lr_final", default=1e-5, type=float)
|
|
|
|
parser.add_argument("--warmup_steps", default=0, type=int)
|
|
|
|
parser.add_argument("--warmup_steps", default=0, type=int)
|
|
|
|
parser.add_argument("--epoch_steps", default=1000, type=int)
|
|
|
|
|
|
|
|
parser.add_argument("--epoch_bias", default=0, type=int)
|
|
|
|
|
|
|
|
parser.add_argument("--epoch_save", default=5, type=int)
|
|
|
|
|
|
|
|
parser.add_argument("--beta1", default=0.9, type=float)
|
|
|
|
parser.add_argument("--beta1", default=0.9, type=float)
|
|
|
|
parser.add_argument("--beta2", default=0.99, type=float)
|
|
|
|
parser.add_argument("--beta2", default=0.99, type=float)
|
|
|
|
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
|
|
|
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
|
|
|
parser.add_argument("--ctx_len", default=1024, type=int)
|
|
|
|
|
|
|
|
parser.add_argument("--micro_bsz", default=12, type=int)
|
|
|
|
|
|
|
|
parser.add_argument("--data_workers", default=1, type=int)
|
|
|
|
|
|
|
|
parser.add_argument("--grad_cp", default=0, type=int)
|
|
|
|
parser.add_argument("--grad_cp", default=0, type=int)
|
|
|
|
parser.add_argument("--load_model", default="", type=str)
|
|
|
|
parser.add_argument("--data_workers", default=1, type=int)
|
|
|
|
parser.add_argument("--data_file", default="", type=str)
|
|
|
|
|
|
|
|
parser.add_argument("--data_type", default="utf-8", type=str)
|
|
|
|
|
|
|
|
parser.add_argument("--vocab_size", default=0, type=int)
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
args.enable_checkpointing = False
|
|
|
|
args.enable_checkpointing = False
|
|
|
|
args.logger = False
|
|
|
|
args.logger = False
|
|
|
|
args.gradient_clip_val = 1.0
|
|
|
|
args.gradient_clip_val = 1.0
|
|
|
|
args.num_sanity_val_steps = 0
|
|
|
|
args.num_sanity_val_steps = 0
|
|
|
|
|
|
|
|
args.check_val_every_n_epoch = int(1e20)
|
|
|
|
|
|
|
|
args.auto_select_gpus = True
|
|
|
|
|
|
|
|
args.log_every_n_steps = int(1e20)
|
|
|
|
|
|
|
|
args.max_epochs = -1 # continue forever
|
|
|
|
args.betas = (args.beta1, args.beta2)
|
|
|
|
args.betas = (args.beta1, args.beta2)
|
|
|
|
args.proj_dir = args.proj_dir.strip().strip("\\/")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
|
|
|
|
samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
|
|
|
|
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
|
|
|
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
|
|
|
@ -61,11 +87,11 @@ if __name__ == "__main__":
|
|
|
|
f"""
|
|
|
|
f"""
|
|
|
|
############################################################################
|
|
|
|
############################################################################
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()} {args.strategy.upper()} {'with grad_cp' if args.grad_cp > 0 else ''}
|
|
|
|
# RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()} {args.strateg.upper()} {'with grad_cp' if args.grad_cp > 0 else ''}
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
|
|
|
|
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# Epoch = {args.epoch_bias} to {args.epoch_bias + args.max_epochs - 1}, save every {args.epoch_save} epoch
|
|
|
|
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
|
|
|
|
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
|
|
|
|
#
|
|
|
|
#
|
|
|
|
@ -133,7 +159,12 @@ if __name__ == "__main__":
|
|
|
|
import wandb
|
|
|
|
import wandb
|
|
|
|
|
|
|
|
|
|
|
|
model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd)
|
|
|
|
model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd)
|
|
|
|
wandb.init(project=args.wandb, name=model_name + "-" + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"), config=args, save_code=False)
|
|
|
|
wandb.init(
|
|
|
|
|
|
|
|
project=args.wandb,
|
|
|
|
|
|
|
|
name=model_name + "-" + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"),
|
|
|
|
|
|
|
|
config=args,
|
|
|
|
|
|
|
|
save_code=False,
|
|
|
|
|
|
|
|
)
|
|
|
|
trainer.my_wandb = wandb
|
|
|
|
trainer.my_wandb = wandb
|
|
|
|
|
|
|
|
|
|
|
|
# LR schedule
|
|
|
|
# LR schedule
|
|
|
|
@ -141,7 +172,7 @@ if __name__ == "__main__":
|
|
|
|
if g_step < w_step:
|
|
|
|
if g_step < w_step:
|
|
|
|
lr = args.lr_init * (g_step / w_step)
|
|
|
|
lr = args.lr_init * (g_step / w_step)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
progress = (g_step - w_step) / (args.max_epochs * args.epoch_steps - w_step - 1)
|
|
|
|
progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1)
|
|
|
|
progress = min(1, max(0, progress))
|
|
|
|
progress = min(1, max(0, progress))
|
|
|
|
|
|
|
|
|
|
|
|
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
|
|
|
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
|
|
|
@ -160,14 +191,21 @@ if __name__ == "__main__":
|
|
|
|
# logging
|
|
|
|
# logging
|
|
|
|
if trainer.global_rank == 0:
|
|
|
|
if trainer.global_rank == 0:
|
|
|
|
if len(args.wandb) > 0:
|
|
|
|
if len(args.wandb) > 0:
|
|
|
|
trainer.my_wandb.log({"loss": trainer.my_loss, "lr": trainer.my_lr}, step=trainer.global_step)
|
|
|
|
trainer.my_wandb.log(
|
|
|
|
|
|
|
|
{"loss": trainer.my_loss, "lr": trainer.my_lr},
|
|
|
|
|
|
|
|
step=trainer.global_step,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def on_train_epoch_end(self, trainer, pl_module):
|
|
|
|
def on_train_epoch_end(self, trainer, pl_module):
|
|
|
|
args = self.args
|
|
|
|
args = self.args
|
|
|
|
if trainer.current_epoch % args.epoch_save == 0 or trainer.current_epoch == args.max_epochs - 1:
|
|
|
|
if trainer.global_rank == 0:
|
|
|
|
torch.save(pl_module.state_dict(), f"{args.proj_dir}/rwkv-{args.epoch_bias + trainer.current_epoch}.pth")
|
|
|
|
if trainer.current_epoch % args.epoch_save == 0 or trainer.current_epoch == args.epoch_count - 1:
|
|
|
|
trainer.my_log.write(f"{args.epoch_bias + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
|
|
|
|
torch.save(
|
|
|
|
trainer.my_log.flush()
|
|
|
|
pl_module.state_dict(),
|
|
|
|
|
|
|
|
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
|
|
|
|
|
|
|
|
trainer.my_log.flush()
|
|
|
|
|
|
|
|
|
|
|
|
@rank_zero_only
|
|
|
|
@rank_zero_only
|
|
|
|
def generate_init_weight(model, temp_name):
|
|
|
|
def generate_init_weight(model, temp_name):
|
|
|
|
@ -193,8 +231,6 @@ if __name__ == "__main__":
|
|
|
|
if len(args.load_model) == 0:
|
|
|
|
if len(args.load_model) == 0:
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth" # init weights to tmp file
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth" # init weights to tmp file
|
|
|
|
generate_init_weight(model, args.load_model)
|
|
|
|
generate_init_weight(model, args.load_model)
|
|
|
|
else:
|
|
|
|
|
|
|
|
args.load_model = f"{args.proj_dir}/{args.load_model}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nLoading {args.load_model}...\n")
|
|
|
|
print(f"\nLoading {args.load_model}...\n")
|
|
|
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
|
|
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
|
|
|
|