From e9fbd9bf701811c5e7dc0481895ede72a6cb7e40 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Wed, 11 Aug 2021 14:39:57 +0800 Subject: [PATCH] remove layernorm -> better RWKV --- src/model.py | 36 +++++++++++++++++++++++++++++++++--- train.py | 14 +++++++++++--- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/model.py b/src/model.py index ef22713..35ea089 100644 --- a/src/model.py +++ b/src/model.py @@ -237,7 +237,7 @@ class RotaryMHA_Plus(nn.Module): # The GPT Model with our blocks ######################################################################################################## -class LabelSmoothingCrossEntropy(nn.Module): # might be able to avoid nan loss +class LabelSmoothingCrossEntropy(nn.Module): # can avoid nan loss def __init__(self, smoothing=0.0): super().__init__() self.confidence = 1.0 - smoothing @@ -251,6 +251,29 @@ class LabelSmoothingCrossEntropy(nn.Module): # might be able to avoid nan loss true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) return torch.mean(torch.sum(-true_dist * pred, dim=-1)) +class RMSNorm(nn.Module): + def __init__(self, d): + super().__init__() + self.dd = d ** (-1. / 2) + self.weight = nn.Parameter(torch.ones(d)) + + def forward(self, x): + norm_x = x.norm(2, dim=-1, keepdim=True) + x_normed = x / (norm_x * self.dd + 1e-12) + return self.weight * x_normed + +class SimpleRMSNorm(nn.Module): + def __init__(self, d): + super().__init__() + self.dd = d ** (-1. / 2) + + def forward(self, x): + norm_x = x.norm(2, dim=-1, keepdim=True) + x_normed = x / (norm_x * self.dd + 1e-12) + return x_normed + +######################################################################################################## + class GPTConfig: def __init__(self, vocab_size, ctx_size, **kwargs): self.vocab_size = vocab_size @@ -266,6 +289,8 @@ class Block(nn.Module): self.ln2 = nn.LayerNorm(config.n_embd) if config.model_type == 'RWKV': + self.ln1 = nn.Identity() # remove first LayerNorm -> faster convergence for deep models + self.ln2 = SimpleRMSNorm(config.n_embd) # SimpleRMSNorm is good enough for RWKV -> less parameters self.attn = RWKV_TimeMix(config) self.mlp = RWKV_ChannelMix(config) elif config.model_type == 'RotaryMHA': @@ -278,6 +303,7 @@ class Block(nn.Module): def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) + return x class GPT(nn.Module): @@ -288,7 +314,11 @@ class GPT(nn.Module): self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) - self.ln_f = nn.LayerNorm(config.n_embd) + if config.model_type == 'RWKV': + self.ln_f = SimpleRMSNorm(config.n_embd) # SimpleRMSNorm is good enough for RWKV -> less parameters + else: + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.ctx_size = config.ctx_size @@ -311,7 +341,7 @@ class GPT(nn.Module): no_decay = set() whitelist_weight_modules = (nn.Linear, ) - blacklist_weight_modules = (nn.LayerNorm, nn.Embedding) + blacklist_weight_modules = (RMSNorm, nn.LayerNorm, nn.Embedding) for mn, m in self.named_modules(): for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # full param name diff --git a/train.py b/train.py index 429f741..8374232 100644 --- a/train.py +++ b/train.py @@ -19,6 +19,10 @@ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' or 'MHA-Plus' datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip +datafile_encoding = 'utf-8' +# datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt" +# datafile_encoding = 'utf-16' + model_level = 'character' # 'character' or 'word' ctx_size = 256 if model_level == 'character' else 128 @@ -26,6 +30,10 @@ nLayers = 5 nHead = 8 nEmb = 512 +lr_initial = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher LR +lr_final = 2e-4 +betas = (0.9, 0.99) + nepoch = 50 # just a quick test. the 'epoch' here is very short nbatchsz = 64 epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress @@ -65,7 +73,7 @@ class Dataset(Dataset): y = torch.tensor(dix[1:], dtype=torch.long) return x, y -train_dataset = Dataset(open(datafile, "r", encoding="utf-8").read(), model_level, ctx_size) +train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_size) ######################################################################################################## @@ -74,8 +82,8 @@ model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_ty print('model', model_type, 'total epoch', nepoch, 'batchsz', nbatchsz, 'nLayers', nLayers, 'nHead', nHead, 'nEmb', nEmb, 'len', ctx_size) tconf = TrainerConfig(model_type=model_type, max_epochs=nepoch, batch_size=nbatchsz, - learning_rate=6e-4 if model_type == 'RWKV' else 4e-4, betas=(0.9, 0.99), # RWKV can use higher LR - lr_decay=True, lr_final=2e-4, warmup_tokens=0, final_tokens=nepoch*len(train_dataset)*ctx_size, num_workers=0) + learning_rate=lr_initial, lr_decay=True, lr_final=lr_final, betas=betas, + warmup_tokens=0, final_tokens=nepoch*len(train_dataset)*ctx_size, num_workers=0) trainer = Trainer(model, train_dataset, None, tconf) trainer.train()