add wandb, and rename variables

main
BlinkDL 4 years ago
parent 440bebff1a
commit 3b60c5b266

1
.gitignore vendored

@ -4,6 +4,7 @@
*.xlsb *.xlsb
*.xlsx *.xlsx
*.xls *.xls
wandb/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

@ -10,7 +10,7 @@ from torch.nn import functional as F
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
######################################################################################################## ########################################################################################################
# Block: RWKV Time-mix + RWKV Channel-mix # RWKV: RWKV Time-mix + RWKV Channel-mix
######################################################################################################## ########################################################################################################
class RWKV_TimeMix(nn.Module): class RWKV_TimeMix(nn.Module):
@ -18,15 +18,15 @@ class RWKV_TimeMix(nn.Module):
super().__init__() super().__init__()
assert config.n_embd % config.n_head == 0 assert config.n_embd % config.n_head == 0
self.layer_id = layer_id self.layer_id = layer_id
self.ctx_size = config.ctx_size self.ctx_len = config.ctx_len
self.n_head = config.n_head self.n_head = config.n_head
self.head_size = config.n_embd // config.n_head self.head_size = config.n_embd // config.n_head
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_size)) self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_size)) self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_size, 1)) self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_size, 1)) self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size))) self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.time_shift = nn.ZeroPad2d((0,0,1,0)) self.time_shift = nn.ZeroPad2d((0,0,1,0))
@ -38,7 +38,7 @@ class RWKV_TimeMix(nn.Module):
def forward(self, x): def forward(self, x):
B, T, C = x.size() B, T, C = x.size()
TT = self.ctx_size TT = self.ctx_len
w = F.pad(self.time_w, (0, TT)) w = F.pad(self.time_w, (0, TT))
w = torch.tile(w, [TT]) w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
@ -88,7 +88,7 @@ class RWKV_ChannelMix(nn.Module):
return y return y
######################################################################################################## ########################################################################################################
# Block: Multi-head Attention + Rotary Encoding + GeGLU FFN # MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
######################################################################################################## ########################################################################################################
class RotaryEmbedding(torch.nn.Module): class RotaryEmbedding(torch.nn.Module):
@ -119,19 +119,20 @@ def apply_rotary_pos_emb(q, k, cos, sin):
cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:] cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:]
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
class RotaryMHA(nn.Module): class MHA_rotary(nn.Module):
def __init__(self, config): def __init__(self, config, layer_id):
super().__init__() super().__init__()
self.layer_id = layer_id
assert config.n_embd % config.n_head == 0 assert config.n_embd % config.n_head == 0
self.n_head = config.n_head self.n_head = config.n_head
self.ctx_size = config.ctx_size self.ctx_len = config.ctx_len
self.head_size = config.n_embd // config.n_head self.head_size = config.n_embd // config.n_head
self.query = nn.Linear(config.n_embd, config.n_embd) self.query = nn.Linear(config.n_embd, config.n_embd)
self.key = nn.Linear(config.n_embd, config.n_embd) self.key = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd) self.value = nn.Linear(config.n_embd, config.n_embd)
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size))) self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.rotary_ndims = int(self.head_size * 0.5) self.rotary_ndims = int(self.head_size * 0.5)
self.rotary_emb = RotaryEmbedding(self.rotary_ndims) self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
@ -163,8 +164,9 @@ class RotaryMHA(nn.Module):
return x return x
class GeGLU(torch.nn.Module): class GeGLU(torch.nn.Module):
def __init__(self, config): def __init__(self, config, layer_id):
super().__init__() super().__init__()
self.layer_id = layer_id
self.key = nn.Linear(config.n_embd, 3 * config.n_embd) self.key = nn.Linear(config.n_embd, 3 * config.n_embd)
self.value = nn.Linear(config.n_embd, 3 * config.n_embd) self.value = nn.Linear(config.n_embd, 3 * config.n_embd)
self.weight = nn.Linear(3 * config.n_embd, config.n_embd) self.weight = nn.Linear(3 * config.n_embd, config.n_embd)
@ -176,22 +178,23 @@ class GeGLU(torch.nn.Module):
return y return y
######################################################################################################## ########################################################################################################
# Block: MHA+ (with even more tricks) # MHA_pro: with more tricks
######################################################################################################## ########################################################################################################
class RotaryMHA_Plus(nn.Module): class MHA_pro(nn.Module):
def __init__(self, config): def __init__(self, config, layer_id):
super().__init__() super().__init__()
self.layer_id = layer_id
assert config.n_embd % config.n_head == 0 assert config.n_embd % config.n_head == 0
self.n_head = config.n_head self.n_head = config.n_head
self.ctx_size = config.ctx_size self.ctx_len = config.ctx_len
self.head_size = config.n_embd // config.n_head self.head_size = config.n_embd // config.n_head
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_size)) self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_size)) self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_size, 1)) self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_size, 1)) self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size))) self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.time_shift = nn.ZeroPad2d((0,0,1,0)) self.time_shift = nn.ZeroPad2d((0,0,1,0))
self.query = nn.Linear(config.n_embd, config.n_embd) self.query = nn.Linear(config.n_embd, config.n_embd)
@ -207,7 +210,7 @@ class RotaryMHA_Plus(nn.Module):
def forward(self, x): def forward(self, x):
B, T, C = x.size() B, T, C = x.size()
TT = self.ctx_size TT = self.ctx_len
w = F.pad(self.time_w, (0, TT)) w = F.pad(self.time_w, (0, TT))
w = torch.tile(w, [TT]) w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
@ -280,9 +283,9 @@ class FixedNorm(nn.Module):
######################################################################################################## ########################################################################################################
class GPTConfig: class GPTConfig:
def __init__(self, vocab_size, ctx_size, **kwargs): def __init__(self, vocab_size, ctx_len, **kwargs):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.ctx_size = ctx_size self.ctx_len = ctx_len
for k,v in kwargs.items(): for k,v in kwargs.items():
setattr(self, k, v) setattr(self, k, v)
@ -298,12 +301,12 @@ class Block(nn.Module):
self.ln2 = FixedNorm(config.n_embd) self.ln2 = FixedNorm(config.n_embd)
self.attn = RWKV_TimeMix(config, layer_id) self.attn = RWKV_TimeMix(config, layer_id)
self.mlp = RWKV_ChannelMix(config, layer_id) self.mlp = RWKV_ChannelMix(config, layer_id)
elif config.model_type == 'RotaryMHA': elif config.model_type == 'MHA_rotary':
self.attn = RotaryMHA(config) self.attn = MHA_rotary(config, layer_id)
self.mlp = GeGLU(config) self.mlp = GeGLU(config, layer_id)
elif config.model_type == 'MHA-Plus': elif config.model_type == 'MHA_pro':
self.attn = RotaryMHA_Plus(config) self.attn = MHA_pro(config, layer_id)
self.mlp = RWKV_ChannelMix(config) self.mlp = RWKV_ChannelMix(config, layer_id)
def forward(self, x): def forward(self, x):
@ -328,10 +331,15 @@ class GPT(nn.Module):
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.ctx_size = config.ctx_size self.ctx_len = config.ctx_len
self.apply(self._init_weights) self.apply(self._init_weights)
if self.config.model_type == 'RWKV': # improve orthogonal weight init if self.config.model_type == 'RWKV': # improve orthogonal weight init
token_diversity = pow(self.config.vocab_size / 200, 1/3)
token_diversity = 0.4 * min(max(token_diversity, 1), 2) # 200 -> 0.4, 1600 -> 0.8. ENG-char 0.4 CHN-char 0.8
print('token_diversity', token_diversity)
ww = self.state_dict() ww = self.state_dict()
for k in ww: for k in ww:
if 'tok_emb' in k: if 'tok_emb' in k:
@ -339,20 +347,24 @@ class GPT(nn.Module):
ww[k] *= math.sqrt(self.config.vocab_size) ww[k] *= math.sqrt(self.config.vocab_size)
else: else:
ww[k] *= math.sqrt(self.config.n_embd) ww[k] *= math.sqrt(self.config.n_embd)
ww[k] *= 0.4 ww[k] *= token_diversity
elif 'head.weight' in k: elif 'head.weight' in k:
ww[k] *= 0.2 ww[k] *= token_diversity
elif 'blocks.' in k: elif 'blocks.' in k:
block_id = int(k.split('.')[1]) block_id = int(k.split('.')[1])
if 'receptance.weight' in k: if 'receptance.weight' in k:
ww[k] *= 0.5 ww[k] *= 0.2 # 0.2 ~ 0.5
elif 'attn.key.weight' in k: elif 'attn.key.weight' in k:
ww[k] *= 0.2 ww[k] *= 0.2 # 0.2 ~ 0.5
elif 'attn.output.weight' in k:
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7
elif 'mlp.weight.weight' in k:
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def get_ctx_size(self): def get_ctx_len(self):
return self.ctx_size return self.ctx_len
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
@ -403,7 +415,7 @@ class GPT(nn.Module):
def forward(self, idx, targets=None): def forward(self, idx, targets=None):
B, T = idx.size() B, T = idx.size()
assert T <= self.ctx_size, "Cannot forward, model block size is exhausted." assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.tok_emb(idx) x = self.tok_emb(idx)

@ -1,4 +1,4 @@
import math import math, sys
import logging import logging
import numpy as np import numpy as np
from tqdm.auto import tqdm from tqdm.auto import tqdm
@ -8,16 +8,19 @@ from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
print('logging to wandb... (comment it if you don\'t have wandb)')
import wandb # comment it if you don't have wandb
class TrainerConfig: class TrainerConfig:
max_epochs = 10 max_epochs = 10
batch_size = 64 batch_size = 64
learning_rate = 3e-4 learning_rate = 4e-4
betas = (0.9, 0.95) betas = (0.9, 0.95)
grad_norm_clip = 1.0 grad_norm_clip = 1.0
weight_decay = 0.01 weight_decay = 0.01
lr_decay = False # learning rate decay params: linear warmup followed by cosine decay lr_decay = False # linear warmup followed by cosine decay
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper
final_tokens = 260e9 # (at what point we reach 10% of original LR) final_tokens = 260e9 # at which point do we reach lr_final
ckpt_path = None ckpt_path = None
num_workers = 0 # for DataLoader num_workers = 0 # for DataLoader
@ -33,6 +36,12 @@ class Trainer:
self.test_dataset = test_dataset self.test_dataset = test_dataset
self.config = config self.config = config
self.avg_loss = -1 self.avg_loss = -1
self.steps = 0
if 'wandb' in sys.modules:
cfg = model.config
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
wandb.init(project="RWKV-LM", name=run_name + '-' + wandb.util.generate_id(), config=config, save_code=False)
# take over whatever gpus are on the system # take over whatever gpus are on the system
self.device = 'cpu' self.device = 'cpu'
@ -101,6 +110,11 @@ class Trainer:
# report progress # report progress
now_loss = loss.item() now_loss = loss.item()
if 'wandb' in sys.modules:
wandb.log({"loss": now_loss}, step = self.steps * self.config.batch_size)
self.steps += 1
if self.avg_loss < 0: if self.avg_loss < 0:
self.avg_loss = now_loss self.avg_loss = now_loss
else: else:

@ -15,10 +15,10 @@ set_seed(42)
np.set_printoptions(precision=4, suppress=True, linewidth=200) np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,) logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,)
# RWKV is our proposed model - fastest when the ctx window is long - good performance # RWKV - our new model - fastest when ctx_len is long - VRAM friendly - good performance
# RotaryMHA is usual Multi-head Attention + Rotary Encoding + GeGLU FFN # MHA_rotary - usual Multi-head Attention+Rotary+GeGLU - not as good
# MHA-Plus is a bit slow (lots of tricks), with excellent performance # MHA_pro - slow (lots of tricks) - VRAM hungry - good performance
model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' or 'MHA-Plus' model_type = 'RWKV' # 'RWKV' or 'MHA_rotary' or 'MHA_pro'
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip
datafile_encoding = 'utf-8' datafile_encoding = 'utf-8'
@ -27,22 +27,19 @@ datafile_encoding = 'utf-8'
model_level = 'character' # 'character' or 'word' model_level = 'character' # 'character' or 'word'
ctx_size = 256 if model_level == 'character' else 128 ctx_len = 256 # length of ctx window
nLayers = 5 n_layer = 5
nHead = 8 n_head = 8
nEmb = nHead * 64 n_embd = n_head * 64
lr_initial = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr batch_size = 64
lr_final = 2e-4
lr_initial /= math.sqrt(nLayers / 5) # lower lr for deep models; higher lr for shallow models n_epoch = 50 # the 'epoch' here is very short
lr_final /= math.sqrt(nLayers / 5) lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr
lr_final = 2e-4
betas = (0.9, 0.99) betas = (0.9, 0.99)
weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when we have enough data weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when we have enough data
nepoch = 50 # just a quick test. the 'epoch' here is very short
nbatchsz = 64
epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
######################################################################################################## ########################################################################################################
@ -52,7 +49,7 @@ epoch_length_fixed = 10000 # make an 'epoch' ve
print('loading data... ' + datafile) print('loading data... ' + datafile)
class Dataset(Dataset): class Dataset(Dataset):
def __init__(self, data, model_level, ctx_size): def __init__(self, data, model_level, ctx_len):
print('building token list...') print('building token list...')
if model_level == 'word': if model_level == 'word':
import re import re
@ -67,7 +64,7 @@ class Dataset(Dataset):
print('\n\ndata has %d %ss, %d unique.' % (data_size, model_level, vocab_size)) print('\n\ndata has %d %ss, %d unique.' % (data_size, model_level, vocab_size))
self.stoi = { ch:i for i,ch in enumerate(unique) } self.stoi = { ch:i for i,ch in enumerate(unique) }
self.itos = { i:ch for i,ch in enumerate(unique) } self.itos = { i:ch for i,ch in enumerate(unique) }
self.ctx_size = ctx_size self.ctx_len = ctx_len
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.data = data self.data = data
@ -75,26 +72,26 @@ class Dataset(Dataset):
return epoch_length_fixed return epoch_length_fixed
def __getitem__(self, idx): def __getitem__(self, idx):
i = np.random.randint(0, len(self.data) - (self.ctx_size + 1)) # CHEAT: pick a spot in the dataset at random i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) # CHEAT: pick a spot in the dataset at random
chunk = self.data[i:i+self.ctx_size+1] chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk] dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long) x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long) y = torch.tensor(dix[1:], dtype=torch.long)
return x, y return x, y
train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_size) train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_len)
######################################################################################################## ########################################################################################################
# Train model # Train model
######################################################################################################## ########################################################################################################
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_type=model_type, model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
n_layer=nLayers, n_head=nHead, n_embd=nEmb)) n_layer=n_layer, n_head=n_head, n_embd=n_embd))
print('model', model_type, 'total epoch', nepoch, 'batchsz', nbatchsz, 'nLayers', nLayers, 'nHead', nHead, 'nEmb', nEmb, 'len', ctx_size) print('model', model_type, 'total epoch', n_epoch, 'batch_size', batch_size, 'n_layer', n_layer, 'n_head', n_head, 'n_embd', n_embd, 'len', ctx_len)
tconf = TrainerConfig(model_type=model_type, max_epochs=nepoch, batch_size=nbatchsz, weight_decay=weight_decay, tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay,
learning_rate=lr_initial, lr_decay=True, lr_final=lr_final, betas=betas, learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas,
warmup_tokens=0, final_tokens=nepoch*len(train_dataset)*ctx_size, num_workers=0) warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0)
trainer = Trainer(model, train_dataset, None, tconf) trainer = Trainer(model, train_dataset, None, tconf)
trainer.train() trainer.train()
@ -119,8 +116,8 @@ for run in range(NUM_OF_RUNS):
x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64) x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64)
real_len = len(x) real_len = len(x)
if real_len < ctx_size: if real_len < ctx_len:
x = np.pad(x, (0, ctx_size - real_len)) x = np.pad(x, (0, ctx_len - real_len))
print_begin = 0 print_begin = 0
for i in range(LENGTH_OF_EACH): for i in range(LENGTH_OF_EACH):
@ -130,13 +127,13 @@ for run in range(NUM_OF_RUNS):
print_begin = real_len print_begin = real_len
with torch.no_grad(): with torch.no_grad():
xxx = torch.tensor(x[-ctx_size:], dtype=torch.long)[None,...].to("cuda:0") xxx = torch.tensor(x[-ctx_len:], dtype=torch.long)[None,...].to("cuda:0")
out, _ = model(xxx) out, _ = model(xxx)
pos = -1 if real_len >= ctx_size else real_len - 1 pos = -1 if real_len >= ctx_len else real_len - 1
char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) # our special sampling method char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) # our special sampling method
if real_len < ctx_size: if real_len < ctx_len:
x[real_len] = char x[real_len] = char
else: else:
x = np.append(x, char) x = np.append(x, char)

Loading…
Cancel
Save