clean init code

main
BlinkDL 4 years ago
parent c675b47705
commit ad627311f4

@ -12,17 +12,74 @@ logger = logging.getLogger(__name__)
########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
#
# fancy initialization of lin & emb layers, for faster convergence
# note it will change ALL lin & emb layers in the module (including token emb & final projection)
#
def RWKV_Init(module, config):
for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
name = '[unknown weight]'
for name, parameter in module.named_parameters(): # find the name of the weight
if id(m.weight) == id(parameter):
break
shape = m.weight.data.shape
gain = 1.0 # positive: gain for orthogonal, negative: std for normal
scale = 1.0 # extra scale for gain
if isinstance(m, nn.Linear):
if m.bias is not None:
m.bias.data.zero_()
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
scale = 0.4 # 0.4 is a safe choice, 0.8 is better for chinese
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
scale = 0.4 # 0.4 is a safe choice, 0.8 is better for chinese
if hasattr(m, 'scale_init'):
scale = m.scale_init
print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
gain *= scale
if gain > 0:
nn.init.orthogonal_(m.weight, gain=gain)
else:
nn.init.normal_(m.weight, mean=0, std=-gain)
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
assert config.n_embd % config.n_head == 0
assert config.n_attn % config.n_head == 0
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_head = config.n_head
self.head_size = config.n_embd // config.n_head
self.head_size = config.n_attn // config.n_head
with torch.no_grad(): # build initial time_w curves for better convergence
ww = torch.zeros(config.n_head, config.ctx_len)
curve = torch.tensor([0.9 ** (config.ctx_len - 1 - i) for i in range(config.ctx_len)])
curve = curve * 2 + 0.7
for h in range(config.n_head):
if config.n_head > 1:
mix_strength = 1 - 1.2 * h / (config.n_head - 1) # mix_strength from 1 to -0.2
else:
mix_strength = 0.5
ww[h] = (1 - mix_strength) + curve * mix_strength
# special tweaks because of time_shift
ww[h][config.ctx_len - 3] = (ww[h][config.ctx_len - 3] * 2 + 1) / 3
ww[h][config.ctx_len - 2] = (ww[h][config.ctx_len - 2] * 1 + 2) / 3
ww[h][config.ctx_len - 1] = 1
# print(h, mix_strength, ww[h])
self.time_w = nn.Parameter(ww)
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
@ -30,11 +87,15 @@ class RWKV_TimeMix(nn.Module):
self.time_shift = nn.ZeroPad2d((0,0,1,0))
self.key = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
self.receptance = nn.Linear(config.n_embd, config.n_embd)
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.receptance = nn.Linear(config.n_embd, config.n_attn)
self.output = nn.Linear(config.n_embd, config.n_embd)
self.output = nn.Linear(config.n_attn, config.n_embd)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 1 / pow(1+layer_id, 0.5) # 0.5 ~ 0.7 gives similar results
def forward(self, x):
B, T, C = x.size()
@ -57,7 +118,7 @@ class RWKV_TimeMix(nn.Module):
kv = (k * v).view(B, T, self.n_head, self.head_size)
wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, C)
wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, -1)
rwkv = torch.sigmoid(r) * wkv / sum_k
@ -69,12 +130,15 @@ class RWKV_ChannelMix(nn.Module):
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,0))
hidden_sz = 5 * config.n_embd // 2 # can use smaller hidden_sz because of R
hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of R
self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd)
self.receptance = nn.Linear(config.n_embd, config.n_embd)
self.receptance.scale_init = 0
self.weight.scale_init = 1 / pow(1+layer_id, 0.5) # 0.5 ~ 0.7 gives similar results
def forward(self, x):
B, T, C = x.size()
@ -125,24 +189,24 @@ class MHA_rotary(nn.Module):
def __init__(self, config, layer_id, time_shift = False):
super().__init__()
self.layer_id = layer_id
assert config.n_embd % config.n_head == 0
assert config.n_attn % config.n_head == 0
self.n_head = config.n_head
self.ctx_len = config.ctx_len
self.head_size = config.n_embd // config.n_head
self.head_size = config.n_attn // config.n_head
if time_shift:
self.time_shift = nn.ZeroPad2d((0,0,1,0))
self.query = nn.Linear(config.n_embd, config.n_embd)
self.key = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
self.query = nn.Linear(config.n_embd, config.n_attn)
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.rotary_ndims = int(self.head_size * 0.5)
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
self.output = nn.Linear(config.n_embd, config.n_embd)
self.output = nn.Linear(config.n_attn, config.n_embd)
def forward(self, x):
B, T, C = x.size()
@ -166,7 +230,7 @@ class MHA_rotary(nn.Module):
att = F.softmax(att, dim = -1) # softmax
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
x = self.output(x)
return x
@ -179,7 +243,7 @@ class GeGLU(torch.nn.Module):
if time_shift:
self.time_shift = nn.ZeroPad2d((0,0,1,0))
hidden_sz = 3 * config.n_embd
hidden_sz = 3 * config.n_ffn
self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd)
@ -202,10 +266,10 @@ class MHA_pro(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
assert config.n_embd % config.n_head == 0
assert config.n_attn % config.n_head == 0
self.n_head = config.n_head
self.ctx_len = config.ctx_len
self.head_size = config.n_embd // config.n_head
self.head_size = config.n_attn // config.n_head
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
@ -214,16 +278,16 @@ class MHA_pro(nn.Module):
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.time_shift = nn.ZeroPad2d((0,0,1,0))
self.query = nn.Linear(config.n_embd, config.n_embd)
self.key = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
self.query = nn.Linear(config.n_embd, config.n_attn)
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.rotary_ndims = int(self.head_size * 0.5)
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads
self.output = nn.Linear(config.n_embd, config.n_embd)
self.output = nn.Linear(config.n_attn, config.n_embd)
def forward(self, x):
B, T, C = x.size()
@ -248,12 +312,12 @@ class MHA_pro(nn.Module):
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
att = F.softmax(att, dim = -1) # softmax
att = F.softmax(att, dim = -1) # softmax
att = att * w # time-weighting
att = self.head_mix(att) # talking heads
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
x = self.output(x) * self.time_gamma[:T, :]
return x
@ -338,43 +402,11 @@ class GPT(nn.Module):
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.ctx_len = config.ctx_len
self.apply(self._init_weights)
if self.config.model_type == 'RWKV': # improve orthogonal weight init
ww = self.state_dict()
for k in ww:
if 'tok_emb' in k:
if self.config.vocab_size > self.config.n_embd:
ww[k] *= math.sqrt(self.config.vocab_size)
else:
ww[k] *= math.sqrt(self.config.n_embd)
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese
elif 'head.weight' in k:
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese
elif 'blocks.' in k:
block_id = int(k.split('.')[1])
if 'receptance.weight' in k:
ww[k] *= 0 # init with zero matrix
elif 'attn.key.weight' in k:
ww[k] *= 0 # init with zero matrix
elif 'attn.output.weight' in k:
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
elif 'mlp.weight.weight' in k:
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
elif 'attn.time_w' in k:
curve = torch.tensor([0.9 ** (self.config.ctx_len - 1 - i) for i in range(self.config.ctx_len)])
curve = curve * 2 + 0.7
for h in range(self.config.n_head):
if self.config.n_head > 1:
mix_strength = 1 - 1.2 * h / (self.config.n_head - 1) # mix_strength from 1 to -0.2
else:
mix_strength = 0.5
ww[k][h] = (1 - mix_strength) + curve * mix_strength
# special tweaks because of time_shift
ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 3] * 2 + 1) / 3
ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] * 1 + 2) / 3
ww[k][h][self.config.ctx_len - 1] = 1
# print(k, h, mix_strength, ww[k][h])
if self.config.model_type == 'RWKV':
RWKV_Init(self, config)
else:
self.apply(self._init_weights)
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
@ -383,15 +415,7 @@ class GPT(nn.Module):
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
if self.config.model_type == 'RWKV':
gain = 1.0
if isinstance(module, nn.Linear):
if module.weight.data.shape[0] > module.weight.data.shape[1]:
gain = math.sqrt(module.weight.data.shape[0] / module.weight.data.shape[1])
nn.init.orthogonal_(module.weight, gain=gain)
else:
module.weight.data.normal_(mean=0.0, std=0.01)
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()

@ -1,4 +1,4 @@
import math, sys
import math, sys, datetime
import logging
import numpy as np
from tqdm.auto import tqdm
@ -43,8 +43,7 @@ class Trainer:
cfg = model.config
for k in config.__dict__:
setattr(cfg, k, config.__dict__[k]) # combine cfg
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
wandb.init(project="RWKV-LM", name=run_name + '-' + wandb.util.generate_id(), config=cfg, save_code=False)
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
# take over whatever gpus are on the system
self.device = 'cpu'
@ -52,6 +51,12 @@ class Trainer:
self.device = torch.cuda.current_device()
self.model = torch.nn.DataParallel(self.model).to(self.device)
def get_run_name(self):
raw_model = self.model.module if hasattr(self.model, "module") else self.model
cfg = raw_model.config
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
return run_name
def save_checkpoint(self):
# DataParallel wrappers keep raw model object in .module attribute
raw_model = self.model.module if hasattr(self.model, "module") else self.model

@ -13,10 +13,10 @@ from src.utils import set_seed
set_seed(42)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
# RWKV : our new model - fastest when ctx_len is long - VRAM friendly - good performance
# MHA_rotary : usual Multi-head Attention+Rotary+GeGLU - not as good
# MHA_rotary : usual MultiheadAttention+Rotary+GeGLU - not as good
# MHA_shift : with time-shift - good performance
# MHA_pro : slow (lots of tricks) - VRAM hungry - very good performance
model_type = 'RWKV'
@ -34,6 +34,8 @@ ctx_len = 256 # context length
n_layer = 5
n_head = 8
n_embd = n_head * 64
n_attn = n_embd
n_ffn = n_embd
batch_size = 64
@ -54,7 +56,7 @@ print('loading data... ' + datafile)
class Dataset(Dataset):
def __init__(self, data, model_level, ctx_len):
print('building token list...')
print('building token list...', end=' ')
if model_level == 'word':
import re
data = re.sub(r'(\n|\.|\,|\?|\!|\:|\;|\-|\—|\||\'|\"|\`|\(|\)|[0-9]|\[|\]|\{|\}|\=|\+|\*|\\|\/|\~|\&|\$|\#|\%)', r' \g<0> ', data)
@ -62,10 +64,12 @@ class Dataset(Dataset):
print('splitting token...')
data = data.lower().split(' ')
unique = sorted(list(set(data)))
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
data_size, vocab_size = len(data), len(unique)
print('\n\ndata has %d %ss, %d unique.' % (data_size, model_level, vocab_size))
print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size))
self.stoi = { ch:i for i,ch in enumerate(unique) }
self.itos = { i:ch for i,ch in enumerate(unique) }
self.ctx_len = ctx_len
@ -90,9 +94,9 @@ train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(),
########################################################################################################
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
n_layer=n_layer, n_head=n_head, n_embd=n_embd))
n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn))
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'ctx', ctx_len)
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn)
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps,
warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0)
@ -100,7 +104,7 @@ trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()
torch.save(model, 'trained-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')
torch.save(model, 'trained-' + trainer.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')
########################################################################################################
# Run model to generate text

Loading…
Cancel
Save