main
BlinkDL 3 years ago
parent 24e3cf6e93
commit 470ac7d1fa

@ -275,15 +275,22 @@ class RWKV(pl.LightningModule):
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
def configure_optimizers(self): def configure_optimizers(self):
if self.args.layerwise_lr > 0: args = self.args
if args.layerwise_lr > 0:
lr_1x = set() lr_1x = set()
lr_2x = set() lr_2x = set()
lr_3x = set() lr_3x = set()
for n, p in self.named_parameters(): for n, p in self.named_parameters():
if "time_mix" in n: if "time_mix" in n:
lr_1x.add(n) if args.my_pile_stage == 2:
lr_2x.add(n)
else:
lr_1x.add(n)
elif "time_decay" in n: elif "time_decay" in n:
lr_2x.add(n) if args.my_pile_stage == 2:
lr_3x.add(n)
else:
lr_2x.add(n)
elif "time_first" in n: elif "time_first" in n:
lr_3x.add(n) lr_3x.add(n)
else: else:
@ -295,11 +302,18 @@ class RWKV(pl.LightningModule):
# print('2x', lr_2x) # print('2x', lr_2x)
# print('3x', lr_3x) # print('3x', lr_3x)
param_dict = {n: p for n, p in self.named_parameters()} param_dict = {n: p for n, p in self.named_parameters()}
optim_groups = [ if args.my_pile_stage == 2:
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, optim_groups = [
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
] {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
]
else:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
]
else: else:
optim_groups = [ optim_groups = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},

@ -19,7 +19,7 @@ class train_callback(pl.Callback):
# LR schedule # LR schedule
w_step = args.warmup_steps w_step = args.warmup_steps
if trainer.global_step < w_step: if trainer.global_step < w_step:
lr = args.lr_init * (0.1 + 0.9 * trainer.global_step / w_step) lr = args.lr_init * (0.2 + 0.8 * trainer.global_step / w_step)
else: else:
if args.lr_final == args.lr_init or args.epoch_count == 0: if args.lr_final == args.lr_init or args.epoch_count == 0:
lr = args.lr_init lr = args.lr_init

@ -135,18 +135,19 @@ if __name__ == "__main__":
if args.my_pile_stage == 2: if args.my_pile_stage == 2:
assert args.lr_final == args.lr_init assert args.lr_final == args.lr_init
if args.my_pile_stage >= 2: # find latest saved model if args.my_pile_stage >= 2: # find latest saved model
pths = os.listdir(args.proj_dir) list_p = []
max_p = -1 for p in os.listdir(args.proj_dir):
for p in pths:
if p.startswith("rwkv") and p.endswith(".pth"): if p.startswith("rwkv") and p.endswith(".pth"):
p = ((p.split("-"))[1].split("."))[0] p = ((p.split("-"))[1].split("."))[0]
if p == "init": if p == "init":
p = -1 p = -1
else: else:
p = int(p) p = int(p)
if p > max_p: list_p += [p]
args.my_pile_prev_p = max_p # in case max_p is corrupted list_p.sort()
max_p = p max_p = list_p[-1]
if len(list_p) > 1:
args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted
if max_p == -1: if max_p == -1:
args.load_model = f"{args.proj_dir}/rwkv-init.pth" args.load_model = f"{args.proj_dir}/rwkv-init.pth"
else: else:
@ -163,7 +164,7 @@ if __name__ == "__main__":
f""" f"""
############################################################################ ############################################################################
# #
# RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''} # RWKV-4 {args.precision.upper()} on {args.devices}x{args.accelerator.upper()}, bsz {args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
# #
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} # Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
# #

Loading…
Cancel
Save