From ceafd4e7af8cf722da48c3323fd79f3aae61894f Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Tue, 6 Sep 2022 17:03:11 +0000 Subject: [PATCH] better --- RWKV-v4neo/src/dataset.py | 2 +- RWKV-v4neo/src/model.py | 29 +++++++++++++++++++++-------- RWKV-v4neo/src/trainer.py | 8 +------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index f0e8262..ccd9f00 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -24,7 +24,7 @@ class MyDataset(Dataset): if args.my_pile_stage > 0: assert self.data_size == 332115325534 and self.vocab_size == 50277 and args.ctx_len == 1024 - self.samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz + self.samples_per_epoch = args.epoch_steps * args.real_bsz assert self.samples_per_epoch == 40320 print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") self.magic_prime = 324331313 diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 68e63c0..afbe750 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -280,10 +280,16 @@ class RWKV(pl.LightningModule): lr_2x = set() lr_3x = set() for n, p in self.named_parameters(): - if ("time_mix" in n) and (self.args.my_pile_stage == 2): - lr_2x.add(n) + if "time_mix" in n: + if self.args.my_pile_stage == 2: + lr_2x.add(n) + else: + lr_1x.add(n) elif "time_decay" in n: - lr_2x.add(n) + if self.args.my_pile_stage == 2: + lr_3x.add(n) + else: + lr_2x.add(n) elif "time_first" in n: lr_3x.add(n) else: @@ -295,11 +301,18 @@ class RWKV(pl.LightningModule): # print('2x', lr_2x) # print('3x', lr_3x) param_dict = {n: p for n, p in self.named_parameters()} - optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, - ] + if self.args.my_pile_stage == 2: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0}, + ] + else: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, + ] else: optim_groups = [ {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index 0cc35a1..81bea20 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -34,13 +34,7 @@ class train_callback(pl.Callback): for param_group in trainer.optimizers[0].param_groups: if args.layerwise_lr > 0: - if self.args.my_pile_stage != 2: - param_group["lr"] = lr * param_group["my_lr_scale"] - else: - if param_group["my_lr_scale"] > 1: - param_group["lr"] = lr * 5 - else: - param_group["lr"] = lr + param_group["lr"] = lr * param_group["my_lr_scale"] # print(param_group["lr"], param_group["my_lr_scale"]) else: param_group["lr"] = lr