fixing nan in large models

main
BlinkDL 4 years ago
parent d699a69169
commit 62e2cb06d6

@ -292,8 +292,9 @@ class Block(nn.Module):
self.ln2 = nn.LayerNorm(config.n_embd)
if config.model_type == 'RWKV':
self.ln1 = nn.Identity() # remove first LayerNorm -> faster convergence for deep models
self.ln2 = SimpleRMSNorm(config.n_embd) # SimpleRMSNorm is good enough for RWKV -> less parameters
self.ln1 = nn.Identity()
# self.ln1 = SimpleRMSNorm(config.n_embd) # turn on this if you see nan in large RWKV models
self.ln2 = SimpleRMSNorm(config.n_embd)
self.attn = RWKV_TimeMix(config)
self.mlp = RWKV_ChannelMix(config)
elif config.model_type == 'RotaryMHA':
@ -319,7 +320,7 @@ class GPT(nn.Module):
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
if config.model_type == 'RWKV':
self.ln_f = SimpleRMSNorm(config.n_embd) # SimpleRMSNorm is good enough for RWKV -> less parameters
self.ln_f = SimpleRMSNorm(config.n_embd)
else:
self.ln_f = nn.LayerNorm(config.n_embd)
@ -328,6 +329,16 @@ class GPT(nn.Module):
self.ctx_size = config.ctx_size
self.apply(self._init_weights)
if self.config.model_type == 'RWKV':
ww = self.state_dict()
for k in ww: # reduce weight to avoid nan
if 'receptance.weight' in k:
ww[k] /= math.pow(config.n_embd, 0.5)
elif 'key.weight' in k:
ww[k] /= math.pow(config.n_embd, 0.25)
elif 'value.weight' in k:
ww[k] /= math.pow(config.n_embd, 0.25)
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def get_ctx_size(self):
@ -335,11 +346,14 @@ class GPT(nn.Module):
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# if self.config.model_type == 'RWKV' and isinstance(module, nn.Linear):
# gain_layer = min(3, module.weight.shape[0] / module.weight.shape[1])
# depth_factor = min(1, 1 / math.sqrt(self.config.n_layer / 5))
# nn.init.orthogonal_(module.weight, gain = gain_layer * depth_factor) # will nan for large models
# else:
if self.config.model_type == 'RWKV':
if isinstance(module, nn.Linear):
gain_layer = min(3, module.weight.shape[0] / module.weight.shape[1])
depth_factor = 1 # min(1, 1 / math.sqrt(self.config.n_layer / 5))
nn.init.orthogonal_(module.weight, gain = gain_layer * depth_factor)
else:
nn.init.orthogonal_(module.weight, gain = 1.0)
else:
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()

@ -29,8 +29,8 @@ model_level = 'character' # 'character' or 'word'
ctx_size = 256 if model_level == 'character' else 128
nLayers = 5
nHead = 8
nEmb = 512
nHead = 8 # if you see nan in large RWKV models, turn on 'self.ln1' in model.py
nEmb = nHead * 64
lr_initial = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr
lr_final = 2e-4
@ -39,7 +39,7 @@ lr_initial /= math.sqrt(nLayers / 5) # lower lr for deep models; higher lr for s
lr_final /= math.sqrt(nLayers / 5)
betas = (0.9, 0.99)
weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when you have enough data
weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when we have enough data
nepoch = 50 # just a quick test. the 'epoch' here is very short
nbatchsz = 64

Loading…
Cancel
Save