|
|
|
|
@ -2,9 +2,8 @@ import os, math, time, datetime
|
|
|
|
|
import torch
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
|
from pytorch_lightning import seed_everything
|
|
|
|
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
|
|
|
|
from pytorch_lightning.callbacks import TQDMProgressBar
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class train_callback(pl.Callback):
|
|
|
|
|
def __init__(self, args):
|
|
|
|
|
@ -33,9 +32,9 @@ class train_callback(pl.Callback):
|
|
|
|
|
|
|
|
|
|
for param_group in trainer.optimizers[0].param_groups:
|
|
|
|
|
if args.layerwise_lr > 0:
|
|
|
|
|
if self.args.my_pile_mode == 0:
|
|
|
|
|
if self.args.my_pile_stage != 2:
|
|
|
|
|
param_group["lr"] = lr * param_group["my_lr_scale"]
|
|
|
|
|
elif self.args.my_pile_mode == 2:
|
|
|
|
|
else:
|
|
|
|
|
if param_group["my_lr_scale"] > 1:
|
|
|
|
|
param_group["lr"] = lr * 5
|
|
|
|
|
else:
|
|
|
|
|
@ -63,10 +62,10 @@ class train_callback(pl.Callback):
|
|
|
|
|
print("Login to wandb...")
|
|
|
|
|
import wandb
|
|
|
|
|
|
|
|
|
|
model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd)
|
|
|
|
|
model_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
|
|
|
|
|
wandb.init(
|
|
|
|
|
project=args.wandb,
|
|
|
|
|
name=model_name + "-" + args.my_timestamp,
|
|
|
|
|
name=model_name + " " + args.my_timestamp,
|
|
|
|
|
config=args,
|
|
|
|
|
save_code=False,
|
|
|
|
|
)
|
|
|
|
|
@ -76,10 +75,12 @@ class train_callback(pl.Callback):
|
|
|
|
|
args = self.args
|
|
|
|
|
if trainer.is_global_zero: # logging
|
|
|
|
|
t_now = time.time_ns()
|
|
|
|
|
token_per_step = args.ctx_len * float(args.devices) * args.micro_bsz
|
|
|
|
|
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
|
|
|
|
try:
|
|
|
|
|
t_cost = (t_now - trainer.my_time_ns) / 1e9
|
|
|
|
|
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
|
|
|
|
|
self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True)
|
|
|
|
|
self.log("Kt/s", token_per_step / t_cost / 1000, prog_bar=True, on_step=True)
|
|
|
|
|
except:
|
|
|
|
|
pass
|
|
|
|
|
trainer.my_time_ns = t_now
|
|
|
|
|
@ -89,11 +90,12 @@ class train_callback(pl.Callback):
|
|
|
|
|
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
|
|
|
|
|
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
|
|
|
|
|
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
|
|
|
|
|
# self.log("s", real_step, prog_bar=True, on_step=True)
|
|
|
|
|
|
|
|
|
|
if len(args.wandb) > 0:
|
|
|
|
|
trainer.my_wandb.log(
|
|
|
|
|
{"loss": trainer.my_loss, "lr": trainer.my_lr},
|
|
|
|
|
step=trainer.global_step,
|
|
|
|
|
{"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9},
|
|
|
|
|
step=int(real_step),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def on_train_epoch_start(self, trainer, pl_module):
|
|
|
|
|
@ -120,11 +122,19 @@ class train_callback(pl.Callback):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@rank_zero_only
|
|
|
|
|
def generate_init_weight(model, temp_name):
|
|
|
|
|
try:
|
|
|
|
|
os.remove(temp_name)
|
|
|
|
|
except:
|
|
|
|
|
pass
|
|
|
|
|
def generate_init_weight(model, init_weight_name):
|
|
|
|
|
mm = model.generate_init_weight()
|
|
|
|
|
print(f"Saving to {temp_name}...")
|
|
|
|
|
torch.save(mm, temp_name)
|
|
|
|
|
|
|
|
|
|
if model.args.my_pile_stage == 1:
|
|
|
|
|
print(f"Combine weights from {model.args.load_model}...")
|
|
|
|
|
load_dict = torch.load(model.args.load_model, map_location="cpu")
|
|
|
|
|
for k in load_dict:
|
|
|
|
|
assert k in mm
|
|
|
|
|
mm[k] = load_dict[k].reshape(mm[k].shape)
|
|
|
|
|
|
|
|
|
|
print(f"Save to {init_weight_name}...")
|
|
|
|
|
torch.save(mm, init_weight_name)
|
|
|
|
|
|
|
|
|
|
if model.args.my_pile_stage == 1:
|
|
|
|
|
print("Done. Now go for stage 2.")
|
|
|
|
|
exit(0)
|
|
|
|
|
|