From 3329161ed70c892274d4427c8154a32998f0fca7 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Fri, 20 Aug 2021 03:07:23 +0800 Subject: [PATCH] rapid convergence using ZERO initialization --- src/model.py | 72 ++++++++++++++++++++++++++------------------------ src/trainer.py | 18 +++---------- train.py | 8 +++--- 3 files changed, 45 insertions(+), 53 deletions(-) diff --git a/src/model.py b/src/model.py index 2f9e79c..18b6250 100644 --- a/src/model.py +++ b/src/model.py @@ -17,39 +17,41 @@ def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in for m in module.modules(): if not isinstance(m, (nn.Linear, nn.Embedding)): continue - - name = '[unknown weight]' - for name, parameter in module.named_parameters(): # find the name of the weight - if id(m.weight) == id(parameter): - break - - shape = m.weight.data.shape - gain = 1.0 # positive: gain for orthogonal, negative: std for normal - scale = 1.0 # extra scale for gain - - if isinstance(m, nn.Linear): - if m.bias is not None: - m.bias.data.zero_() - if shape[0] > shape[1]: - gain = math.sqrt(shape[0] / shape[1]) - if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection? - scale = config.rwkv_emb_scale - - if isinstance(m, nn.Embedding): - gain = math.sqrt(max(shape[0], shape[1])) - if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb? - scale = config.rwkv_emb_scale - - if hasattr(m, 'scale_init'): - scale = m.scale_init - - print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name) - - gain *= scale - if gain > 0: - nn.init.orthogonal_(m.weight, gain=gain) - else: - nn.init.normal_(m.weight, mean=0, std=-gain) + with torch.no_grad(): + name = '[unknown weight]' + for name, parameter in module.named_parameters(): # find the name of the weight + if id(m.weight) == id(parameter): + break + + shape = m.weight.data.shape + gain = 1.0 # positive: gain for orthogonal, negative: std for normal + scale = 1.0 # extra scale for gain + + if isinstance(m, nn.Linear): + if m.bias is not None: + m.bias.data.zero_() + if shape[0] > shape[1]: + gain = math.sqrt(shape[0] / shape[1]) + if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection? + scale = config.rwkv_emb_scale + + if isinstance(m, nn.Embedding): + gain = math.sqrt(max(shape[0], shape[1])) + if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb? + scale = config.rwkv_emb_scale + + if hasattr(m, 'scale_init'): + scale = m.scale_init + + print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name) + + gain *= scale + if gain == 0: + nn.init.zeros_(m.weight) # zero init is great for some RWKV matrices + elif gain > 0: + nn.init.orthogonal_(m.weight, gain=gain) + else: + nn.init.normal_(m.weight, mean=0, std=-gain) class RWKV_TimeMix(nn.Module): def __init__(self, config, layer_id): @@ -95,7 +97,7 @@ class RWKV_TimeMix(nn.Module): self.key.scale_init = 0 self.receptance.scale_init = 0 - self.output.scale_init = 1 / pow(1+layer_id, config.rwkv_layer_decay) # reduce initial weight in higher layers + self.output.scale_init = 0 def forward(self, x): B, T, C = x.size() @@ -145,7 +147,7 @@ class RWKV_ChannelMix(nn.Module): self.receptance = nn.Linear(config.n_embd, config.n_embd) self.receptance.scale_init = 0 - self.weight.scale_init = 1 / pow(1+layer_id, config.rwkv_layer_decay) # reduce initial weight in higher layers + self.weight.scale_init = 0 def forward(self, x): B, T, C = x.size() diff --git a/src/trainer.py b/src/trainer.py index 60978d1..56af8e1 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -45,9 +45,8 @@ class Trainer: setattr(cfg, k, config.__dict__[k]) # combine cfg wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False) - # take over whatever gpus are on the system self.device = 'cpu' - if torch.cuda.is_available(): + if torch.cuda.is_available(): # take over whatever gpus are on the system self.device = torch.cuda.current_device() self.model = torch.nn.DataParallel(self.model).to(self.device) @@ -57,8 +56,7 @@ class Trainer: run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) return run_name - def save_checkpoint(self): - # DataParallel wrappers keep raw model object in .module attribute + def save_checkpoint(self): # DataParallel wrappers keep raw model object in .module attribute raw_model = self.model.module if hasattr(self.model, "module") else self.model logger.info("saving %s", self.config.ckpt_path) torch.save(raw_model.state_dict(), self.config.ckpt_path) @@ -94,14 +92,7 @@ class Trainer: torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) optimizer.step() - # try: - # torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip, error_if_nonfinite=True) - # optimizer.step() - # except: - # pass # ignore nan sample -> sometimes can continue - - # decay the learning rate based on our progress - if config.lr_decay: + if config.lr_decay: # decay the learning rate based on our progress self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) if self.tokens < config.warmup_tokens: # linear warmup @@ -118,8 +109,7 @@ class Trainer: else: lr = config.learning_rate - # report progress - now_loss = loss.item() + now_loss = loss.item() # report progress if 'wandb' in sys.modules: wandb.log({"loss": now_loss}, step = self.steps * self.config.batch_size) diff --git a/train.py b/train.py index 53a639f..f16f5e0 100644 --- a/train.py +++ b/train.py @@ -24,6 +24,7 @@ model_type = 'RWKV' # datafile = u"V:\\NLP\\enwik8" datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" datafile_encoding = 'utf-8' +# datafile = u"D:\\NLP-Data\\ww100M.txt" # datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt" # datafile_encoding = 'utf-16' @@ -51,10 +52,9 @@ weight_decay = 0 if model_type == 'RWKV' else 0.01 # wd is not useful when we h epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress ######## special hyperparameters for RWKV model ######## -rwkv_layer_decay = 1.0 # reduce initial weight in higher layers. try 0.5 ~ 1.0 -rwkv_emb_scale = 0.4 if datafile_type == 0 else 0.8 # use 0.4 for char-level english, 0.8 for chinese +rwkv_emb_scale = 0.4 # scale of initial embedding. 0.4 is a good choice rwkv_tiny_attn = 64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english -rwkv_tiny_head = 1 # 1 is good enough +rwkv_tiny_head = 1 # 1 is good enough. 8 is slow ######################################################################################################## # Load data @@ -102,7 +102,7 @@ train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), ######################################################################################################## model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type, - rwkv_emb_scale=rwkv_emb_scale, rwkv_layer_decay=rwkv_layer_decay, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head, + rwkv_emb_scale=rwkv_emb_scale, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head, n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn)) print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn)