main
BlinkDL 3 years ago
parent 1189c9e238
commit 24e3cf6e93

@ -281,14 +281,8 @@ class RWKV(pl.LightningModule):
lr_3x = set()
for n, p in self.named_parameters():
if "time_mix" in n:
if self.args.my_pile_stage == 2:
lr_2x.add(n)
else:
lr_1x.add(n)
elif "time_decay" in n:
if self.args.my_pile_stage == 2:
lr_3x.add(n)
else:
lr_2x.add(n)
elif "time_first" in n:
lr_3x.add(n)
@ -301,13 +295,6 @@ class RWKV(pl.LightningModule):
# print('2x', lr_2x)
# print('3x', lr_3x)
param_dict = {n: p for n, p in self.named_parameters()}
if self.args.my_pile_stage == 2:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},
]
else:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},

@ -19,7 +19,7 @@ class train_callback(pl.Callback):
# LR schedule
w_step = args.warmup_steps
if trainer.global_step < w_step:
lr = args.lr_init * (trainer.global_step / w_step)
lr = args.lr_init * (0.1 + 0.9 * trainer.global_step / w_step)
else:
if args.lr_final == args.lr_init or args.epoch_count == 0:
lr = args.lr_init

@ -240,7 +240,6 @@ if __name__ == "__main__":
max_p = args.my_pile_prev_p
if max_p == -1:
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
args.warmup_steps = 0
else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
args.epoch_begin = max_p + 1

Loading…
Cancel
Save