main
BlinkDL 3 years ago
parent c22c7dadca
commit 44f07e44d2

2
.gitignore vendored

@ -9,6 +9,8 @@ data/
vocab.json
*.sh
*log/
test/
tools/
# Byte-compiled / optimized / DLL files
__pycache__/

@ -24,13 +24,17 @@ class train_callback(pl.Callback):
if args.lr_final == args.lr_init or args.epoch_count == 0:
lr = args.lr_init
else:
progress = (real_step - w_step + 1) / (args.epoch_count * args.epoch_steps - w_step)
decay_step = real_step - args.my_pile_edecay * args.epoch_steps
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
progress = (decay_step - w_step + 1) / (decay_total - w_step)
progress = min(1, max(0, progress))
if args.lr_final == 0 or args.lr_init == 0: # linear decay
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
# if trainer.is_global_zero:
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
for param_group in trainer.optimizers[0].param_groups:
if args.layerwise_lr > 0:
@ -73,10 +77,12 @@ class train_callback(pl.Callback):
t_now = time.time_ns()
token_per_step = args.ctx_len * args.real_bsz
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
kt_s = 0
try:
t_cost = (t_now - trainer.my_time_ns) / 1e9
kt_s = token_per_step / t_cost / 1000
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("Kt/s", token_per_step / t_cost / 1000, prog_bar=True, on_step=True)
self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
except:
pass
trainer.my_time_ns = t_now
@ -89,10 +95,10 @@ class train_callback(pl.Callback):
# self.log("s", real_step, prog_bar=True, on_step=True)
if len(args.wandb) > 0:
trainer.my_wandb.log(
{"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9},
step=int(real_step),
)
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
if kt_s > 0:
lll["kt/s"] = kt_s
trainer.my_wandb.log(lll, step=int(real_step))
def on_train_epoch_start(self, trainer, pl_module):
args = self.args

@ -93,6 +93,7 @@ if __name__ == "__main__":
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
parser.add_argument("--my_pile_edecay", default=0, type=int)
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
@ -117,10 +118,13 @@ if __name__ == "__main__":
if args.my_pile_stage > 0:
if args.ctx_len == 1024:
args.magic_prime = 324331313
args.epoch_count = 8043
elif args.ctx_len == 2048:
args.magic_prime = 162165671
args.epoch_count = 4021
elif args.ctx_len == 4096:
args.magic_prime = 81082817
args.epoch_count = 2010
if args.my_pile_shift < 0:
if args.ctx_len == 1024:
args.my_pile_shift = 0
@ -129,7 +133,6 @@ if __name__ == "__main__":
elif args.ctx_len == 4096:
args.my_pile_shift = 768
args.epoch_count = 8043
args.epoch_steps = 40320 // args.real_bsz
assert args.epoch_steps * args.real_bsz == 40320
if args.my_pile_stage == 2:

Loading…
Cancel
Save