|
|
|
|
@ -4,10 +4,11 @@
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
print("\n!!! NOTE: THIS IS STILL WIP !!!\n")
|
|
|
|
|
import os, warnings, math, datetime, sys
|
|
|
|
|
import os, warnings, math, datetime, sys, time
|
|
|
|
|
import numpy as np
|
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
|
import torch
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
import deepspeed
|
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
|
@ -83,10 +84,12 @@ if __name__ == "__main__":
|
|
|
|
|
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
|
|
|
|
|
parser.add_argument("--my_pile_mode", default=0, type=int) # my special pile mode
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
|
|
|
|
args.enable_checkpointing = False
|
|
|
|
|
args.replace_sampler_ddp = False
|
|
|
|
|
args.logger = False
|
|
|
|
|
args.gradient_clip_val = 1.0
|
|
|
|
|
args.num_sanity_val_steps = 0
|
|
|
|
|
@ -95,6 +98,12 @@ if __name__ == "__main__":
|
|
|
|
|
args.max_epochs = -1 # continue forever
|
|
|
|
|
args.betas = (args.beta1, args.beta2)
|
|
|
|
|
|
|
|
|
|
if args.my_pile_mode > 0:
|
|
|
|
|
args.epoch_steps = 40320 // (int(args.devices) * args.micro_bsz)
|
|
|
|
|
assert args.epoch_steps * int(args.devices) * args.micro_bsz == 40320
|
|
|
|
|
if args.my_pile_mode == 2:
|
|
|
|
|
assert args.lr_final == args.lr_init
|
|
|
|
|
|
|
|
|
|
samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
|
|
|
|
|
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
|
|
|
|
rank_zero_info(
|
|
|
|
|
@ -138,9 +147,9 @@ if __name__ == "__main__":
|
|
|
|
|
if args.precision == "fp16":
|
|
|
|
|
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
|
|
|
|
|
|
|
|
|
|
os.environ["RWKV_JIT"] = "1"
|
|
|
|
|
os.environ["RWKV_JIT_ON"] = "1"
|
|
|
|
|
if "deepspeed_stage_3" in args.strategy:
|
|
|
|
|
os.environ["RWKV_JIT"] = "0"
|
|
|
|
|
os.environ["RWKV_JIT_ON"] = "0"
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
@ -170,9 +179,37 @@ if __name__ == "__main__":
|
|
|
|
|
args = self.args
|
|
|
|
|
g_step = trainer.global_step
|
|
|
|
|
|
|
|
|
|
# logging
|
|
|
|
|
if trainer.global_rank == 0:
|
|
|
|
|
# LR schedule
|
|
|
|
|
w_step = args.warmup_steps
|
|
|
|
|
if g_step < w_step:
|
|
|
|
|
lr = args.lr_init * (g_step / w_step)
|
|
|
|
|
else:
|
|
|
|
|
if args.lr_final == args.lr_init:
|
|
|
|
|
lr = args.lr_init
|
|
|
|
|
else:
|
|
|
|
|
progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1)
|
|
|
|
|
progress = min(1, max(0, progress))
|
|
|
|
|
|
|
|
|
|
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
|
|
|
|
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
|
|
|
|
|
else: # exp decay
|
|
|
|
|
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
|
|
|
|
|
|
|
|
|
|
for param_group in trainer.optimizers[0].param_groups:
|
|
|
|
|
if self.args.my_pile_mode == 0:
|
|
|
|
|
param_group["lr"] = lr * param_group["my_lr_scale"]
|
|
|
|
|
elif self.args.my_pile_mode == 2:
|
|
|
|
|
if param_group["my_lr_scale"] > 1:
|
|
|
|
|
param_group["lr"] = lr * 5
|
|
|
|
|
else:
|
|
|
|
|
param_group["lr"] = lr
|
|
|
|
|
# print(param_group["lr"], param_group["my_lr_scale"])
|
|
|
|
|
|
|
|
|
|
trainer.my_lr = lr
|
|
|
|
|
# rank_zero_info(f"{g_step} {lr}")
|
|
|
|
|
|
|
|
|
|
if g_step == 0:
|
|
|
|
|
if trainer.is_global_zero: # logging
|
|
|
|
|
trainer.my_loss_sum = 0
|
|
|
|
|
trainer.my_loss_count = 0
|
|
|
|
|
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
|
|
|
|
|
@ -196,39 +233,42 @@ if __name__ == "__main__":
|
|
|
|
|
)
|
|
|
|
|
trainer.my_wandb = wandb
|
|
|
|
|
|
|
|
|
|
# LR schedule
|
|
|
|
|
w_step = args.warmup_steps
|
|
|
|
|
if g_step < w_step:
|
|
|
|
|
lr = args.lr_init * (g_step / w_step)
|
|
|
|
|
else:
|
|
|
|
|
progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1)
|
|
|
|
|
progress = min(1, max(0, progress))
|
|
|
|
|
|
|
|
|
|
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
|
|
|
|
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
|
|
|
|
|
else: # exp decay
|
|
|
|
|
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
|
|
|
|
|
|
|
|
|
|
for param_group in trainer.optimizers[0].param_groups:
|
|
|
|
|
param_group["lr"] = lr
|
|
|
|
|
|
|
|
|
|
trainer.my_lr = lr
|
|
|
|
|
# rank_zero_info(f"{g_step} {lr}")
|
|
|
|
|
|
|
|
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
|
|
|
|
args = self.args
|
|
|
|
|
# logging
|
|
|
|
|
if trainer.global_rank == 0:
|
|
|
|
|
if trainer.is_global_zero: # logging
|
|
|
|
|
t_now = time.time_ns()
|
|
|
|
|
try:
|
|
|
|
|
t_cost = (t_now - trainer.my_time_ns) / 1e9
|
|
|
|
|
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
|
|
|
|
|
self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True)
|
|
|
|
|
except:
|
|
|
|
|
pass
|
|
|
|
|
trainer.my_time_ns = t_now
|
|
|
|
|
trainer.my_loss = trainer.my_loss_all.float().mean().item()
|
|
|
|
|
trainer.my_loss_sum += trainer.my_loss
|
|
|
|
|
trainer.my_loss_count += 1
|
|
|
|
|
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
|
|
|
|
|
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
|
|
|
|
|
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
|
|
|
|
|
|
|
|
|
|
if len(args.wandb) > 0:
|
|
|
|
|
trainer.my_wandb.log(
|
|
|
|
|
{"loss": trainer.my_loss, "lr": trainer.my_lr},
|
|
|
|
|
step=trainer.global_step,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def on_train_epoch_start(self, trainer, pl_module):
|
|
|
|
|
args = self.args
|
|
|
|
|
dataset = trainer.train_dataloader.dataset.datasets
|
|
|
|
|
assert "MyDataset" in str(dataset)
|
|
|
|
|
dataset.global_rank = trainer.global_rank
|
|
|
|
|
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
|
|
|
|
|
dataset.world_size = trainer.world_size
|
|
|
|
|
|
|
|
|
|
def on_train_epoch_end(self, trainer, pl_module):
|
|
|
|
|
args = self.args
|
|
|
|
|
if trainer.global_rank == 0:
|
|
|
|
|
if trainer.current_epoch % args.epoch_save == 0 or trainer.current_epoch == args.epoch_count - 1:
|
|
|
|
|
if trainer.is_global_zero: # logging & save state_dict
|
|
|
|
|
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
|
|
|
|
|
torch.save(
|
|
|
|
|
pl_module.state_dict(),
|
|
|
|
|
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
|
|
|
|
|
@ -251,7 +291,6 @@ if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
from src.dataset import MyDataset
|
|
|
|
|
from src.model import RWKV
|
|
|
|
|
|
|
|
|
|
@ -261,8 +300,8 @@ if __name__ == "__main__":
|
|
|
|
|
model = RWKV(args)
|
|
|
|
|
|
|
|
|
|
if len(args.load_model) == 0:
|
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth" # init weights to tmp file
|
|
|
|
|
generate_init_weight(model, args.load_model)
|
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
|
|
|
|
generate_init_weight(model, args.load_model) # save initial weights to tmp file
|
|
|
|
|
|
|
|
|
|
print(f"########## Loading {args.load_model}... ##########")
|
|
|
|
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
|
|
|
|
@ -273,5 +312,7 @@ if __name__ == "__main__":
|
|
|
|
|
callbacks=[train_callback(args)],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
train_loader = DataLoader(train_data, batch_size=args.micro_bsz, num_workers=1)
|
|
|
|
|
trainer.fit(model, train_loader)
|
|
|
|
|
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
|
|
|
|
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
|
|
|
|
|
|
|
|
|
|
trainer.fit(model, data_loader)
|
|
|
|
|
|