From 62e2cb06d677ab8fd4c865a7c6b860bd1349c7eb Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Wed, 11 Aug 2021 22:11:12 +0800 Subject: [PATCH] fixing nan in large models --- src/model.py | 38 ++++++++++++++++++++++++++------------ train.py | 14 +++++++------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/src/model.py b/src/model.py index 20131a0..ca1d91c 100644 --- a/src/model.py +++ b/src/model.py @@ -80,7 +80,7 @@ class RWKV_ChannelMix(nn.Module): v = self.value(x) r = self.receptance(x) - wkv = self.weight(F.mish(k) * v) # mish is a bit better than gelu + wkv = self.weight(F.mish(k) * v) # mish is a bit better than gelu y = torch.sigmoid(r) * wkv return y @@ -292,8 +292,9 @@ class Block(nn.Module): self.ln2 = nn.LayerNorm(config.n_embd) if config.model_type == 'RWKV': - self.ln1 = nn.Identity() # remove first LayerNorm -> faster convergence for deep models - self.ln2 = SimpleRMSNorm(config.n_embd) # SimpleRMSNorm is good enough for RWKV -> less parameters + self.ln1 = nn.Identity() + # self.ln1 = SimpleRMSNorm(config.n_embd) # turn on this if you see nan in large RWKV models + self.ln2 = SimpleRMSNorm(config.n_embd) self.attn = RWKV_TimeMix(config) self.mlp = RWKV_ChannelMix(config) elif config.model_type == 'RotaryMHA': @@ -319,7 +320,7 @@ class GPT(nn.Module): self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) if config.model_type == 'RWKV': - self.ln_f = SimpleRMSNorm(config.n_embd) # SimpleRMSNorm is good enough for RWKV -> less parameters + self.ln_f = SimpleRMSNorm(config.n_embd) else: self.ln_f = nn.LayerNorm(config.n_embd) @@ -328,6 +329,16 @@ class GPT(nn.Module): self.ctx_size = config.ctx_size self.apply(self._init_weights) + if self.config.model_type == 'RWKV': + ww = self.state_dict() + for k in ww: # reduce weight to avoid nan + if 'receptance.weight' in k: + ww[k] /= math.pow(config.n_embd, 0.5) + elif 'key.weight' in k: + ww[k] /= math.pow(config.n_embd, 0.25) + elif 'value.weight' in k: + ww[k] /= math.pow(config.n_embd, 0.25) + logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) def get_ctx_size(self): @@ -335,14 +346,17 @@ class GPT(nn.Module): def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): - # if self.config.model_type == 'RWKV' and isinstance(module, nn.Linear): - # gain_layer = min(3, module.weight.shape[0] / module.weight.shape[1]) - # depth_factor = min(1, 1 / math.sqrt(self.config.n_layer / 5)) - # nn.init.orthogonal_(module.weight, gain = gain_layer * depth_factor) # will nan for large models - # else: - module.weight.data.normal_(mean=0.0, std=0.01) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + if self.config.model_type == 'RWKV': + if isinstance(module, nn.Linear): + gain_layer = min(3, module.weight.shape[0] / module.weight.shape[1]) + depth_factor = 1 # min(1, 1 / math.sqrt(self.config.n_layer / 5)) + nn.init.orthogonal_(module.weight, gain = gain_layer * depth_factor) + else: + nn.init.orthogonal_(module.weight, gain = 1.0) + else: + module.weight.data.normal_(mean=0.0, std=0.01) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() def configure_optimizers(self, train_config): # separate out all parameters to those that will and won't experience regularizing weight decay diff --git a/train.py b/train.py index bb4498c..0b1a7f2 100644 --- a/train.py +++ b/train.py @@ -29,21 +29,21 @@ model_level = 'character' # 'character' or 'word' ctx_size = 256 if model_level == 'character' else 128 nLayers = 5 -nHead = 8 -nEmb = 512 +nHead = 8 # if you see nan in large RWKV models, turn on 'self.ln1' in model.py +nEmb = nHead * 64 -lr_initial = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr +lr_initial = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr lr_final = 2e-4 -lr_initial /= math.sqrt(nLayers / 5) # lower lr for deep models; higher lr for shallow models +lr_initial /= math.sqrt(nLayers / 5) # lower lr for deep models; higher lr for shallow models lr_final /= math.sqrt(nLayers / 5) betas = (0.9, 0.99) -weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when you have enough data +weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when we have enough data -nepoch = 50 # just a quick test. the 'epoch' here is very short +nepoch = 50 # just a quick test. the 'epoch' here is very short nbatchsz = 64 -epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress +epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress ######################################################################################################## # Load data