first commit
parent
d21af78c97
commit
aa4e2a68f4
Binary file not shown.
|
After Width: | Height: | Size: 9.4 KiB |
@ -0,0 +1,290 @@
|
|||||||
|
import math
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
# Block: RWKV Time-mix + RWKV Channel-mix
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
class RWKV_TimeMix(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
assert config.n_embd % config.n_head == 0
|
||||||
|
self.ctx_size = config.ctx_size
|
||||||
|
self.n_head = config.n_head
|
||||||
|
self.head_size = config.n_embd // config.n_head
|
||||||
|
|
||||||
|
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_size))
|
||||||
|
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_size))
|
||||||
|
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_size, 1))
|
||||||
|
self.time_gamma = nn.Parameter(torch.ones(config.ctx_size, 1))
|
||||||
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size)))
|
||||||
|
|
||||||
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
||||||
|
|
||||||
|
self.key = nn.Linear(config.n_embd, config.n_embd)
|
||||||
|
self.value = nn.Linear(config.n_embd, config.n_embd)
|
||||||
|
self.receptance = nn.Linear(config.n_embd, config.n_embd)
|
||||||
|
|
||||||
|
self.output = nn.Linear(config.n_embd, config.n_embd)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, T, C = x.size()
|
||||||
|
TT = self.ctx_size
|
||||||
|
w = F.pad(self.time_w, (0, TT))
|
||||||
|
w = torch.tile(w, [TT])
|
||||||
|
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
|
||||||
|
w = w[:, :, TT-1:] # w is now a circulant matrix
|
||||||
|
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
|
||||||
|
w = w.masked_fill(self.mask[:T, :T] == 0, 0)
|
||||||
|
|
||||||
|
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
|
||||||
|
k = self.key(x)
|
||||||
|
v = self.value(x)
|
||||||
|
r = self.receptance(x)
|
||||||
|
|
||||||
|
k = torch.exp(k)
|
||||||
|
sum_k = torch.cumsum(k, dim=1)
|
||||||
|
|
||||||
|
k = k.view(B, T, self.n_head, self.head_size)
|
||||||
|
v = v.view(B, T, self.n_head, self.head_size)
|
||||||
|
|
||||||
|
wkv = (torch.einsum('htu,buhc->bthc', w, k * v)).contiguous().view(B, T, C)
|
||||||
|
y = torch.sigmoid(r) * wkv / sum_k
|
||||||
|
|
||||||
|
y = self.output(y) * self.time_gamma[:T, :]
|
||||||
|
return y
|
||||||
|
|
||||||
|
class RWKV_ChannelMix(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
||||||
|
|
||||||
|
self.key = nn.Linear(config.n_embd, 3 * config.n_embd)
|
||||||
|
self.value = nn.Linear(config.n_embd, 3 * config.n_embd)
|
||||||
|
self.weight = nn.Linear(3 * config.n_embd, config.n_embd)
|
||||||
|
self.receptance = nn.Linear(config.n_embd, config.n_embd)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, T, C = x.size()
|
||||||
|
|
||||||
|
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
|
||||||
|
k = self.key(x)
|
||||||
|
v = self.value(x)
|
||||||
|
r = self.receptance(x)
|
||||||
|
|
||||||
|
wkv = self.weight(F.gelu(k) * v)
|
||||||
|
y = torch.sigmoid(r) * wkv
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
# Block: Multi-head Attention + Rotary Encoding + GeGLU FFN
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
class RotaryEmbedding(torch.nn.Module):
|
||||||
|
def __init__(self, dim, base=10000):
|
||||||
|
super().__init__()
|
||||||
|
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||||
|
self.register_buffer('inv_freq', inv_freq)
|
||||||
|
self.seq_len_cached = None
|
||||||
|
self.cos_cached = None
|
||||||
|
self.sin_cached = None
|
||||||
|
|
||||||
|
def forward(self, x, seq_len=None):
|
||||||
|
if seq_len != self.seq_len_cached:
|
||||||
|
self.seq_len_cached = seq_len
|
||||||
|
t = torch.arange(seq_len, device=x.device)
|
||||||
|
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||||
|
self.cos_cached = emb.cos()
|
||||||
|
self.sin_cached = emb.sin()
|
||||||
|
return self.cos_cached, self.sin_cached
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
||||||
|
return torch.cat((-x2, x1), -1)
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin):
|
||||||
|
cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:]
|
||||||
|
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
||||||
|
|
||||||
|
class RotaryMHA(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
assert config.n_embd % config.n_head == 0
|
||||||
|
self.n_head = config.n_head
|
||||||
|
self.ctx_size = config.ctx_size
|
||||||
|
self.head_size = config.n_embd // config.n_head
|
||||||
|
|
||||||
|
self.query = nn.Linear(config.n_embd, config.n_embd)
|
||||||
|
self.key = nn.Linear(config.n_embd, config.n_embd)
|
||||||
|
self.value = nn.Linear(config.n_embd, config.n_embd)
|
||||||
|
|
||||||
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size)))
|
||||||
|
|
||||||
|
self.rotary_ndims = int(self.head_size * 0.5)
|
||||||
|
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
|
||||||
|
|
||||||
|
self.output = nn.Linear(config.n_embd, config.n_embd)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, T, C = x.size()
|
||||||
|
|
||||||
|
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
||||||
|
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
||||||
|
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
||||||
|
|
||||||
|
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
|
||||||
|
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
|
||||||
|
cos, sin = self.rotary_emb(q, seq_len=T)
|
||||||
|
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
|
||||||
|
q = torch.cat((q, query_pass), dim=-1)
|
||||||
|
k = torch.cat((k, key_pass), dim=-1)
|
||||||
|
|
||||||
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
|
||||||
|
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
|
||||||
|
att = F.softmax(att, dim = -1) # softmax
|
||||||
|
|
||||||
|
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
||||||
|
x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
||||||
|
|
||||||
|
x = self.output(x) # output projection
|
||||||
|
return x
|
||||||
|
|
||||||
|
class GeGLU(torch.nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.key = nn.Linear(config.n_embd, 3 * config.n_embd)
|
||||||
|
self.value = nn.Linear(config.n_embd, 3 * config.n_embd)
|
||||||
|
self.weight = nn.Linear(3 * config.n_embd, config.n_embd)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
k = self.key(x)
|
||||||
|
v = self.value(x)
|
||||||
|
y = self.weight(F.gelu(k) * v)
|
||||||
|
return y
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
# The GPT Model with our blocks
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
class LabelSmoothingCrossEntropy(nn.Module): # might be able to avoid nan loss
|
||||||
|
def __init__(self, smoothing=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.confidence = 1.0 - smoothing
|
||||||
|
self.smoothing = smoothing
|
||||||
|
|
||||||
|
def forward(self, pred, target):
|
||||||
|
pred = pred.log_softmax(dim=-1)
|
||||||
|
with torch.no_grad():
|
||||||
|
true_dist = torch.zeros_like(pred)
|
||||||
|
true_dist.fill_(self.smoothing / (pred.size(-1) - 1))
|
||||||
|
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
|
||||||
|
return torch.mean(torch.sum(-true_dist * pred, dim=-1))
|
||||||
|
|
||||||
|
class GPTConfig:
|
||||||
|
def __init__(self, vocab_size, ctx_size, **kwargs):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.ctx_size = ctx_size
|
||||||
|
for k,v in kwargs.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.ln1 = nn.LayerNorm(config.n_embd)
|
||||||
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
||||||
|
|
||||||
|
if config.model_type == 'RWKV':
|
||||||
|
self.attn = RWKV_TimeMix(config)
|
||||||
|
self.mlp = RWKV_ChannelMix(config)
|
||||||
|
else:
|
||||||
|
self.attn = RotaryMHA(config)
|
||||||
|
self.mlp = GeGLU(config)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.attn(self.ln1(x))
|
||||||
|
x = x + self.mlp(self.ln2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
class GPT(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
|
||||||
|
|
||||||
|
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
||||||
|
|
||||||
|
self.ln_f = nn.LayerNorm(config.n_embd)
|
||||||
|
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.ctx_size = config.ctx_size
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
||||||
|
|
||||||
|
def get_ctx_size(self):
|
||||||
|
return self.ctx_size
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=0.01)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
def configure_optimizers(self, train_config):
|
||||||
|
# separate out all parameters to those that will and won't experience regularizing weight decay
|
||||||
|
decay = set()
|
||||||
|
no_decay = set()
|
||||||
|
|
||||||
|
whitelist_weight_modules = (nn.Linear, )
|
||||||
|
blacklist_weight_modules = (nn.LayerNorm, nn.Embedding)
|
||||||
|
for mn, m in self.named_modules():
|
||||||
|
for pn, p in m.named_parameters():
|
||||||
|
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
||||||
|
|
||||||
|
if pn.endswith('bias') or ('time' in fpn) or ('head' in fpn):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
||||||
|
decay.add(fpn)
|
||||||
|
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
|
||||||
|
# validate that we considered every parameter
|
||||||
|
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||||
|
inter_params = decay & no_decay
|
||||||
|
union_params = decay | no_decay
|
||||||
|
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
||||||
|
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
||||||
|
% (str(param_dict.keys() - union_params), )
|
||||||
|
|
||||||
|
optim_groups = [
|
||||||
|
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
|
||||||
|
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
||||||
|
]
|
||||||
|
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def forward(self, idx, targets=None):
|
||||||
|
B, T = idx.size()
|
||||||
|
assert T <= self.ctx_size, "Cannot forward, model block size is exhausted."
|
||||||
|
|
||||||
|
x = self.tok_emb(idx)
|
||||||
|
|
||||||
|
x = self.blocks(x)
|
||||||
|
|
||||||
|
x = self.ln_f(x)
|
||||||
|
logits = self.head(x)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if targets is not None:
|
||||||
|
loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||||
|
|
||||||
|
return logits, loss
|
||||||
@ -0,0 +1,128 @@
|
|||||||
|
import math
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
from torch.utils.data.dataloader import DataLoader
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class TrainerConfig:
|
||||||
|
max_epochs = 10
|
||||||
|
batch_size = 64
|
||||||
|
learning_rate = 3e-4
|
||||||
|
betas = (0.9, 0.95)
|
||||||
|
grad_norm_clip = 1.0
|
||||||
|
weight_decay = 0.01
|
||||||
|
lr_decay = False # learning rate decay params: linear warmup followed by cosine decay
|
||||||
|
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
|
||||||
|
final_tokens = 260e9 # (at what point we reach 10% of original LR)
|
||||||
|
ckpt_path = None
|
||||||
|
num_workers = 0 # for DataLoader
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
for k,v in kwargs.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
|
||||||
|
def __init__(self, model, train_dataset, test_dataset, config):
|
||||||
|
self.model = model
|
||||||
|
self.train_dataset = train_dataset
|
||||||
|
self.test_dataset = test_dataset
|
||||||
|
self.config = config
|
||||||
|
self.avg_loss = -1
|
||||||
|
|
||||||
|
# take over whatever gpus are on the system
|
||||||
|
self.device = 'cpu'
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.device = torch.cuda.current_device()
|
||||||
|
self.model = torch.nn.DataParallel(self.model).to(self.device)
|
||||||
|
|
||||||
|
def save_checkpoint(self):
|
||||||
|
# DataParallel wrappers keep raw model object in .module attribute
|
||||||
|
raw_model = self.model.module if hasattr(self.model, "module") else self.model
|
||||||
|
logger.info("saving %s", self.config.ckpt_path)
|
||||||
|
torch.save(raw_model.state_dict(), self.config.ckpt_path)
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
model, config = self.model, self.config
|
||||||
|
raw_model = model.module if hasattr(self.model, "module") else model
|
||||||
|
optimizer = raw_model.configure_optimizers(config)
|
||||||
|
|
||||||
|
def run_epoch(split):
|
||||||
|
is_train = split == 'train'
|
||||||
|
model.train(is_train)
|
||||||
|
data = self.train_dataset if is_train else self.test_dataset
|
||||||
|
loader = DataLoader(data, shuffle=True, pin_memory=True,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
pbar = tqdm(enumerate(loader), total=len(loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
|
||||||
|
for it, (x, y) in pbar:
|
||||||
|
|
||||||
|
# place data on the correct device
|
||||||
|
x = x.to(self.device)
|
||||||
|
y = y.to(self.device)
|
||||||
|
|
||||||
|
# forward the model
|
||||||
|
with torch.set_grad_enabled(is_train):
|
||||||
|
logits, loss = model(x, y)
|
||||||
|
loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
|
||||||
|
losses.append(loss.item())
|
||||||
|
|
||||||
|
if is_train:
|
||||||
|
|
||||||
|
# backprop and update the parameters
|
||||||
|
model.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# decay the learning rate based on our progress
|
||||||
|
if config.lr_decay:
|
||||||
|
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
|
||||||
|
if self.tokens < config.warmup_tokens:
|
||||||
|
# linear warmup
|
||||||
|
lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
|
||||||
|
progress = 0
|
||||||
|
else:
|
||||||
|
# cosine learning rate decay
|
||||||
|
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
|
||||||
|
lr_final_factor = config.lr_final / config.learning_rate
|
||||||
|
lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1
|
||||||
|
lr = config.learning_rate * lr_mult
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
else:
|
||||||
|
lr = config.learning_rate
|
||||||
|
|
||||||
|
# report progress
|
||||||
|
now_loss = loss.item()
|
||||||
|
if self.avg_loss < 0:
|
||||||
|
self.avg_loss = now_loss
|
||||||
|
else:
|
||||||
|
factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1))
|
||||||
|
self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor
|
||||||
|
pbar.set_description(f"epoch {epoch+1} progress {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
|
||||||
|
|
||||||
|
if not is_train:
|
||||||
|
test_loss = float(np.mean(losses))
|
||||||
|
logger.info("test loss: %f", test_loss)
|
||||||
|
return test_loss
|
||||||
|
|
||||||
|
best_loss = float('inf')
|
||||||
|
self.tokens = 0 # counter used for learning rate decay
|
||||||
|
for epoch in range(config.max_epochs):
|
||||||
|
|
||||||
|
run_epoch('train')
|
||||||
|
if self.test_dataset is not None:
|
||||||
|
test_loss = run_epoch('test')
|
||||||
|
|
||||||
|
# supports early stopping based on the test loss, or just save always if no test set is provided
|
||||||
|
good_model = self.test_dataset is None or test_loss < best_loss
|
||||||
|
if self.config.ckpt_path is not None and good_model:
|
||||||
|
best_loss = test_loss
|
||||||
|
self.save_checkpoint()
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
def top_k_logits(logits, k):
|
||||||
|
v, ix = torch.topk(logits, k)
|
||||||
|
out = logits.clone()
|
||||||
|
out[out < v[:, [-1]]] = -float('Inf')
|
||||||
|
return out
|
||||||
|
|
||||||
|
def top_p_probs(probs, p):
|
||||||
|
out = probs.clone()
|
||||||
|
|
||||||
|
sorted_probs, sorted_indices = torch.sort(out, descending=True)
|
||||||
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||||
|
sorted_indices_to_remove = cumulative_probs > p
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||||
|
out[indices_to_remove] = 0
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
# top-p + top-k + pow&ratio sampling
|
||||||
|
def sample_logits(logits, pos, temperature=1.0, top_k=None, top_p=None, min_p_pow=None, min_p_ratio=None):
|
||||||
|
logits = logits[:, pos, :] / temperature
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
if min_p_ratio is not None:
|
||||||
|
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
|
||||||
|
logits[probs < limit] = -float('Inf')
|
||||||
|
if top_k is not None:
|
||||||
|
logits = top_k_logits(logits, top_k)
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
if top_p is not None:
|
||||||
|
probs[0] = top_p_probs(probs[0], top_p)
|
||||||
|
ix = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
return ix[0][0].cpu()
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
@ -0,0 +1,117 @@
|
|||||||
|
import os, sys, time, math, random, json, datetime
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from src.trainer import Trainer, TrainerConfig
|
||||||
|
from src.model import GPT, GPTConfig
|
||||||
|
from src.utils import set_seed
|
||||||
|
|
||||||
|
set_seed(42)
|
||||||
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||||
|
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,)
|
||||||
|
|
||||||
|
model_type = 'RWKV' # 'RWKV' or 'RotaryMHA'
|
||||||
|
|
||||||
|
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip
|
||||||
|
model_level = 'character' # 'character' or 'word'
|
||||||
|
|
||||||
|
ctx_size = 256 if 'character' else 128
|
||||||
|
nLayers = 5
|
||||||
|
nHead = 8
|
||||||
|
nEmb = 512
|
||||||
|
|
||||||
|
nepoch = 50
|
||||||
|
nbatchsz = 64
|
||||||
|
epoch_length_fixed = 10000 # make an epoch very short, so we can see the training progress
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
print("loading data...", end="")
|
||||||
|
|
||||||
|
class Dataset(Dataset):
|
||||||
|
def __init__(self, data, model_level, ctx_size):
|
||||||
|
if model_level == 'word':
|
||||||
|
data = data.replace('\n', ' \n ').replace(' ', ' ').split(' ')
|
||||||
|
|
||||||
|
unique = sorted(list(set(data)))
|
||||||
|
data_size, vocab_size = len(data), len(unique)
|
||||||
|
self.stoi = { ch:i for i,ch in enumerate(unique) }
|
||||||
|
self.itos = { i:ch for i,ch in enumerate(unique) }
|
||||||
|
print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size))
|
||||||
|
self.ctx_size = ctx_size
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return epoch_length_fixed
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
i = np.random.randint(0, len(self.data) - (self.ctx_size + 1)) # CHEAT: pick a spot in the dataset at random
|
||||||
|
chunk = self.data[i:i+self.ctx_size+1]
|
||||||
|
dix = [self.stoi[s] for s in chunk]
|
||||||
|
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||||
|
y = torch.tensor(dix[1:], dtype=torch.long)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
train_dataset = Dataset(open(datafile, "r", encoding="utf-8").read(), model_level, ctx_size)
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_type=model_type,
|
||||||
|
n_layer=nLayers, n_head=nHead, n_embd=nEmb))
|
||||||
|
|
||||||
|
print('model', model_type, 'total epoch', nepoch, 'batchsz', nbatchsz, 'nLayers', nLayers, 'nHead', nHead, 'nEmb', nEmb, 'len', ctx_size)
|
||||||
|
tconf = TrainerConfig(model_type=model_type, max_epochs=nepoch, batch_size=nbatchsz,
|
||||||
|
learning_rate=6e-4 if model_type == 'RWKV' else 4e-4, betas=(0.9, 0.99), # RWKV can use higher LR
|
||||||
|
lr_decay=True, lr_final=2e-4, warmup_tokens=0, final_tokens=nepoch*len(train_dataset)*ctx_size, num_workers=0)
|
||||||
|
trainer = Trainer(model, train_dataset, None, tconf)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
torch.save(model, 'trained-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
from src.utils import sample_logits
|
||||||
|
|
||||||
|
MAX_LEN = ctx_size
|
||||||
|
NUM_OF_RUNS = 5
|
||||||
|
LENGTH_OF_EACH = 300
|
||||||
|
|
||||||
|
for run in range(NUM_OF_RUNS):
|
||||||
|
context = "It was"
|
||||||
|
|
||||||
|
x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64)
|
||||||
|
|
||||||
|
real_len = len(x)
|
||||||
|
if real_len < MAX_LEN:
|
||||||
|
x = np.pad(x, (0, MAX_LEN - real_len))
|
||||||
|
print_begin = 0
|
||||||
|
|
||||||
|
for i in range(LENGTH_OF_EACH):
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
print(('-' * 80) + '\n' + context, end = '')
|
||||||
|
print_begin = real_len
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
xxx = torch.tensor(x[-MAX_LEN:], dtype=torch.long)[None,...].to("cuda:0")
|
||||||
|
out, _ = model(xxx)
|
||||||
|
pos = -1 if real_len >= MAX_LEN else real_len - 1
|
||||||
|
|
||||||
|
char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02)
|
||||||
|
|
||||||
|
if real_len < MAX_LEN:
|
||||||
|
x[real_len] = char
|
||||||
|
else:
|
||||||
|
x = np.append(x, char)
|
||||||
|
real_len += 1
|
||||||
|
|
||||||
|
if i % 10 == 9 or i == LENGTH_OF_EACH-1:
|
||||||
|
completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]])
|
||||||
|
print(completion, end = '')
|
||||||
|
print_begin = real_len
|
||||||
|
print()
|
||||||
Loading…
Reference in New Issue