|
|
|
|
@ -281,14 +281,8 @@ class RWKV(pl.LightningModule):
|
|
|
|
|
lr_3x = set()
|
|
|
|
|
for n, p in self.named_parameters():
|
|
|
|
|
if "time_mix" in n:
|
|
|
|
|
if self.args.my_pile_stage == 2:
|
|
|
|
|
lr_2x.add(n)
|
|
|
|
|
else:
|
|
|
|
|
lr_1x.add(n)
|
|
|
|
|
elif "time_decay" in n:
|
|
|
|
|
if self.args.my_pile_stage == 2:
|
|
|
|
|
lr_3x.add(n)
|
|
|
|
|
else:
|
|
|
|
|
lr_2x.add(n)
|
|
|
|
|
elif "time_first" in n:
|
|
|
|
|
lr_3x.add(n)
|
|
|
|
|
@ -301,13 +295,6 @@ class RWKV(pl.LightningModule):
|
|
|
|
|
# print('2x', lr_2x)
|
|
|
|
|
# print('3x', lr_3x)
|
|
|
|
|
param_dict = {n: p for n, p in self.named_parameters()}
|
|
|
|
|
if self.args.my_pile_stage == 2:
|
|
|
|
|
optim_groups = [
|
|
|
|
|
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
|
|
|
|
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
|
|
|
|
|
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
optim_groups = [
|
|
|
|
|
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
|
|
|
|
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
|
|
|
|
|
|