add wandb, and rename variables

main
BlinkDL 4 years ago
parent 440bebff1a
commit 3b60c5b266

1
.gitignore vendored

@ -4,6 +4,7 @@
*.xlsb
*.xlsx
*.xls
wandb/
# Byte-compiled / optimized / DLL files
__pycache__/

@ -10,7 +10,7 @@ from torch.nn import functional as F
logger = logging.getLogger(__name__)
########################################################################################################
# Block: RWKV Time-mix + RWKV Channel-mix
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
class RWKV_TimeMix(nn.Module):
@ -18,15 +18,15 @@ class RWKV_TimeMix(nn.Module):
super().__init__()
assert config.n_embd % config.n_head == 0
self.layer_id = layer_id
self.ctx_size = config.ctx_size
self.ctx_len = config.ctx_len
self.n_head = config.n_head
self.head_size = config.n_embd // config.n_head
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_size))
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_size))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_size, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_size, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size)))
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.time_shift = nn.ZeroPad2d((0,0,1,0))
@ -38,7 +38,7 @@ class RWKV_TimeMix(nn.Module):
def forward(self, x):
B, T, C = x.size()
TT = self.ctx_size
TT = self.ctx_len
w = F.pad(self.time_w, (0, TT))
w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
@ -88,7 +88,7 @@ class RWKV_ChannelMix(nn.Module):
return y
########################################################################################################
# Block: Multi-head Attention + Rotary Encoding + GeGLU FFN
# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
########################################################################################################
class RotaryEmbedding(torch.nn.Module):
@ -119,19 +119,20 @@ def apply_rotary_pos_emb(q, k, cos, sin):
cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:]
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
class RotaryMHA(nn.Module):
def __init__(self, config):
class MHA_rotary(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.ctx_size = config.ctx_size
self.ctx_len = config.ctx_len
self.head_size = config.n_embd // config.n_head
self.query = nn.Linear(config.n_embd, config.n_embd)
self.key = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size)))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.rotary_ndims = int(self.head_size * 0.5)
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
@ -163,8 +164,9 @@ class RotaryMHA(nn.Module):
return x
class GeGLU(torch.nn.Module):
def __init__(self, config):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.key = nn.Linear(config.n_embd, 3 * config.n_embd)
self.value = nn.Linear(config.n_embd, 3 * config.n_embd)
self.weight = nn.Linear(3 * config.n_embd, config.n_embd)
@ -176,22 +178,23 @@ class GeGLU(torch.nn.Module):
return y
########################################################################################################
# Block: MHA+ (with even more tricks)
# MHA_pro: with more tricks
########################################################################################################
class RotaryMHA_Plus(nn.Module):
def __init__(self, config):
class MHA_pro(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.ctx_size = config.ctx_size
self.ctx_len = config.ctx_len
self.head_size = config.n_embd // config.n_head
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_size))
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_size))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_size, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_size, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size)))
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.time_shift = nn.ZeroPad2d((0,0,1,0))
self.query = nn.Linear(config.n_embd, config.n_embd)
@ -207,7 +210,7 @@ class RotaryMHA_Plus(nn.Module):
def forward(self, x):
B, T, C = x.size()
TT = self.ctx_size
TT = self.ctx_len
w = F.pad(self.time_w, (0, TT))
w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
@ -280,9 +283,9 @@ class FixedNorm(nn.Module):
########################################################################################################
class GPTConfig:
def __init__(self, vocab_size, ctx_size, **kwargs):
def __init__(self, vocab_size, ctx_len, **kwargs):
self.vocab_size = vocab_size
self.ctx_size = ctx_size
self.ctx_len = ctx_len
for k,v in kwargs.items():
setattr(self, k, v)
@ -298,12 +301,12 @@ class Block(nn.Module):
self.ln2 = FixedNorm(config.n_embd)
self.attn = RWKV_TimeMix(config, layer_id)
self.mlp = RWKV_ChannelMix(config, layer_id)
elif config.model_type == 'RotaryMHA':
self.attn = RotaryMHA(config)
self.mlp = GeGLU(config)
elif config.model_type == 'MHA-Plus':
self.attn = RotaryMHA_Plus(config)
self.mlp = RWKV_ChannelMix(config)
elif config.model_type == 'MHA_rotary':
self.attn = MHA_rotary(config, layer_id)
self.mlp = GeGLU(config, layer_id)
elif config.model_type == 'MHA_pro':
self.attn = MHA_pro(config, layer_id)
self.mlp = RWKV_ChannelMix(config, layer_id)
def forward(self, x):
@ -328,10 +331,15 @@ class GPT(nn.Module):
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.ctx_size = config.ctx_size
self.ctx_len = config.ctx_len
self.apply(self._init_weights)
if self.config.model_type == 'RWKV': # improve orthogonal weight init
token_diversity = pow(self.config.vocab_size / 200, 1/3)
token_diversity = 0.4 * min(max(token_diversity, 1), 2) # 200 -> 0.4, 1600 -> 0.8. ENG-char 0.4 CHN-char 0.8
print('token_diversity', token_diversity)
ww = self.state_dict()
for k in ww:
if 'tok_emb' in k:
@ -339,20 +347,24 @@ class GPT(nn.Module):
ww[k] *= math.sqrt(self.config.vocab_size)
else:
ww[k] *= math.sqrt(self.config.n_embd)
ww[k] *= 0.4
ww[k] *= token_diversity
elif 'head.weight' in k:
ww[k] *= 0.2
ww[k] *= token_diversity
elif 'blocks.' in k:
block_id = int(k.split('.')[1])
if 'receptance.weight' in k:
ww[k] *= 0.5
ww[k] *= 0.2 # 0.2 ~ 0.5
elif 'attn.key.weight' in k:
ww[k] *= 0.2
ww[k] *= 0.2 # 0.2 ~ 0.5
elif 'attn.output.weight' in k:
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7
elif 'mlp.weight.weight' in k:
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def get_ctx_size(self):
return self.ctx_size
def get_ctx_len(self):
return self.ctx_len
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
@ -403,7 +415,7 @@ class GPT(nn.Module):
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.ctx_size, "Cannot forward, model block size is exhausted."
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.tok_emb(idx)

@ -1,4 +1,4 @@
import math
import math, sys
import logging
import numpy as np
from tqdm.auto import tqdm
@ -8,16 +8,19 @@ from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader
logger = logging.getLogger(__name__)
print('logging to wandb... (comment it if you don\'t have wandb)')
import wandb # comment it if you don't have wandb
class TrainerConfig:
max_epochs = 10
batch_size = 64
learning_rate = 3e-4
learning_rate = 4e-4
betas = (0.9, 0.95)
grad_norm_clip = 1.0
weight_decay = 0.01
lr_decay = False # learning rate decay params: linear warmup followed by cosine decay
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
final_tokens = 260e9 # (at what point we reach 10% of original LR)
lr_decay = False # linear warmup followed by cosine decay
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper
final_tokens = 260e9 # at which point do we reach lr_final
ckpt_path = None
num_workers = 0 # for DataLoader
@ -33,6 +36,12 @@ class Trainer:
self.test_dataset = test_dataset
self.config = config
self.avg_loss = -1
self.steps = 0
if 'wandb' in sys.modules:
cfg = model.config
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
wandb.init(project="RWKV-LM", name=run_name + '-' + wandb.util.generate_id(), config=config, save_code=False)
# take over whatever gpus are on the system
self.device = 'cpu'
@ -101,6 +110,11 @@ class Trainer:
# report progress
now_loss = loss.item()
if 'wandb' in sys.modules:
wandb.log({"loss": now_loss}, step = self.steps * self.config.batch_size)
self.steps += 1
if self.avg_loss < 0:
self.avg_loss = now_loss
else:

@ -15,10 +15,10 @@ set_seed(42)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,)
# RWKV is our proposed model - fastest when the ctx window is long - good performance
# RotaryMHA is usual Multi-head Attention + Rotary Encoding + GeGLU FFN
# MHA-Plus is a bit slow (lots of tricks), with excellent performance
model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' or 'MHA-Plus'
# RWKV - our new model - fastest when ctx_len is long - VRAM friendly - good performance
# MHA_rotary - usual Multi-head Attention+Rotary+GeGLU - not as good
# MHA_pro - slow (lots of tricks) - VRAM hungry - good performance
model_type = 'RWKV' # 'RWKV' or 'MHA_rotary' or 'MHA_pro'
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip
datafile_encoding = 'utf-8'
@ -27,22 +27,19 @@ datafile_encoding = 'utf-8'
model_level = 'character' # 'character' or 'word'
ctx_size = 256 if model_level == 'character' else 128
nLayers = 5
nHead = 8
nEmb = nHead * 64
ctx_len = 256 # length of ctx window
n_layer = 5
n_head = 8
n_embd = n_head * 64
lr_initial = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr
lr_final = 2e-4
batch_size = 64
lr_initial /= math.sqrt(nLayers / 5) # lower lr for deep models; higher lr for shallow models
lr_final /= math.sqrt(nLayers / 5)
n_epoch = 50 # the 'epoch' here is very short
lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr
lr_final = 2e-4
betas = (0.9, 0.99)
weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when we have enough data
nepoch = 50 # just a quick test. the 'epoch' here is very short
nbatchsz = 64
epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
########################################################################################################
@ -52,7 +49,7 @@ epoch_length_fixed = 10000 # make an 'epoch' ve
print('loading data... ' + datafile)
class Dataset(Dataset):
def __init__(self, data, model_level, ctx_size):
def __init__(self, data, model_level, ctx_len):
print('building token list...')
if model_level == 'word':
import re
@ -67,7 +64,7 @@ class Dataset(Dataset):
print('\n\ndata has %d %ss, %d unique.' % (data_size, model_level, vocab_size))
self.stoi = { ch:i for i,ch in enumerate(unique) }
self.itos = { i:ch for i,ch in enumerate(unique) }
self.ctx_size = ctx_size
self.ctx_len = ctx_len
self.vocab_size = vocab_size
self.data = data
@ -75,26 +72,26 @@ class Dataset(Dataset):
return epoch_length_fixed
def __getitem__(self, idx):
i = np.random.randint(0, len(self.data) - (self.ctx_size + 1)) # CHEAT: pick a spot in the dataset at random
chunk = self.data[i:i+self.ctx_size+1]
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) # CHEAT: pick a spot in the dataset at random
chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
return x, y
train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_size)
train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_len)
########################################################################################################
# Train model
########################################################################################################
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_type=model_type,
n_layer=nLayers, n_head=nHead, n_embd=nEmb))
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
n_layer=n_layer, n_head=n_head, n_embd=n_embd))
print('model', model_type, 'total epoch', nepoch, 'batchsz', nbatchsz, 'nLayers', nLayers, 'nHead', nHead, 'nEmb', nEmb, 'len', ctx_size)
tconf = TrainerConfig(model_type=model_type, max_epochs=nepoch, batch_size=nbatchsz, weight_decay=weight_decay,
learning_rate=lr_initial, lr_decay=True, lr_final=lr_final, betas=betas,
warmup_tokens=0, final_tokens=nepoch*len(train_dataset)*ctx_size, num_workers=0)
print('model', model_type, 'total epoch', n_epoch, 'batch_size', batch_size, 'n_layer', n_layer, 'n_head', n_head, 'n_embd', n_embd, 'len', ctx_len)
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas,
warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0)
trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()
@ -119,8 +116,8 @@ for run in range(NUM_OF_RUNS):
x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64)
real_len = len(x)
if real_len < ctx_size:
x = np.pad(x, (0, ctx_size - real_len))
if real_len < ctx_len:
x = np.pad(x, (0, ctx_len - real_len))
print_begin = 0
for i in range(LENGTH_OF_EACH):
@ -130,13 +127,13 @@ for run in range(NUM_OF_RUNS):
print_begin = real_len
with torch.no_grad():
xxx = torch.tensor(x[-ctx_size:], dtype=torch.long)[None,...].to("cuda:0")
xxx = torch.tensor(x[-ctx_len:], dtype=torch.long)[None,...].to("cuda:0")
out, _ = model(xxx)
pos = -1 if real_len >= ctx_size else real_len - 1
pos = -1 if real_len >= ctx_len else real_len - 1
char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) # our special sampling method
if real_len < ctx_size:
if real_len < ctx_len:
x[real_len] = char
else:
x = np.append(x, char)

Loading…
Cancel
Save