diff --git a/.gitignore b/.gitignore index 9fa2131..3c9a4ff 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,8 @@ data/ vocab.json *.sh *log/ +test/ +tools/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index d42856c..0648c6c 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -24,13 +24,17 @@ class train_callback(pl.Callback): if args.lr_final == args.lr_init or args.epoch_count == 0: lr = args.lr_init else: - progress = (real_step - w_step + 1) / (args.epoch_count * args.epoch_steps - w_step) + decay_step = real_step - args.my_pile_edecay * args.epoch_steps + decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps + progress = (decay_step - w_step + 1) / (decay_total - w_step) progress = min(1, max(0, progress)) if args.lr_final == 0 or args.lr_init == 0: # linear decay lr = args.lr_init + (args.lr_final - args.lr_init) * progress else: # exp decay lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) + # if trainer.is_global_zero: + # print(trainer.global_step, decay_step, decay_total, w_step, progress, lr) for param_group in trainer.optimizers[0].param_groups: if args.layerwise_lr > 0: @@ -73,10 +77,12 @@ class train_callback(pl.Callback): t_now = time.time_ns() token_per_step = args.ctx_len * args.real_bsz real_step = trainer.global_step + args.epoch_begin * args.epoch_steps + kt_s = 0 try: t_cost = (t_now - trainer.my_time_ns) / 1e9 + kt_s = token_per_step / t_cost / 1000 self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True) - self.log("Kt/s", token_per_step / t_cost / 1000, prog_bar=True, on_step=True) + self.log("Kt/s", kt_s, prog_bar=True, on_step=True) except: pass trainer.my_time_ns = t_now @@ -89,10 +95,10 @@ class train_callback(pl.Callback): # self.log("s", real_step, prog_bar=True, on_step=True) if len(args.wandb) > 0: - trainer.my_wandb.log( - {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}, - step=int(real_step), - ) + lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9} + if kt_s > 0: + lll["kt/s"] = kt_s + trainer.my_wandb.log(lll, step=int(real_step)) def on_train_epoch_start(self, trainer, pl_module): args = self.args diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 1246990..5b31797 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -93,6 +93,7 @@ if __name__ == "__main__": parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift + parser.add_argument("--my_pile_edecay", default=0, type=int) parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) @@ -117,10 +118,13 @@ if __name__ == "__main__": if args.my_pile_stage > 0: if args.ctx_len == 1024: args.magic_prime = 324331313 + args.epoch_count = 8043 elif args.ctx_len == 2048: args.magic_prime = 162165671 + args.epoch_count = 4021 elif args.ctx_len == 4096: args.magic_prime = 81082817 + args.epoch_count = 2010 if args.my_pile_shift < 0: if args.ctx_len == 1024: args.my_pile_shift = 0 @@ -129,7 +133,6 @@ if __name__ == "__main__": elif args.ctx_len == 4096: args.my_pile_shift = 768 - args.epoch_count = 8043 args.epoch_steps = 40320 // args.real_bsz assert args.epoch_steps * args.real_bsz == 40320 if args.my_pile_stage == 2: