diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 4115979..fc2fc5f 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -275,15 +275,22 @@ class RWKV(pl.LightningModule): self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) def configure_optimizers(self): - if self.args.layerwise_lr > 0: + args = self.args + if args.layerwise_lr > 0: lr_1x = set() lr_2x = set() lr_3x = set() for n, p in self.named_parameters(): if "time_mix" in n: - lr_1x.add(n) + if args.my_pile_stage == 2: + lr_2x.add(n) + else: + lr_1x.add(n) elif "time_decay" in n: - lr_2x.add(n) + if args.my_pile_stage == 2: + lr_3x.add(n) + else: + lr_2x.add(n) elif "time_first" in n: lr_3x.add(n) else: @@ -295,11 +302,18 @@ class RWKV(pl.LightningModule): # print('2x', lr_2x) # print('3x', lr_3x) param_dict = {n: p for n, p in self.named_parameters()} - optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, - ] + if args.my_pile_stage == 2: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, + ] + else: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, + ] else: optim_groups = [ {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index ee4b5bd..6337030 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -19,7 +19,7 @@ class train_callback(pl.Callback): # LR schedule w_step = args.warmup_steps if trainer.global_step < w_step: - lr = args.lr_init * (0.1 + 0.9 * trainer.global_step / w_step) + lr = args.lr_init * (0.2 + 0.8 * trainer.global_step / w_step) else: if args.lr_final == args.lr_init or args.epoch_count == 0: lr = args.lr_init diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index e41f3a0..bd4847c 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -135,18 +135,19 @@ if __name__ == "__main__": if args.my_pile_stage == 2: assert args.lr_final == args.lr_init if args.my_pile_stage >= 2: # find latest saved model - pths = os.listdir(args.proj_dir) - max_p = -1 - for p in pths: + list_p = [] + for p in os.listdir(args.proj_dir): if p.startswith("rwkv") and p.endswith(".pth"): p = ((p.split("-"))[1].split("."))[0] if p == "init": p = -1 else: p = int(p) - if p > max_p: - args.my_pile_prev_p = max_p # in case max_p is corrupted - max_p = p + list_p += [p] + list_p.sort() + max_p = list_p[-1] + if len(list_p) > 1: + args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted if max_p == -1: args.load_model = f"{args.proj_dir}/rwkv-init.pth" else: @@ -163,7 +164,7 @@ if __name__ == "__main__": f""" ############################################################################ # -# RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''} +# RWKV-4 {args.precision.upper()} on {args.devices}x{args.accelerator.upper()}, bsz {args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''} # # Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} #