main
BlinkDL 3 years ago
parent 94300caba1
commit ceafd4e7af

@ -24,7 +24,7 @@ class MyDataset(Dataset):
if args.my_pile_stage > 0:
assert self.data_size == 332115325534 and self.vocab_size == 50277 and args.ctx_len == 1024
self.samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
self.samples_per_epoch = args.epoch_steps * args.real_bsz
assert self.samples_per_epoch == 40320
print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
self.magic_prime = 324331313

@ -280,10 +280,16 @@ class RWKV(pl.LightningModule):
lr_2x = set()
lr_3x = set()
for n, p in self.named_parameters():
if ("time_mix" in n) and (self.args.my_pile_stage == 2):
lr_2x.add(n)
if "time_mix" in n:
if self.args.my_pile_stage == 2:
lr_2x.add(n)
else:
lr_1x.add(n)
elif "time_decay" in n:
lr_2x.add(n)
if self.args.my_pile_stage == 2:
lr_3x.add(n)
else:
lr_2x.add(n)
elif "time_first" in n:
lr_3x.add(n)
else:
@ -295,11 +301,18 @@ class RWKV(pl.LightningModule):
# print('2x', lr_2x)
# print('3x', lr_3x)
param_dict = {n: p for n, p in self.named_parameters()}
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
]
if self.args.my_pile_stage == 2:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},
]
else:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
]
else:
optim_groups = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},

@ -34,13 +34,7 @@ class train_callback(pl.Callback):
for param_group in trainer.optimizers[0].param_groups:
if args.layerwise_lr > 0:
if self.args.my_pile_stage != 2:
param_group["lr"] = lr * param_group["my_lr_scale"]
else:
if param_group["my_lr_scale"] > 1:
param_group["lr"] = lr * 5
else:
param_group["lr"] = lr
param_group["lr"] = lr * param_group["my_lr_scale"]
# print(param_group["lr"], param_group["my_lr_scale"])
else:
param_group["lr"] = lr

Loading…
Cancel
Save