From 440bebff1a097d31d2017d4c2ff8421518c0a56b Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Thu, 12 Aug 2021 12:15:27 +0800 Subject: [PATCH] fixed nan in large models --- src/model.py | 70 ++++++++++++++++++++++++++++++---------------------- train.py | 12 ++++----- 2 files changed, 46 insertions(+), 36 deletions(-) diff --git a/src/model.py b/src/model.py index ca1d91c..ca5adab 100644 --- a/src/model.py +++ b/src/model.py @@ -14,9 +14,10 @@ logger = logging.getLogger(__name__) ######################################################################################################## class RWKV_TimeMix(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_id): super().__init__() assert config.n_embd % config.n_head == 0 + self.layer_id = layer_id self.ctx_size = config.ctx_size self.n_head = config.n_head self.head_size = config.n_embd // config.n_head @@ -58,13 +59,14 @@ class RWKV_TimeMix(nn.Module): wkv = (torch.einsum('htu,buhc->bthc', w, k * v)).contiguous().view(B, T, C) y = torch.sigmoid(r) * wkv / sum_k - + y = self.output(y) * self.time_gamma[:T, :] return y class RWKV_ChannelMix(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_id): super().__init__() + self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0,0,1,0)) self.key = nn.Linear(config.n_embd, 3 * config.n_embd) @@ -265,7 +267,7 @@ class RMSNorm(nn.Module): x_normed = x / (norm_x * self.dd + 1e-12) return self.weight * x_normed -class SimpleRMSNorm(nn.Module): +class FixedNorm(nn.Module): def __init__(self, d): super().__init__() self.dd = d ** (-1. / 2) @@ -285,18 +287,17 @@ class GPTConfig: setattr(self, k, v) class Block(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_id): super().__init__() self.ln1 = nn.LayerNorm(config.n_embd) self.ln2 = nn.LayerNorm(config.n_embd) if config.model_type == 'RWKV': - self.ln1 = nn.Identity() - # self.ln1 = SimpleRMSNorm(config.n_embd) # turn on this if you see nan in large RWKV models - self.ln2 = SimpleRMSNorm(config.n_embd) - self.attn = RWKV_TimeMix(config) - self.mlp = RWKV_ChannelMix(config) + self.ln1 = FixedNorm(config.n_embd) + self.ln2 = FixedNorm(config.n_embd) + self.attn = RWKV_TimeMix(config, layer_id) + self.mlp = RWKV_ChannelMix(config, layer_id) elif config.model_type == 'RotaryMHA': self.attn = RotaryMHA(config) self.mlp = GeGLU(config) @@ -305,6 +306,7 @@ class Block(nn.Module): self.mlp = RWKV_ChannelMix(config) def forward(self, x): + x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) @@ -317,10 +319,10 @@ class GPT(nn.Module): self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) - self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) + self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)]) if config.model_type == 'RWKV': - self.ln_f = SimpleRMSNorm(config.n_embd) + self.ln_f = FixedNorm(config.n_embd) else: self.ln_f = nn.LayerNorm(config.n_embd) @@ -329,15 +331,23 @@ class GPT(nn.Module): self.ctx_size = config.ctx_size self.apply(self._init_weights) - if self.config.model_type == 'RWKV': + if self.config.model_type == 'RWKV': # improve orthogonal weight init ww = self.state_dict() - for k in ww: # reduce weight to avoid nan - if 'receptance.weight' in k: - ww[k] /= math.pow(config.n_embd, 0.5) - elif 'key.weight' in k: - ww[k] /= math.pow(config.n_embd, 0.25) - elif 'value.weight' in k: - ww[k] /= math.pow(config.n_embd, 0.25) + for k in ww: + if 'tok_emb' in k: + if self.config.vocab_size > self.config.n_embd: + ww[k] *= math.sqrt(self.config.vocab_size) + else: + ww[k] *= math.sqrt(self.config.n_embd) + ww[k] *= 0.4 + elif 'head.weight' in k: + ww[k] *= 0.2 + elif 'blocks.' in k: + block_id = int(k.split('.')[1]) + if 'receptance.weight' in k: + ww[k] *= 0.5 + elif 'attn.key.weight' in k: + ww[k] *= 0.2 logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) @@ -347,14 +357,14 @@ class GPT(nn.Module): def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): if self.config.model_type == 'RWKV': + gain = 1.0 if isinstance(module, nn.Linear): - gain_layer = min(3, module.weight.shape[0] / module.weight.shape[1]) - depth_factor = 1 # min(1, 1 / math.sqrt(self.config.n_layer / 5)) - nn.init.orthogonal_(module.weight, gain = gain_layer * depth_factor) - else: - nn.init.orthogonal_(module.weight, gain = 1.0) + if module.weight.data.shape[0] > module.weight.data.shape[1]: + gain = math.sqrt(module.weight.data.shape[0] / module.weight.data.shape[1]) + nn.init.orthogonal_(module.weight, gain=gain) else: module.weight.data.normal_(mean=0.0, std=0.01) + if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() @@ -396,14 +406,14 @@ class GPT(nn.Module): assert T <= self.ctx_size, "Cannot forward, model block size is exhausted." x = self.tok_emb(idx) - + x = self.blocks(x) - + x = self.ln_f(x) - logits = self.head(x) + x = self.head(x) loss = None if targets is not None: - loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(logits.view(-1, logits.size(-1)), targets.view(-1)) + loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(x.view(-1, x.size(-1)), targets.view(-1)) - return logits, loss + return x, loss diff --git a/train.py b/train.py index 0b1a7f2..2734b4e 100644 --- a/train.py +++ b/train.py @@ -29,21 +29,21 @@ model_level = 'character' # 'character' or 'word' ctx_size = 256 if model_level == 'character' else 128 nLayers = 5 -nHead = 8 # if you see nan in large RWKV models, turn on 'self.ln1' in model.py +nHead = 8 nEmb = nHead * 64 -lr_initial = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr +lr_initial = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr lr_final = 2e-4 -lr_initial /= math.sqrt(nLayers / 5) # lower lr for deep models; higher lr for shallow models +lr_initial /= math.sqrt(nLayers / 5) # lower lr for deep models; higher lr for shallow models lr_final /= math.sqrt(nLayers / 5) betas = (0.9, 0.99) -weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when we have enough data +weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when we have enough data -nepoch = 50 # just a quick test. the 'epoch' here is very short +nepoch = 50 # just a quick test. the 'epoch' here is very short nbatchsz = 64 -epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress +epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress ######################################################################################################## # Load data