rapid convergence using ZERO initialization

main
BlinkDL 4 years ago
parent 7f391c5758
commit 3329161ed7

@ -17,39 +17,41 @@ def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in
for m in module.modules(): for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)): if not isinstance(m, (nn.Linear, nn.Embedding)):
continue continue
with torch.no_grad():
name = '[unknown weight]' name = '[unknown weight]'
for name, parameter in module.named_parameters(): # find the name of the weight for name, parameter in module.named_parameters(): # find the name of the weight
if id(m.weight) == id(parameter): if id(m.weight) == id(parameter):
break break
shape = m.weight.data.shape shape = m.weight.data.shape
gain = 1.0 # positive: gain for orthogonal, negative: std for normal gain = 1.0 # positive: gain for orthogonal, negative: std for normal
scale = 1.0 # extra scale for gain scale = 1.0 # extra scale for gain
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
if m.bias is not None: if m.bias is not None:
m.bias.data.zero_() m.bias.data.zero_()
if shape[0] > shape[1]: if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1]) gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection? if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
scale = config.rwkv_emb_scale scale = config.rwkv_emb_scale
if isinstance(m, nn.Embedding): if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1])) gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb? if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
scale = config.rwkv_emb_scale scale = config.rwkv_emb_scale
if hasattr(m, 'scale_init'): if hasattr(m, 'scale_init'):
scale = m.scale_init scale = m.scale_init
print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name) print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
gain *= scale gain *= scale
if gain > 0: if gain == 0:
nn.init.orthogonal_(m.weight, gain=gain) nn.init.zeros_(m.weight) # zero init is great for some RWKV matrices
else: elif gain > 0:
nn.init.normal_(m.weight, mean=0, std=-gain) nn.init.orthogonal_(m.weight, gain=gain)
else:
nn.init.normal_(m.weight, mean=0, std=-gain)
class RWKV_TimeMix(nn.Module): class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id): def __init__(self, config, layer_id):
@ -95,7 +97,7 @@ class RWKV_TimeMix(nn.Module):
self.key.scale_init = 0 self.key.scale_init = 0
self.receptance.scale_init = 0 self.receptance.scale_init = 0
self.output.scale_init = 1 / pow(1+layer_id, config.rwkv_layer_decay) # reduce initial weight in higher layers self.output.scale_init = 0
def forward(self, x): def forward(self, x):
B, T, C = x.size() B, T, C = x.size()
@ -145,7 +147,7 @@ class RWKV_ChannelMix(nn.Module):
self.receptance = nn.Linear(config.n_embd, config.n_embd) self.receptance = nn.Linear(config.n_embd, config.n_embd)
self.receptance.scale_init = 0 self.receptance.scale_init = 0
self.weight.scale_init = 1 / pow(1+layer_id, config.rwkv_layer_decay) # reduce initial weight in higher layers self.weight.scale_init = 0
def forward(self, x): def forward(self, x):
B, T, C = x.size() B, T, C = x.size()

@ -45,9 +45,8 @@ class Trainer:
setattr(cfg, k, config.__dict__[k]) # combine cfg setattr(cfg, k, config.__dict__[k]) # combine cfg
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False) wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
# take over whatever gpus are on the system
self.device = 'cpu' self.device = 'cpu'
if torch.cuda.is_available(): if torch.cuda.is_available(): # take over whatever gpus are on the system
self.device = torch.cuda.current_device() self.device = torch.cuda.current_device()
self.model = torch.nn.DataParallel(self.model).to(self.device) self.model = torch.nn.DataParallel(self.model).to(self.device)
@ -57,8 +56,7 @@ class Trainer:
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
return run_name return run_name
def save_checkpoint(self): def save_checkpoint(self): # DataParallel wrappers keep raw model object in .module attribute
# DataParallel wrappers keep raw model object in .module attribute
raw_model = self.model.module if hasattr(self.model, "module") else self.model raw_model = self.model.module if hasattr(self.model, "module") else self.model
logger.info("saving %s", self.config.ckpt_path) logger.info("saving %s", self.config.ckpt_path)
torch.save(raw_model.state_dict(), self.config.ckpt_path) torch.save(raw_model.state_dict(), self.config.ckpt_path)
@ -94,14 +92,7 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
optimizer.step() optimizer.step()
# try: if config.lr_decay: # decay the learning rate based on our progress
# torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip, error_if_nonfinite=True)
# optimizer.step()
# except:
# pass # ignore nan sample -> sometimes can continue
# decay the learning rate based on our progress
if config.lr_decay:
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
if self.tokens < config.warmup_tokens: if self.tokens < config.warmup_tokens:
# linear warmup # linear warmup
@ -118,8 +109,7 @@ class Trainer:
else: else:
lr = config.learning_rate lr = config.learning_rate
# report progress now_loss = loss.item() # report progress
now_loss = loss.item()
if 'wandb' in sys.modules: if 'wandb' in sys.modules:
wandb.log({"loss": now_loss}, step = self.steps * self.config.batch_size) wandb.log({"loss": now_loss}, step = self.steps * self.config.batch_size)

@ -24,6 +24,7 @@ model_type = 'RWKV'
# datafile = u"V:\\NLP\\enwik8" # datafile = u"V:\\NLP\\enwik8"
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt"
datafile_encoding = 'utf-8' datafile_encoding = 'utf-8'
# datafile = u"D:\\NLP-Data\\ww100M.txt"
# datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt" # datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt"
# datafile_encoding = 'utf-16' # datafile_encoding = 'utf-16'
@ -51,10 +52,9 @@ weight_decay = 0 if model_type == 'RWKV' else 0.01 # wd is not useful when we h
epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
######## special hyperparameters for RWKV model ######## ######## special hyperparameters for RWKV model ########
rwkv_layer_decay = 1.0 # reduce initial weight in higher layers. try 0.5 ~ 1.0 rwkv_emb_scale = 0.4 # scale of initial embedding. 0.4 is a good choice
rwkv_emb_scale = 0.4 if datafile_type == 0 else 0.8 # use 0.4 for char-level english, 0.8 for chinese
rwkv_tiny_attn = 64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english rwkv_tiny_attn = 64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english
rwkv_tiny_head = 1 # 1 is good enough rwkv_tiny_head = 1 # 1 is good enough. 8 is slow
######################################################################################################## ########################################################################################################
# Load data # Load data
@ -102,7 +102,7 @@ train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(),
######################################################################################################## ########################################################################################################
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type, model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
rwkv_emb_scale=rwkv_emb_scale, rwkv_layer_decay=rwkv_layer_decay, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head, rwkv_emb_scale=rwkv_emb_scale, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head,
n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn)) n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn))
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn) print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn)

Loading…
Cancel
Save