diff --git a/.gitignore b/.gitignore index b6e4761..2de160e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,10 @@ +*.txt +*.csv +*.pth +*.xlsb +*.xlsx +*.xls + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/RWKV-vs-MHA.png b/RWKV-vs-MHA.png new file mode 100644 index 0000000..21fad12 Binary files /dev/null and b/RWKV-vs-MHA.png differ diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..40fdbe9 --- /dev/null +++ b/src/model.py @@ -0,0 +1,290 @@ +import math +import logging +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +logger = logging.getLogger(__name__) + +######################################################################################################## +# Block: RWKV Time-mix + RWKV Channel-mix +######################################################################################################## + +class RWKV_TimeMix(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + self.ctx_size = config.ctx_size + self.n_head = config.n_head + self.head_size = config.n_embd // config.n_head + + self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_size)) + self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_size)) + self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_size, 1)) + self.time_gamma = nn.Parameter(torch.ones(config.ctx_size, 1)) + self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size))) + + self.time_shift = nn.ZeroPad2d((0,0,1,0)) + + self.key = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + self.receptance = nn.Linear(config.n_embd, config.n_embd) + + self.output = nn.Linear(config.n_embd, config.n_embd) + + def forward(self, x): + B, T, C = x.size() + TT = self.ctx_size + w = F.pad(self.time_w, (0, TT)) + w = torch.tile(w, [TT]) + w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) + w = w[:, :, TT-1:] # w is now a circulant matrix + w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :] + w = w.masked_fill(self.mask[:T, :T] == 0, 0) + + x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) + k = self.key(x) + v = self.value(x) + r = self.receptance(x) + + k = torch.exp(k) + sum_k = torch.cumsum(k, dim=1) + + k = k.view(B, T, self.n_head, self.head_size) + v = v.view(B, T, self.n_head, self.head_size) + + wkv = (torch.einsum('htu,buhc->bthc', w, k * v)).contiguous().view(B, T, C) + y = torch.sigmoid(r) * wkv / sum_k + + y = self.output(y) * self.time_gamma[:T, :] + return y + +class RWKV_ChannelMix(nn.Module): + def __init__(self, config): + super().__init__() + self.time_shift = nn.ZeroPad2d((0,0,1,0)) + + self.key = nn.Linear(config.n_embd, 3 * config.n_embd) + self.value = nn.Linear(config.n_embd, 3 * config.n_embd) + self.weight = nn.Linear(3 * config.n_embd, config.n_embd) + self.receptance = nn.Linear(config.n_embd, config.n_embd) + + def forward(self, x): + B, T, C = x.size() + + x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) + k = self.key(x) + v = self.value(x) + r = self.receptance(x) + + wkv = self.weight(F.gelu(k) * v) + y = torch.sigmoid(r) * wkv + + return y + +######################################################################################################## +# Block: Multi-head Attention + Rotary Encoding + GeGLU FFN +######################################################################################################## + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + self.seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + + def forward(self, x, seq_len=None): + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange(seq_len, device=x.device) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() + return self.cos_cached, self.sin_cached + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), -1) + +@torch.jit.script +def apply_rotary_pos_emb(q, k, cos, sin): + cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:] + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + +class RotaryMHA(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + self.n_head = config.n_head + self.ctx_size = config.ctx_size + self.head_size = config.n_embd // config.n_head + + self.query = nn.Linear(config.n_embd, config.n_embd) + self.key = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + + self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size))) + + self.rotary_ndims = int(self.head_size * 0.5) + self.rotary_emb = RotaryEmbedding(self.rotary_ndims) + + self.output = nn.Linear(config.n_embd, config.n_embd) + + def forward(self, x): + B, T, C = x.size() + + q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) + k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) + + q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:] + k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:] + cos, sin = self.rotary_emb(q, seq_len=T) + q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding + q = torch.cat((q, query_pass), dim=-1) + k = torch.cat((k, key_pass), dim=-1) + + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) + att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask + att = F.softmax(att, dim = -1) # softmax + + x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) + x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) + + x = self.output(x) # output projection + return x + +class GeGLU(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.key = nn.Linear(config.n_embd, 3 * config.n_embd) + self.value = nn.Linear(config.n_embd, 3 * config.n_embd) + self.weight = nn.Linear(3 * config.n_embd, config.n_embd) + + def forward(self, x): + k = self.key(x) + v = self.value(x) + y = self.weight(F.gelu(k) * v) + return y + +######################################################################################################## +# The GPT Model with our blocks +######################################################################################################## + +class LabelSmoothingCrossEntropy(nn.Module): # might be able to avoid nan loss + def __init__(self, smoothing=0.0): + super().__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + + def forward(self, pred, target): + pred = pred.log_softmax(dim=-1) + with torch.no_grad(): + true_dist = torch.zeros_like(pred) + true_dist.fill_(self.smoothing / (pred.size(-1) - 1)) + true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) + return torch.mean(torch.sum(-true_dist * pred, dim=-1)) + +class GPTConfig: + def __init__(self, vocab_size, ctx_size, **kwargs): + self.vocab_size = vocab_size + self.ctx_size = ctx_size + for k,v in kwargs.items(): + setattr(self, k, v) + +class Block(nn.Module): + def __init__(self, config): + super().__init__() + + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + + if config.model_type == 'RWKV': + self.attn = RWKV_TimeMix(config) + self.mlp = RWKV_ChannelMix(config) + else: + self.attn = RotaryMHA(config) + self.mlp = GeGLU(config) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + +class GPT(nn.Module): + def __init__(self, config): + super().__init__() + + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + + self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) + + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + self.ctx_size = config.ctx_size + self.apply(self._init_weights) + + logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + + def get_ctx_size(self): + return self.ctx_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.01) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def configure_optimizers(self, train_config): + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + + whitelist_weight_modules = (nn.Linear, ) + blacklist_weight_modules = (nn.LayerNorm, nn.Embedding) + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + + if pn.endswith('bias') or ('time' in fpn) or ('head' in fpn): + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) + return optimizer + + def forward(self, idx, targets=None): + B, T = idx.size() + assert T <= self.ctx_size, "Cannot forward, model block size is exhausted." + + x = self.tok_emb(idx) + + x = self.blocks(x) + + x = self.ln_f(x) + logits = self.head(x) + + loss = None + if targets is not None: + loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(logits.view(-1, logits.size(-1)), targets.view(-1)) + + return logits, loss diff --git a/src/trainer.py b/src/trainer.py new file mode 100644 index 0000000..e9618f4 --- /dev/null +++ b/src/trainer.py @@ -0,0 +1,128 @@ +import math +import logging +import numpy as np +from tqdm.auto import tqdm +import torch +import torch.optim as optim +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data.dataloader import DataLoader +logger = logging.getLogger(__name__) + +class TrainerConfig: + max_epochs = 10 + batch_size = 64 + learning_rate = 3e-4 + betas = (0.9, 0.95) + grad_norm_clip = 1.0 + weight_decay = 0.01 + lr_decay = False # learning rate decay params: linear warmup followed by cosine decay + warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere + final_tokens = 260e9 # (at what point we reach 10% of original LR) + ckpt_path = None + num_workers = 0 # for DataLoader + + def __init__(self, **kwargs): + for k,v in kwargs.items(): + setattr(self, k, v) + +class Trainer: + + def __init__(self, model, train_dataset, test_dataset, config): + self.model = model + self.train_dataset = train_dataset + self.test_dataset = test_dataset + self.config = config + self.avg_loss = -1 + + # take over whatever gpus are on the system + self.device = 'cpu' + if torch.cuda.is_available(): + self.device = torch.cuda.current_device() + self.model = torch.nn.DataParallel(self.model).to(self.device) + + def save_checkpoint(self): + # DataParallel wrappers keep raw model object in .module attribute + raw_model = self.model.module if hasattr(self.model, "module") else self.model + logger.info("saving %s", self.config.ckpt_path) + torch.save(raw_model.state_dict(), self.config.ckpt_path) + + def train(self): + model, config = self.model, self.config + raw_model = model.module if hasattr(self.model, "module") else model + optimizer = raw_model.configure_optimizers(config) + + def run_epoch(split): + is_train = split == 'train' + model.train(is_train) + data = self.train_dataset if is_train else self.test_dataset + loader = DataLoader(data, shuffle=True, pin_memory=True, + batch_size=config.batch_size, + num_workers=config.num_workers) + + losses = [] + pbar = tqdm(enumerate(loader), total=len(loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) + for it, (x, y) in pbar: + + # place data on the correct device + x = x.to(self.device) + y = y.to(self.device) + + # forward the model + with torch.set_grad_enabled(is_train): + logits, loss = model(x, y) + loss = loss.mean() # collapse all losses if they are scattered on multiple gpus + losses.append(loss.item()) + + if is_train: + + # backprop and update the parameters + model.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) + optimizer.step() + + # decay the learning rate based on our progress + if config.lr_decay: + self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) + if self.tokens < config.warmup_tokens: + # linear warmup + lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens)) + progress = 0 + else: + # cosine learning rate decay + progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) + lr_final_factor = config.lr_final / config.learning_rate + lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1 + lr = config.learning_rate * lr_mult + for param_group in optimizer.param_groups: + param_group['lr'] = lr + else: + lr = config.learning_rate + + # report progress + now_loss = loss.item() + if self.avg_loss < 0: + self.avg_loss = now_loss + else: + factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1)) + self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor + pbar.set_description(f"epoch {epoch+1} progress {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}") + + if not is_train: + test_loss = float(np.mean(losses)) + logger.info("test loss: %f", test_loss) + return test_loss + + best_loss = float('inf') + self.tokens = 0 # counter used for learning rate decay + for epoch in range(config.max_epochs): + + run_epoch('train') + if self.test_dataset is not None: + test_loss = run_epoch('test') + + # supports early stopping based on the test loss, or just save always if no test set is provided + good_model = self.test_dataset is None or test_loss < best_loss + if self.config.ckpt_path is not None and good_model: + best_loss = test_loss + self.save_checkpoint() diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..6192589 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,46 @@ +import random +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +def top_k_logits(logits, k): + v, ix = torch.topk(logits, k) + out = logits.clone() + out[out < v[:, [-1]]] = -float('Inf') + return out + +def top_p_probs(probs, p): + out = probs.clone() + + sorted_probs, sorted_indices = torch.sort(out, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + sorted_indices_to_remove = cumulative_probs > p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + out[indices_to_remove] = 0 + + return out + +# top-p + top-k + pow&ratio sampling +def sample_logits(logits, pos, temperature=1.0, top_k=None, top_p=None, min_p_pow=None, min_p_ratio=None): + logits = logits[:, pos, :] / temperature + probs = F.softmax(logits, dim=-1) + if min_p_ratio is not None: + limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio + logits[probs < limit] = -float('Inf') + if top_k is not None: + logits = top_k_logits(logits, top_k) + probs = F.softmax(logits, dim=-1) + if top_p is not None: + probs[0] = top_p_probs(probs[0], top_p) + ix = torch.multinomial(probs, num_samples=1) + + return ix[0][0].cpu() + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/train.py b/train.py new file mode 100644 index 0000000..82e5802 --- /dev/null +++ b/train.py @@ -0,0 +1,117 @@ +import os, sys, time, math, random, json, datetime +import logging +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.data import Dataset +from src.trainer import Trainer, TrainerConfig +from src.model import GPT, GPTConfig +from src.utils import set_seed + +set_seed(42) +np.set_printoptions(precision=4, suppress=True, linewidth=200) +logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,) + +model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' + +datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip +model_level = 'character' # 'character' or 'word' + +ctx_size = 256 if 'character' else 128 +nLayers = 5 +nHead = 8 +nEmb = 512 + +nepoch = 50 +nbatchsz = 64 +epoch_length_fixed = 10000 # make an epoch very short, so we can see the training progress + +######################################################################################################## + +print("loading data...", end="") + +class Dataset(Dataset): + def __init__(self, data, model_level, ctx_size): + if model_level == 'word': + data = data.replace('\n', ' \n ').replace(' ', ' ').split(' ') + + unique = sorted(list(set(data))) + data_size, vocab_size = len(data), len(unique) + self.stoi = { ch:i for i,ch in enumerate(unique) } + self.itos = { i:ch for i,ch in enumerate(unique) } + print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size)) + self.ctx_size = ctx_size + self.vocab_size = vocab_size + self.data = data + + def __len__(self): + return epoch_length_fixed + + def __getitem__(self, idx): + i = np.random.randint(0, len(self.data) - (self.ctx_size + 1)) # CHEAT: pick a spot in the dataset at random + chunk = self.data[i:i+self.ctx_size+1] + dix = [self.stoi[s] for s in chunk] + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) + return x, y + +train_dataset = Dataset(open(datafile, "r", encoding="utf-8").read(), model_level, ctx_size) + +######################################################################################################## + +model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_type=model_type, + n_layer=nLayers, n_head=nHead, n_embd=nEmb)) + +print('model', model_type, 'total epoch', nepoch, 'batchsz', nbatchsz, 'nLayers', nLayers, 'nHead', nHead, 'nEmb', nEmb, 'len', ctx_size) +tconf = TrainerConfig(model_type=model_type, max_epochs=nepoch, batch_size=nbatchsz, + learning_rate=6e-4 if model_type == 'RWKV' else 4e-4, betas=(0.9, 0.99), # RWKV can use higher LR + lr_decay=True, lr_final=2e-4, warmup_tokens=0, final_tokens=nepoch*len(train_dataset)*ctx_size, num_workers=0) +trainer = Trainer(model, train_dataset, None, tconf) + +trainer.train() + +torch.save(model, 'trained-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') + +######################################################################################################## + +from src.utils import sample_logits + +MAX_LEN = ctx_size +NUM_OF_RUNS = 5 +LENGTH_OF_EACH = 300 + +for run in range(NUM_OF_RUNS): + context = "It was" + + x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64) + + real_len = len(x) + if real_len < MAX_LEN: + x = np.pad(x, (0, MAX_LEN - real_len)) + print_begin = 0 + + for i in range(LENGTH_OF_EACH): + + if i == 0: + print(('-' * 80) + '\n' + context, end = '') + print_begin = real_len + + with torch.no_grad(): + xxx = torch.tensor(x[-MAX_LEN:], dtype=torch.long)[None,...].to("cuda:0") + out, _ = model(xxx) + pos = -1 if real_len >= MAX_LEN else real_len - 1 + + char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) + + if real_len < MAX_LEN: + x[real_len] = char + else: + x = np.append(x, char) + real_len += 1 + + if i % 10 == 9 or i == LENGTH_OF_EACH-1: + completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]]) + print(completion, end = '') + print_begin = real_len + print()