diff --git a/.gitignore b/.gitignore index 2de160e..19616df 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ *.xlsb *.xlsx *.xls +wandb/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/src/model.py b/src/model.py index ca5adab..03021e7 100644 --- a/src/model.py +++ b/src/model.py @@ -10,7 +10,7 @@ from torch.nn import functional as F logger = logging.getLogger(__name__) ######################################################################################################## -# Block: RWKV Time-mix + RWKV Channel-mix +# RWKV: RWKV Time-mix + RWKV Channel-mix ######################################################################################################## class RWKV_TimeMix(nn.Module): @@ -18,15 +18,15 @@ class RWKV_TimeMix(nn.Module): super().__init__() assert config.n_embd % config.n_head == 0 self.layer_id = layer_id - self.ctx_size = config.ctx_size + self.ctx_len = config.ctx_len self.n_head = config.n_head self.head_size = config.n_embd // config.n_head - self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_size)) - self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_size)) - self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_size, 1)) - self.time_gamma = nn.Parameter(torch.ones(config.ctx_size, 1)) - self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size))) + self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len)) + self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len)) + self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1)) + self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1)) + self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) self.time_shift = nn.ZeroPad2d((0,0,1,0)) @@ -38,7 +38,7 @@ class RWKV_TimeMix(nn.Module): def forward(self, x): B, T, C = x.size() - TT = self.ctx_size + TT = self.ctx_len w = F.pad(self.time_w, (0, TT)) w = torch.tile(w, [TT]) w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) @@ -88,7 +88,7 @@ class RWKV_ChannelMix(nn.Module): return y ######################################################################################################## -# Block: Multi-head Attention + Rotary Encoding + GeGLU FFN +# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN ######################################################################################################## class RotaryEmbedding(torch.nn.Module): @@ -119,19 +119,20 @@ def apply_rotary_pos_emb(q, k, cos, sin): cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:] return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) -class RotaryMHA(nn.Module): - def __init__(self, config): +class MHA_rotary(nn.Module): + def __init__(self, config, layer_id): super().__init__() + self.layer_id = layer_id assert config.n_embd % config.n_head == 0 self.n_head = config.n_head - self.ctx_size = config.ctx_size + self.ctx_len = config.ctx_len self.head_size = config.n_embd // config.n_head self.query = nn.Linear(config.n_embd, config.n_embd) self.key = nn.Linear(config.n_embd, config.n_embd) self.value = nn.Linear(config.n_embd, config.n_embd) - self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size))) + self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) self.rotary_ndims = int(self.head_size * 0.5) self.rotary_emb = RotaryEmbedding(self.rotary_ndims) @@ -163,8 +164,9 @@ class RotaryMHA(nn.Module): return x class GeGLU(torch.nn.Module): - def __init__(self, config): + def __init__(self, config, layer_id): super().__init__() + self.layer_id = layer_id self.key = nn.Linear(config.n_embd, 3 * config.n_embd) self.value = nn.Linear(config.n_embd, 3 * config.n_embd) self.weight = nn.Linear(3 * config.n_embd, config.n_embd) @@ -176,22 +178,23 @@ class GeGLU(torch.nn.Module): return y ######################################################################################################## -# Block: MHA+ (with even more tricks) +# MHA_pro: with more tricks ######################################################################################################## -class RotaryMHA_Plus(nn.Module): - def __init__(self, config): +class MHA_pro(nn.Module): + def __init__(self, config, layer_id): super().__init__() + self.layer_id = layer_id assert config.n_embd % config.n_head == 0 self.n_head = config.n_head - self.ctx_size = config.ctx_size + self.ctx_len = config.ctx_len self.head_size = config.n_embd // config.n_head - self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_size)) - self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_size)) - self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_size, 1)) - self.time_gamma = nn.Parameter(torch.ones(config.ctx_size, 1)) - self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size))) + self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len)) + self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len)) + self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1)) + self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1)) + self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) self.time_shift = nn.ZeroPad2d((0,0,1,0)) self.query = nn.Linear(config.n_embd, config.n_embd) @@ -207,7 +210,7 @@ class RotaryMHA_Plus(nn.Module): def forward(self, x): B, T, C = x.size() - TT = self.ctx_size + TT = self.ctx_len w = F.pad(self.time_w, (0, TT)) w = torch.tile(w, [TT]) w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) @@ -280,9 +283,9 @@ class FixedNorm(nn.Module): ######################################################################################################## class GPTConfig: - def __init__(self, vocab_size, ctx_size, **kwargs): + def __init__(self, vocab_size, ctx_len, **kwargs): self.vocab_size = vocab_size - self.ctx_size = ctx_size + self.ctx_len = ctx_len for k,v in kwargs.items(): setattr(self, k, v) @@ -298,12 +301,12 @@ class Block(nn.Module): self.ln2 = FixedNorm(config.n_embd) self.attn = RWKV_TimeMix(config, layer_id) self.mlp = RWKV_ChannelMix(config, layer_id) - elif config.model_type == 'RotaryMHA': - self.attn = RotaryMHA(config) - self.mlp = GeGLU(config) - elif config.model_type == 'MHA-Plus': - self.attn = RotaryMHA_Plus(config) - self.mlp = RWKV_ChannelMix(config) + elif config.model_type == 'MHA_rotary': + self.attn = MHA_rotary(config, layer_id) + self.mlp = GeGLU(config, layer_id) + elif config.model_type == 'MHA_pro': + self.attn = MHA_pro(config, layer_id) + self.mlp = RWKV_ChannelMix(config, layer_id) def forward(self, x): @@ -328,31 +331,40 @@ class GPT(nn.Module): self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.ctx_size = config.ctx_size + self.ctx_len = config.ctx_len self.apply(self._init_weights) if self.config.model_type == 'RWKV': # improve orthogonal weight init + + token_diversity = pow(self.config.vocab_size / 200, 1/3) + token_diversity = 0.4 * min(max(token_diversity, 1), 2) # 200 -> 0.4, 1600 -> 0.8. ENG-char 0.4 CHN-char 0.8 + print('token_diversity', token_diversity) + ww = self.state_dict() - for k in ww: + for k in ww: if 'tok_emb' in k: if self.config.vocab_size > self.config.n_embd: ww[k] *= math.sqrt(self.config.vocab_size) else: ww[k] *= math.sqrt(self.config.n_embd) - ww[k] *= 0.4 + ww[k] *= token_diversity elif 'head.weight' in k: - ww[k] *= 0.2 + ww[k] *= token_diversity elif 'blocks.' in k: block_id = int(k.split('.')[1]) if 'receptance.weight' in k: - ww[k] *= 0.5 + ww[k] *= 0.2 # 0.2 ~ 0.5 elif 'attn.key.weight' in k: - ww[k] *= 0.2 + ww[k] *= 0.2 # 0.2 ~ 0.5 + elif 'attn.output.weight' in k: + ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 + elif 'mlp.weight.weight' in k: + ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) - def get_ctx_size(self): - return self.ctx_size + def get_ctx_len(self): + return self.ctx_len def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): @@ -403,7 +415,7 @@ class GPT(nn.Module): def forward(self, idx, targets=None): B, T = idx.size() - assert T <= self.ctx_size, "Cannot forward, model block size is exhausted." + assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len." x = self.tok_emb(idx) diff --git a/src/trainer.py b/src/trainer.py index e9618f4..fe54bbd 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -1,4 +1,4 @@ -import math +import math, sys import logging import numpy as np from tqdm.auto import tqdm @@ -8,16 +8,19 @@ from torch.optim.lr_scheduler import LambdaLR from torch.utils.data.dataloader import DataLoader logger = logging.getLogger(__name__) +print('logging to wandb... (comment it if you don\'t have wandb)') +import wandb # comment it if you don't have wandb + class TrainerConfig: max_epochs = 10 batch_size = 64 - learning_rate = 3e-4 + learning_rate = 4e-4 betas = (0.9, 0.95) grad_norm_clip = 1.0 weight_decay = 0.01 - lr_decay = False # learning rate decay params: linear warmup followed by cosine decay - warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere - final_tokens = 260e9 # (at what point we reach 10% of original LR) + lr_decay = False # linear warmup followed by cosine decay + warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper + final_tokens = 260e9 # at which point do we reach lr_final ckpt_path = None num_workers = 0 # for DataLoader @@ -33,6 +36,12 @@ class Trainer: self.test_dataset = test_dataset self.config = config self.avg_loss = -1 + self.steps = 0 + + if 'wandb' in sys.modules: + cfg = model.config + run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) + wandb.init(project="RWKV-LM", name=run_name + '-' + wandb.util.generate_id(), config=config, save_code=False) # take over whatever gpus are on the system self.device = 'cpu' @@ -101,6 +110,11 @@ class Trainer: # report progress now_loss = loss.item() + + if 'wandb' in sys.modules: + wandb.log({"loss": now_loss}, step = self.steps * self.config.batch_size) + self.steps += 1 + if self.avg_loss < 0: self.avg_loss = now_loss else: diff --git a/train.py b/train.py index 2734b4e..6ba94e1 100644 --- a/train.py +++ b/train.py @@ -15,10 +15,10 @@ set_seed(42) np.set_printoptions(precision=4, suppress=True, linewidth=200) logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,) -# RWKV is our proposed model - fastest when the ctx window is long - good performance -# RotaryMHA is usual Multi-head Attention + Rotary Encoding + GeGLU FFN -# MHA-Plus is a bit slow (lots of tricks), with excellent performance -model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' or 'MHA-Plus' +# RWKV - our new model - fastest when ctx_len is long - VRAM friendly - good performance +# MHA_rotary - usual Multi-head Attention+Rotary+GeGLU - not as good +# MHA_pro - slow (lots of tricks) - VRAM hungry - good performance +model_type = 'RWKV' # 'RWKV' or 'MHA_rotary' or 'MHA_pro' datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip datafile_encoding = 'utf-8' @@ -27,23 +27,20 @@ datafile_encoding = 'utf-8' model_level = 'character' # 'character' or 'word' -ctx_size = 256 if model_level == 'character' else 128 -nLayers = 5 -nHead = 8 -nEmb = nHead * 64 +ctx_len = 256 # length of ctx window +n_layer = 5 +n_head = 8 +n_embd = n_head * 64 -lr_initial = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr -lr_final = 2e-4 +batch_size = 64 -lr_initial /= math.sqrt(nLayers / 5) # lower lr for deep models; higher lr for shallow models -lr_final /= math.sqrt(nLayers / 5) +n_epoch = 50 # the 'epoch' here is very short +lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr +lr_final = 2e-4 betas = (0.9, 0.99) -weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when we have enough data - -nepoch = 50 # just a quick test. the 'epoch' here is very short -nbatchsz = 64 -epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress +weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when we have enough data +epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress ######################################################################################################## # Load data @@ -52,7 +49,7 @@ epoch_length_fixed = 10000 # make an 'epoch' ve print('loading data... ' + datafile) class Dataset(Dataset): - def __init__(self, data, model_level, ctx_size): + def __init__(self, data, model_level, ctx_len): print('building token list...') if model_level == 'word': import re @@ -67,7 +64,7 @@ class Dataset(Dataset): print('\n\ndata has %d %ss, %d unique.' % (data_size, model_level, vocab_size)) self.stoi = { ch:i for i,ch in enumerate(unique) } self.itos = { i:ch for i,ch in enumerate(unique) } - self.ctx_size = ctx_size + self.ctx_len = ctx_len self.vocab_size = vocab_size self.data = data @@ -75,26 +72,26 @@ class Dataset(Dataset): return epoch_length_fixed def __getitem__(self, idx): - i = np.random.randint(0, len(self.data) - (self.ctx_size + 1)) # CHEAT: pick a spot in the dataset at random - chunk = self.data[i:i+self.ctx_size+1] + i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) # CHEAT: pick a spot in the dataset at random + chunk = self.data[i:i+self.ctx_len+1] dix = [self.stoi[s] for s in chunk] x = torch.tensor(dix[:-1], dtype=torch.long) y = torch.tensor(dix[1:], dtype=torch.long) return x, y -train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_size) +train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_len) ######################################################################################################## # Train model ######################################################################################################## -model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_type=model_type, - n_layer=nLayers, n_head=nHead, n_embd=nEmb)) +model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type, + n_layer=n_layer, n_head=n_head, n_embd=n_embd)) -print('model', model_type, 'total epoch', nepoch, 'batchsz', nbatchsz, 'nLayers', nLayers, 'nHead', nHead, 'nEmb', nEmb, 'len', ctx_size) -tconf = TrainerConfig(model_type=model_type, max_epochs=nepoch, batch_size=nbatchsz, weight_decay=weight_decay, - learning_rate=lr_initial, lr_decay=True, lr_final=lr_final, betas=betas, - warmup_tokens=0, final_tokens=nepoch*len(train_dataset)*ctx_size, num_workers=0) +print('model', model_type, 'total epoch', n_epoch, 'batch_size', batch_size, 'n_layer', n_layer, 'n_head', n_head, 'n_embd', n_embd, 'len', ctx_len) +tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay, + learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, + warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0) trainer = Trainer(model, train_dataset, None, tconf) trainer.train() @@ -119,8 +116,8 @@ for run in range(NUM_OF_RUNS): x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64) real_len = len(x) - if real_len < ctx_size: - x = np.pad(x, (0, ctx_size - real_len)) + if real_len < ctx_len: + x = np.pad(x, (0, ctx_len - real_len)) print_begin = 0 for i in range(LENGTH_OF_EACH): @@ -130,13 +127,13 @@ for run in range(NUM_OF_RUNS): print_begin = real_len with torch.no_grad(): - xxx = torch.tensor(x[-ctx_size:], dtype=torch.long)[None,...].to("cuda:0") + xxx = torch.tensor(x[-ctx_len:], dtype=torch.long)[None,...].to("cuda:0") out, _ = model(xxx) - pos = -1 if real_len >= ctx_size else real_len - 1 + pos = -1 if real_len >= ctx_len else real_len - 1 char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) # our special sampling method - if real_len < ctx_size: + if real_len < ctx_len: x[real_len] = char else: x = np.append(x, char)