From 24e3cf6e937f6b973b33121678a34962d98b5e43 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Wed, 7 Sep 2022 14:22:50 +0000 Subject: [PATCH] better --- RWKV-v4neo/src/model.py | 27 +++++++-------------------- RWKV-v4neo/src/trainer.py | 2 +- RWKV-v4neo/train.py | 1 - 3 files changed, 8 insertions(+), 22 deletions(-) diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 9dbf6ac..4115979 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -281,15 +281,9 @@ class RWKV(pl.LightningModule): lr_3x = set() for n, p in self.named_parameters(): if "time_mix" in n: - if self.args.my_pile_stage == 2: - lr_2x.add(n) - else: - lr_1x.add(n) + lr_1x.add(n) elif "time_decay" in n: - if self.args.my_pile_stage == 2: - lr_3x.add(n) - else: - lr_2x.add(n) + lr_2x.add(n) elif "time_first" in n: lr_3x.add(n) else: @@ -301,18 +295,11 @@ class RWKV(pl.LightningModule): # print('2x', lr_2x) # print('3x', lr_3x) param_dict = {n: p for n, p in self.named_parameters()} - if self.args.my_pile_stage == 2: - optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0}, - ] - else: - optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, - ] + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, + ] else: optim_groups = [ {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index 74f3ab9..ee4b5bd 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -19,7 +19,7 @@ class train_callback(pl.Callback): # LR schedule w_step = args.warmup_steps if trainer.global_step < w_step: - lr = args.lr_init * (trainer.global_step / w_step) + lr = args.lr_init * (0.1 + 0.9 * trainer.global_step / w_step) else: if args.lr_final == args.lr_init or args.epoch_count == 0: lr = args.lr_init diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 734489c..e41f3a0 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -240,7 +240,6 @@ if __name__ == "__main__": max_p = args.my_pile_prev_p if max_p == -1: args.load_model = f"{args.proj_dir}/rwkv-init.pth" - args.warmup_steps = 0 else: args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" args.epoch_begin = max_p + 1