From ad627311f4ad37adfd04e8434346cdf4c4205591 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Tue, 17 Aug 2021 14:33:33 +0800 Subject: [PATCH] clean init code --- src/model.py | 166 ++++++++++++++++++++++++++++--------------------- src/trainer.py | 11 +++- train.py | 18 +++--- 3 files changed, 114 insertions(+), 81 deletions(-) diff --git a/src/model.py b/src/model.py index 3fdaddb..dcd8153 100644 --- a/src/model.py +++ b/src/model.py @@ -12,17 +12,74 @@ logger = logging.getLogger(__name__) ######################################################################################################## # RWKV: RWKV Time-mix + RWKV Channel-mix ######################################################################################################## +# +# fancy initialization of lin & emb layers, for faster convergence +# note it will change ALL lin & emb layers in the module (including token emb & final projection) +# +def RWKV_Init(module, config): + for m in module.modules(): + if not isinstance(m, (nn.Linear, nn.Embedding)): + continue + + name = '[unknown weight]' + for name, parameter in module.named_parameters(): # find the name of the weight + if id(m.weight) == id(parameter): + break + + shape = m.weight.data.shape + gain = 1.0 # positive: gain for orthogonal, negative: std for normal + scale = 1.0 # extra scale for gain + + if isinstance(m, nn.Linear): + if m.bias is not None: + m.bias.data.zero_() + if shape[0] > shape[1]: + gain = math.sqrt(shape[0] / shape[1]) + if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection? + scale = 0.4 # 0.4 is a safe choice, 0.8 is better for chinese + + if isinstance(m, nn.Embedding): + gain = math.sqrt(max(shape[0], shape[1])) + if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb? + scale = 0.4 # 0.4 is a safe choice, 0.8 is better for chinese + + if hasattr(m, 'scale_init'): + scale = m.scale_init + + print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name) + + gain *= scale + if gain > 0: + nn.init.orthogonal_(m.weight, gain=gain) + else: + nn.init.normal_(m.weight, mean=0, std=-gain) class RWKV_TimeMix(nn.Module): def __init__(self, config, layer_id): super().__init__() - assert config.n_embd % config.n_head == 0 + assert config.n_attn % config.n_head == 0 self.layer_id = layer_id self.ctx_len = config.ctx_len self.n_head = config.n_head - self.head_size = config.n_embd // config.n_head + self.head_size = config.n_attn // config.n_head + + with torch.no_grad(): # build initial time_w curves for better convergence + ww = torch.zeros(config.n_head, config.ctx_len) + curve = torch.tensor([0.9 ** (config.ctx_len - 1 - i) for i in range(config.ctx_len)]) + curve = curve * 2 + 0.7 + for h in range(config.n_head): + if config.n_head > 1: + mix_strength = 1 - 1.2 * h / (config.n_head - 1) # mix_strength from 1 to -0.2 + else: + mix_strength = 0.5 + ww[h] = (1 - mix_strength) + curve * mix_strength + # special tweaks because of time_shift + ww[h][config.ctx_len - 3] = (ww[h][config.ctx_len - 3] * 2 + 1) / 3 + ww[h][config.ctx_len - 2] = (ww[h][config.ctx_len - 2] * 1 + 2) / 3 + ww[h][config.ctx_len - 1] = 1 + # print(h, mix_strength, ww[h]) + self.time_w = nn.Parameter(ww) - self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len)) self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len)) self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1)) self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1)) @@ -30,11 +87,15 @@ class RWKV_TimeMix(nn.Module): self.time_shift = nn.ZeroPad2d((0,0,1,0)) - self.key = nn.Linear(config.n_embd, config.n_embd) - self.value = nn.Linear(config.n_embd, config.n_embd) - self.receptance = nn.Linear(config.n_embd, config.n_embd) + self.key = nn.Linear(config.n_embd, config.n_attn) + self.value = nn.Linear(config.n_embd, config.n_attn) + self.receptance = nn.Linear(config.n_embd, config.n_attn) - self.output = nn.Linear(config.n_embd, config.n_embd) + self.output = nn.Linear(config.n_attn, config.n_embd) + + self.key.scale_init = 0 + self.receptance.scale_init = 0 + self.output.scale_init = 1 / pow(1+layer_id, 0.5) # 0.5 ~ 0.7 gives similar results def forward(self, x): B, T, C = x.size() @@ -57,7 +118,7 @@ class RWKV_TimeMix(nn.Module): kv = (k * v).view(B, T, self.n_head, self.head_size) - wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, C) + wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, -1) rwkv = torch.sigmoid(r) * wkv / sum_k @@ -69,12 +130,15 @@ class RWKV_ChannelMix(nn.Module): self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0,0,1,0)) - hidden_sz = 5 * config.n_embd // 2 # can use smaller hidden_sz because of R + hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of R self.key = nn.Linear(config.n_embd, hidden_sz) self.value = nn.Linear(config.n_embd, hidden_sz) self.weight = nn.Linear(hidden_sz, config.n_embd) self.receptance = nn.Linear(config.n_embd, config.n_embd) + self.receptance.scale_init = 0 + self.weight.scale_init = 1 / pow(1+layer_id, 0.5) # 0.5 ~ 0.7 gives similar results + def forward(self, x): B, T, C = x.size() @@ -125,24 +189,24 @@ class MHA_rotary(nn.Module): def __init__(self, config, layer_id, time_shift = False): super().__init__() self.layer_id = layer_id - assert config.n_embd % config.n_head == 0 + assert config.n_attn % config.n_head == 0 self.n_head = config.n_head self.ctx_len = config.ctx_len - self.head_size = config.n_embd // config.n_head + self.head_size = config.n_attn // config.n_head if time_shift: self.time_shift = nn.ZeroPad2d((0,0,1,0)) - self.query = nn.Linear(config.n_embd, config.n_embd) - self.key = nn.Linear(config.n_embd, config.n_embd) - self.value = nn.Linear(config.n_embd, config.n_embd) + self.query = nn.Linear(config.n_embd, config.n_attn) + self.key = nn.Linear(config.n_embd, config.n_attn) + self.value = nn.Linear(config.n_embd, config.n_attn) self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) self.rotary_ndims = int(self.head_size * 0.5) self.rotary_emb = RotaryEmbedding(self.rotary_ndims) - self.output = nn.Linear(config.n_embd, config.n_embd) + self.output = nn.Linear(config.n_attn, config.n_embd) def forward(self, x): B, T, C = x.size() @@ -166,7 +230,7 @@ class MHA_rotary(nn.Module): att = F.softmax(att, dim = -1) # softmax x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) - x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) + x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) x = self.output(x) return x @@ -179,7 +243,7 @@ class GeGLU(torch.nn.Module): if time_shift: self.time_shift = nn.ZeroPad2d((0,0,1,0)) - hidden_sz = 3 * config.n_embd + hidden_sz = 3 * config.n_ffn self.key = nn.Linear(config.n_embd, hidden_sz) self.value = nn.Linear(config.n_embd, hidden_sz) self.weight = nn.Linear(hidden_sz, config.n_embd) @@ -202,10 +266,10 @@ class MHA_pro(nn.Module): def __init__(self, config, layer_id): super().__init__() self.layer_id = layer_id - assert config.n_embd % config.n_head == 0 + assert config.n_attn % config.n_head == 0 self.n_head = config.n_head self.ctx_len = config.ctx_len - self.head_size = config.n_embd // config.n_head + self.head_size = config.n_attn // config.n_head self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len)) self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len)) @@ -214,16 +278,16 @@ class MHA_pro(nn.Module): self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) self.time_shift = nn.ZeroPad2d((0,0,1,0)) - self.query = nn.Linear(config.n_embd, config.n_embd) - self.key = nn.Linear(config.n_embd, config.n_embd) - self.value = nn.Linear(config.n_embd, config.n_embd) + self.query = nn.Linear(config.n_embd, config.n_attn) + self.key = nn.Linear(config.n_embd, config.n_attn) + self.value = nn.Linear(config.n_embd, config.n_attn) self.rotary_ndims = int(self.head_size * 0.5) self.rotary_emb = RotaryEmbedding(self.rotary_ndims) self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads - self.output = nn.Linear(config.n_embd, config.n_embd) + self.output = nn.Linear(config.n_attn, config.n_embd) def forward(self, x): B, T, C = x.size() @@ -248,12 +312,12 @@ class MHA_pro(nn.Module): att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask - att = F.softmax(att, dim = -1) # softmax + att = F.softmax(att, dim = -1) # softmax att = att * w # time-weighting att = self.head_mix(att) # talking heads x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) - x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) + x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) x = self.output(x) * self.time_gamma[:T, :] return x @@ -338,43 +402,11 @@ class GPT(nn.Module): self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.ctx_len = config.ctx_len - self.apply(self._init_weights) - - if self.config.model_type == 'RWKV': # improve orthogonal weight init - ww = self.state_dict() - for k in ww: - if 'tok_emb' in k: - if self.config.vocab_size > self.config.n_embd: - ww[k] *= math.sqrt(self.config.vocab_size) - else: - ww[k] *= math.sqrt(self.config.n_embd) - ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese - elif 'head.weight' in k: - ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese - elif 'blocks.' in k: - block_id = int(k.split('.')[1]) - if 'receptance.weight' in k: - ww[k] *= 0 # init with zero matrix - elif 'attn.key.weight' in k: - ww[k] *= 0 # init with zero matrix - elif 'attn.output.weight' in k: - ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results - elif 'mlp.weight.weight' in k: - ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results - elif 'attn.time_w' in k: - curve = torch.tensor([0.9 ** (self.config.ctx_len - 1 - i) for i in range(self.config.ctx_len)]) - curve = curve * 2 + 0.7 - for h in range(self.config.n_head): - if self.config.n_head > 1: - mix_strength = 1 - 1.2 * h / (self.config.n_head - 1) # mix_strength from 1 to -0.2 - else: - mix_strength = 0.5 - ww[k][h] = (1 - mix_strength) + curve * mix_strength - # special tweaks because of time_shift - ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 3] * 2 + 1) / 3 - ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] * 1 + 2) / 3 - ww[k][h][self.config.ctx_len - 1] = 1 - # print(k, h, mix_strength, ww[k][h]) + + if self.config.model_type == 'RWKV': + RWKV_Init(self, config) + else: + self.apply(self._init_weights) logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) @@ -383,15 +415,7 @@ class GPT(nn.Module): def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): - if self.config.model_type == 'RWKV': - gain = 1.0 - if isinstance(module, nn.Linear): - if module.weight.data.shape[0] > module.weight.data.shape[1]: - gain = math.sqrt(module.weight.data.shape[0] / module.weight.data.shape[1]) - nn.init.orthogonal_(module.weight, gain=gain) - else: - module.weight.data.normal_(mean=0.0, std=0.01) - + module.weight.data.normal_(mean=0.0, std=0.01) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() diff --git a/src/trainer.py b/src/trainer.py index ec1ff66..60978d1 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -1,4 +1,4 @@ -import math, sys +import math, sys, datetime import logging import numpy as np from tqdm.auto import tqdm @@ -43,8 +43,7 @@ class Trainer: cfg = model.config for k in config.__dict__: setattr(cfg, k, config.__dict__[k]) # combine cfg - run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) - wandb.init(project="RWKV-LM", name=run_name + '-' + wandb.util.generate_id(), config=cfg, save_code=False) + wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False) # take over whatever gpus are on the system self.device = 'cpu' @@ -52,6 +51,12 @@ class Trainer: self.device = torch.cuda.current_device() self.model = torch.nn.DataParallel(self.model).to(self.device) + def get_run_name(self): + raw_model = self.model.module if hasattr(self.model, "module") else self.model + cfg = raw_model.config + run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) + return run_name + def save_checkpoint(self): # DataParallel wrappers keep raw model object in .module attribute raw_model = self.model.module if hasattr(self.model, "module") else self.model diff --git a/train.py b/train.py index 188b438..f4b6a87 100644 --- a/train.py +++ b/train.py @@ -13,10 +13,10 @@ from src.utils import set_seed set_seed(42) np.set_printoptions(precision=4, suppress=True, linewidth=200) -logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,) +logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,) # RWKV : our new model - fastest when ctx_len is long - VRAM friendly - good performance -# MHA_rotary : usual Multi-head Attention+Rotary+GeGLU - not as good +# MHA_rotary : usual MultiheadAttention+Rotary+GeGLU - not as good # MHA_shift : with time-shift - good performance # MHA_pro : slow (lots of tricks) - VRAM hungry - very good performance model_type = 'RWKV' @@ -34,6 +34,8 @@ ctx_len = 256 # context length n_layer = 5 n_head = 8 n_embd = n_head * 64 +n_attn = n_embd +n_ffn = n_embd batch_size = 64 @@ -54,7 +56,7 @@ print('loading data... ' + datafile) class Dataset(Dataset): def __init__(self, data, model_level, ctx_len): - print('building token list...') + print('building token list...', end=' ') if model_level == 'word': import re data = re.sub(r'(\n|\.|\,|\?|\!|\:|\;|\-|\—|\||\'|\"|\`|\(|\)|[0-9]|\[|\]|\{|\}|\=|\+|\*|\\|\/|\~|\&|\$|\#|\%)', r' \g<0> ', data) @@ -62,10 +64,12 @@ class Dataset(Dataset): print('splitting token...') data = data.lower().split(' ') unique = sorted(list(set(data))) + # print() # for u in unique: # print(u, end=' ') + # print('\n\n') data_size, vocab_size = len(data), len(unique) - print('\n\ndata has %d %ss, %d unique.' % (data_size, model_level, vocab_size)) + print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size)) self.stoi = { ch:i for i,ch in enumerate(unique) } self.itos = { i:ch for i,ch in enumerate(unique) } self.ctx_len = ctx_len @@ -90,9 +94,9 @@ train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), ######################################################################################################## model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type, - n_layer=n_layer, n_head=n_head, n_embd=n_embd)) + n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn)) -print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'ctx', ctx_len) +print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn) tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay, learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0) @@ -100,7 +104,7 @@ trainer = Trainer(model, train_dataset, None, tconf) trainer.train() -torch.save(model, 'trained-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') +torch.save(model, 'trained-' + trainer.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') ######################################################################################################## # Run model to generate text