|
|
|
|
@ -275,15 +275,22 @@ class RWKV(pl.LightningModule):
|
|
|
|
|
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
|
if self.args.layerwise_lr > 0:
|
|
|
|
|
args = self.args
|
|
|
|
|
if args.layerwise_lr > 0:
|
|
|
|
|
lr_1x = set()
|
|
|
|
|
lr_2x = set()
|
|
|
|
|
lr_3x = set()
|
|
|
|
|
for n, p in self.named_parameters():
|
|
|
|
|
if "time_mix" in n:
|
|
|
|
|
lr_1x.add(n)
|
|
|
|
|
if args.my_pile_stage == 2:
|
|
|
|
|
lr_2x.add(n)
|
|
|
|
|
else:
|
|
|
|
|
lr_1x.add(n)
|
|
|
|
|
elif "time_decay" in n:
|
|
|
|
|
lr_2x.add(n)
|
|
|
|
|
if args.my_pile_stage == 2:
|
|
|
|
|
lr_3x.add(n)
|
|
|
|
|
else:
|
|
|
|
|
lr_2x.add(n)
|
|
|
|
|
elif "time_first" in n:
|
|
|
|
|
lr_3x.add(n)
|
|
|
|
|
else:
|
|
|
|
|
@ -295,11 +302,18 @@ class RWKV(pl.LightningModule):
|
|
|
|
|
# print('2x', lr_2x)
|
|
|
|
|
# print('3x', lr_3x)
|
|
|
|
|
param_dict = {n: p for n, p in self.named_parameters()}
|
|
|
|
|
optim_groups = [
|
|
|
|
|
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
|
|
|
|
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
|
|
|
|
|
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
|
|
|
|
|
]
|
|
|
|
|
if args.my_pile_stage == 2:
|
|
|
|
|
optim_groups = [
|
|
|
|
|
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
|
|
|
|
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
|
|
|
|
|
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
optim_groups = [
|
|
|
|
|
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
|
|
|
|
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
|
|
|
|
|
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
optim_groups = [
|
|
|
|
|
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
|
|
|
|
|
|