main
BlinkDL 3 years ago
parent 1189c9e238
commit 24e3cf6e93

@ -281,15 +281,9 @@ class RWKV(pl.LightningModule):
lr_3x = set() lr_3x = set()
for n, p in self.named_parameters(): for n, p in self.named_parameters():
if "time_mix" in n: if "time_mix" in n:
if self.args.my_pile_stage == 2: lr_1x.add(n)
lr_2x.add(n)
else:
lr_1x.add(n)
elif "time_decay" in n: elif "time_decay" in n:
if self.args.my_pile_stage == 2: lr_2x.add(n)
lr_3x.add(n)
else:
lr_2x.add(n)
elif "time_first" in n: elif "time_first" in n:
lr_3x.add(n) lr_3x.add(n)
else: else:
@ -301,18 +295,11 @@ class RWKV(pl.LightningModule):
# print('2x', lr_2x) # print('2x', lr_2x)
# print('3x', lr_3x) # print('3x', lr_3x)
param_dict = {n: p for n, p in self.named_parameters()} param_dict = {n: p for n, p in self.named_parameters()}
if self.args.my_pile_stage == 2: optim_groups = [
optim_groups = [ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0}, ]
]
else:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
]
else: else:
optim_groups = [ optim_groups = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},

@ -19,7 +19,7 @@ class train_callback(pl.Callback):
# LR schedule # LR schedule
w_step = args.warmup_steps w_step = args.warmup_steps
if trainer.global_step < w_step: if trainer.global_step < w_step:
lr = args.lr_init * (trainer.global_step / w_step) lr = args.lr_init * (0.1 + 0.9 * trainer.global_step / w_step)
else: else:
if args.lr_final == args.lr_init or args.epoch_count == 0: if args.lr_final == args.lr_init or args.epoch_count == 0:
lr = args.lr_init lr = args.lr_init

@ -240,7 +240,6 @@ if __name__ == "__main__":
max_p = args.my_pile_prev_p max_p = args.my_pile_prev_p
if max_p == -1: if max_p == -1:
args.load_model = f"{args.proj_dir}/rwkv-init.pth" args.load_model = f"{args.proj_dir}/rwkv-init.pth"
args.warmup_steps = 0
else: else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
args.epoch_begin = max_p + 1 args.epoch_begin = max_p + 1

Loading…
Cancel
Save