|
|
|
|
@ -24,13 +24,17 @@ class train_callback(pl.Callback):
|
|
|
|
|
if args.lr_final == args.lr_init or args.epoch_count == 0:
|
|
|
|
|
lr = args.lr_init
|
|
|
|
|
else:
|
|
|
|
|
progress = (real_step - w_step + 1) / (args.epoch_count * args.epoch_steps - w_step)
|
|
|
|
|
decay_step = real_step - args.my_pile_edecay * args.epoch_steps
|
|
|
|
|
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
|
|
|
|
|
progress = (decay_step - w_step + 1) / (decay_total - w_step)
|
|
|
|
|
progress = min(1, max(0, progress))
|
|
|
|
|
|
|
|
|
|
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
|
|
|
|
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
|
|
|
|
|
else: # exp decay
|
|
|
|
|
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
|
|
|
|
|
# if trainer.is_global_zero:
|
|
|
|
|
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
|
|
|
|
|
|
|
|
|
|
for param_group in trainer.optimizers[0].param_groups:
|
|
|
|
|
if args.layerwise_lr > 0:
|
|
|
|
|
@ -73,10 +77,12 @@ class train_callback(pl.Callback):
|
|
|
|
|
t_now = time.time_ns()
|
|
|
|
|
token_per_step = args.ctx_len * args.real_bsz
|
|
|
|
|
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
|
|
|
|
kt_s = 0
|
|
|
|
|
try:
|
|
|
|
|
t_cost = (t_now - trainer.my_time_ns) / 1e9
|
|
|
|
|
kt_s = token_per_step / t_cost / 1000
|
|
|
|
|
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
|
|
|
|
|
self.log("Kt/s", token_per_step / t_cost / 1000, prog_bar=True, on_step=True)
|
|
|
|
|
self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
|
|
|
|
|
except:
|
|
|
|
|
pass
|
|
|
|
|
trainer.my_time_ns = t_now
|
|
|
|
|
@ -89,10 +95,10 @@ class train_callback(pl.Callback):
|
|
|
|
|
# self.log("s", real_step, prog_bar=True, on_step=True)
|
|
|
|
|
|
|
|
|
|
if len(args.wandb) > 0:
|
|
|
|
|
trainer.my_wandb.log(
|
|
|
|
|
{"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9},
|
|
|
|
|
step=int(real_step),
|
|
|
|
|
)
|
|
|
|
|
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
|
|
|
|
|
if kt_s > 0:
|
|
|
|
|
lll["kt/s"] = kt_s
|
|
|
|
|
trainer.my_wandb.log(lll, step=int(real_step))
|
|
|
|
|
|
|
|
|
|
def on_train_epoch_start(self, trainer, pl_module):
|
|
|
|
|
args = self.args
|
|
|
|
|
|