rapid convergence using ZERO initialization

main
BlinkDL 4 years ago
parent 7f391c5758
commit 3329161ed7

@ -17,7 +17,7 @@ def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in
for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
with torch.no_grad():
name = '[unknown weight]'
for name, parameter in module.named_parameters(): # find the name of the weight
if id(m.weight) == id(parameter):
@ -46,7 +46,9 @@ def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in
print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
gain *= scale
if gain > 0:
if gain == 0:
nn.init.zeros_(m.weight) # zero init is great for some RWKV matrices
elif gain > 0:
nn.init.orthogonal_(m.weight, gain=gain)
else:
nn.init.normal_(m.weight, mean=0, std=-gain)
@ -95,7 +97,7 @@ class RWKV_TimeMix(nn.Module):
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 1 / pow(1+layer_id, config.rwkv_layer_decay) # reduce initial weight in higher layers
self.output.scale_init = 0
def forward(self, x):
B, T, C = x.size()
@ -145,7 +147,7 @@ class RWKV_ChannelMix(nn.Module):
self.receptance = nn.Linear(config.n_embd, config.n_embd)
self.receptance.scale_init = 0
self.weight.scale_init = 1 / pow(1+layer_id, config.rwkv_layer_decay) # reduce initial weight in higher layers
self.weight.scale_init = 0
def forward(self, x):
B, T, C = x.size()

@ -45,9 +45,8 @@ class Trainer:
setattr(cfg, k, config.__dict__[k]) # combine cfg
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
# take over whatever gpus are on the system
self.device = 'cpu'
if torch.cuda.is_available():
if torch.cuda.is_available(): # take over whatever gpus are on the system
self.device = torch.cuda.current_device()
self.model = torch.nn.DataParallel(self.model).to(self.device)
@ -57,8 +56,7 @@ class Trainer:
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
return run_name
def save_checkpoint(self):
# DataParallel wrappers keep raw model object in .module attribute
def save_checkpoint(self): # DataParallel wrappers keep raw model object in .module attribute
raw_model = self.model.module if hasattr(self.model, "module") else self.model
logger.info("saving %s", self.config.ckpt_path)
torch.save(raw_model.state_dict(), self.config.ckpt_path)
@ -94,14 +92,7 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
optimizer.step()
# try:
# torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip, error_if_nonfinite=True)
# optimizer.step()
# except:
# pass # ignore nan sample -> sometimes can continue
# decay the learning rate based on our progress
if config.lr_decay:
if config.lr_decay: # decay the learning rate based on our progress
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
if self.tokens < config.warmup_tokens:
# linear warmup
@ -118,8 +109,7 @@ class Trainer:
else:
lr = config.learning_rate
# report progress
now_loss = loss.item()
now_loss = loss.item() # report progress
if 'wandb' in sys.modules:
wandb.log({"loss": now_loss}, step = self.steps * self.config.batch_size)

@ -24,6 +24,7 @@ model_type = 'RWKV'
# datafile = u"V:\\NLP\\enwik8"
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt"
datafile_encoding = 'utf-8'
# datafile = u"D:\\NLP-Data\\ww100M.txt"
# datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt"
# datafile_encoding = 'utf-16'
@ -51,10 +52,9 @@ weight_decay = 0 if model_type == 'RWKV' else 0.01 # wd is not useful when we h
epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
######## special hyperparameters for RWKV model ########
rwkv_layer_decay = 1.0 # reduce initial weight in higher layers. try 0.5 ~ 1.0
rwkv_emb_scale = 0.4 if datafile_type == 0 else 0.8 # use 0.4 for char-level english, 0.8 for chinese
rwkv_emb_scale = 0.4 # scale of initial embedding. 0.4 is a good choice
rwkv_tiny_attn = 64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english
rwkv_tiny_head = 1 # 1 is good enough
rwkv_tiny_head = 1 # 1 is good enough. 8 is slow
########################################################################################################
# Load data
@ -102,7 +102,7 @@ train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(),
########################################################################################################
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
rwkv_emb_scale=rwkv_emb_scale, rwkv_layer_decay=rwkv_layer_decay, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head,
rwkv_emb_scale=rwkv_emb_scale, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head,
n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn))
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn)

Loading…
Cancel
Save