RWKV v2 RNN is here. Probably the strongest LM as of now.

main
BlinkDL 4 years ago
parent 1f189a4034
commit 5f21ddf20d

@ -0,0 +1,172 @@
#include <stdio.h>
// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)
#define F4(A, B) ((float4 *)(A))[(B) >> 2]
template <typename F>
__global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,
const F eps, const int B, const int C, const int T) {
const int i = blockIdx.y;
const int ij = (B * C) / BF;
const int t = threadIdx.x << 2;
__shared__ F ww[Tmax];
__shared__ F kk[Tmax * BF];
F4(ww, t) = F4(__w, t + T * (i % C));
#pragma unroll
for (int j = 0; j < BF; j++) {
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
}
__syncthreads();
float4 s[BF];
#pragma unroll
for (int j = 0; j < BF; j++) {
s[j] = {eps, eps, eps, eps};
}
const F *__restrict__ const w = ww + T - t - 4;
for (int u = 0; u <= t; u++) {
#pragma unroll
for (int j = 0; j < BF; j++) {
const F x = kk[u + Tmax * j];
s[j].x += w[u + 3] * x;
s[j].y += w[u + 2] * x;
s[j].z += w[u + 1] * x;
s[j].w += w[u + 0] * x;
}
}
#pragma unroll
for (int j = 0; j < BF; j++) {
const F *__restrict__ const k = kk + Tmax * j;
s[j].y += w[t + 3] * k[t + 1];
s[j].z += w[t + 2] * k[t + 1];
s[j].z += w[t + 3] * k[t + 2];
s[j].w += w[t + 1] * k[t + 1];
s[j].w += w[t + 2] * k[t + 2];
s[j].w += w[t + 3] * k[t + 3];
F4(x, t + T * (i + ij * j)) = s[j];
}
}
template <typename F>
__global__ void kernel_backward_W(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
F *__restrict__ const gw, F *__restrict__ const gk,
const int B, const int C, const int T) {
const int i = blockIdx.y;
const int t = threadIdx.x << 2;
__shared__ F k[Tmax];
__shared__ F gg[Tmax];
F4(k, t) = F4(__k, t + T * i);
F4(gg, t) = F4(__gwk, t + T * i);
__syncthreads();
float4 s = {0, 0, 0, 0};
const F *__restrict__ const g = gg + T - t - 4;
for (int u = 0; u <= t; u++) {
F x = k[u];
s.x += g[u + 3] * x;
s.y += g[u + 2] * x;
s.z += g[u + 1] * x;
s.w += g[u + 0] * x;
}
s.y += g[t + 3] * k[t + 1];
s.z += g[t + 2] * k[t + 1];
s.z += g[t + 3] * k[t + 2];
s.w += g[t + 1] * k[t + 1];
s.w += g[t + 2] * k[t + 2];
s.w += g[t + 3] * k[t + 3];
F4(gw, t + T * i) = s;
}
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) {
dim3 gridDim(1, B * C / BF);
dim3 blockDim(T >> 2);
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
}
template <typename F>
__global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
F *__restrict__ const gw, F *__restrict__ const gk,
const int B, const int C, const int T) {
const int i = blockIdx.y;
const int ij = (B * C) / BB;
const int t = threadIdx.x << 2;
__shared__ F w[Tmax];
__shared__ F kk[Tmax * BB];
__shared__ F gg[Tmax * BB];
F4(w, t) = F4(__w, t + T * (i % C));
#pragma unroll
for (int j = 0; j < BB; j++) {
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
F4(gg, t + Tmax * j) = F4(__gwk, t + T * (i + ij * j));
}
__syncthreads();
float4 s[BB];
#pragma unroll
for (int j = 0; j < BB; j++) {
s[j] = {0, 0, 0, 0};
}
for (int u = 0; u <= t; u++) {
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
F x = kk[u + Tmax * j];
s[j].x += g[u + 3] * x;
s[j].y += g[u + 2] * x;
s[j].z += g[u + 1] * x;
s[j].w += g[u + 0] * x;
}
}
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const k = kk + Tmax * j;
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
s[j].y += g[t + 3] * k[t + 1];
s[j].z += g[t + 2] * k[t + 1];
s[j].z += g[t + 3] * k[t + 2];
s[j].w += g[t + 1] * k[t + 1];
s[j].w += g[t + 2] * k[t + 2];
s[j].w += g[t + 3] * k[t + 3];
F4(gw, t + T * (i + ij * j)) = s[j];
}
#pragma unroll
for (int j = 0; j < BB; j++) {
s[j] = {0, 0, 0, 0};
}
for (int u = t + 3; u < T; u++) {
F x = w[u];
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
s[j].x += g[2 - u] * x;
s[j].y += g[3 - u] * x;
s[j].z += g[4 - u] * x;
s[j].w += g[5 - u] * x;
}
}
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
s[j].x += g[2 - t] * w[t + 0];
s[j].x += g[1 - t] * w[t + 1];
s[j].x += g[0 - t] * w[t + 2];
s[j].y += g[2 - t] * w[t + 1];
s[j].y += g[1 - t] * w[t + 2];
s[j].z += g[2 - t] * w[t + 2];
F4(gk, t + T * (i + ij * j)) = s[j];
}
}
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) {
dim3 gridDim(1, B * C / BB);
dim3 blockDim(T >> 2);
kernel_backward<<<gridDim, blockDim>>>(w, k, gwk, gw, gk, B, C, T);
}

@ -0,0 +1,21 @@
#include <torch/extension.h>
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T);
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T);
void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) {
cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T);
}
void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) {
cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "timex forward");
m.def("backward", &backward, "timex backward");
}
TORCH_LIBRARY(timex, m) {
m.def("forward", forward);
m.def("backward", backward);
}

@ -0,0 +1,236 @@
# -*- coding:utf-8 -*-
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import types
import copy
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
### Step 1: set model ##################################################################################
ctx_len = 1024
n_layer = 6
n_embd = 512
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
MODEL_NAME = 'trained-31' # your trained model
WORD_NAME = 'vocab' # the .json vocab (generated by train.py)
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> unknown tokens in your context will be denoted by it <--
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False # True False - show softmax output
DEBUG_TIME = False # True False - show trained time-coeffs
### Step 2: set context ################################################################################
context = "\n" # ==> this is your prompt
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 500
TEMPERATURE = 1.0
top_p = 0.7
top_p_newline = 0.9
########################################################################################################
np.set_printoptions(precision=4, suppress=True, linewidth=200)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
context = tokenizer.refine_context(context)
print('Your context has ' + str(len(context)) + ' tokens')
print(f'Loading {MODEL_NAME}...')
##############################################################################################################
RWKV_K_CLAMP = 60
RWKV_K_EPS = 1e-16
RWKV_HEAD_QK_DIM = 256
class RWKV_RNN():
def __init__(self, MODEL_NAME):
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth',
map_location=torch.device(RUN_DEVICE)) # .state_dict()
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
self.hk = None
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
target.hk = copy.deepcopy(self.hk)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
self.hk = copy.deepcopy(target.hk)
def LN(self, xx, w):
return F.layer_norm(xx, (n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if DEBUG_TIME:
print(name+'.time_mix', w.time_mix.squeeze().numpy())
if name not in self.xx:
self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.square(torch.relu(w.key.weight @ x))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if DEBUG_TIME:
print(name+'.time_mix', w.time_mix.squeeze().numpy())
print(name+'.time_decay', w.time_decay.squeeze().numpy())
print(name+'.time_first', w.time_first.squeeze().numpy())
if name not in self.xx:
self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
self.aa[name] = torch.zeros(n_embd, device=RUN_DEVICE)
self.bb[name] = torch.zeros(n_embd, device=RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.exp(torch.clamp(w.key.weight @ x, max=RWKV_K_CLAMP))
v = w.value.weight @ x
kv = k * v
a = self.aa[name] + w.time_first * kv
b = self.bb[name] + w.time_first * k
self.aa[name] = w.time_decay * self.aa[name] + kv
self.bb[name] = w.time_decay * self.bb[name] + k
rwkv = r * a / (b + RWKV_K_EPS)
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
for i in range(n_layer):
x = self.LN(x, w.blocks[i].ln1)
if i == 0 and model_type == 'RWKV-ffnPre':
x = x + self.FF(x, w.blocks[i].ffnPre, f'ffnPre.{i}')
else:
x = x + self.SA(x, w.blocks[i].att, f'att.{i}')
x = self.LN(x, w.blocks[i].ln2)
x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
if self.hk == None:
self.hk = (w.head_k.weight @ x).unsqueeze(0)
else:
self.hk = torch.cat(
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
if self.hk.shape[0] > ctx_len:
self.hk = self.hk[-ctx_len:, :]
q = w.head_q.weight @ x
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
for i in range(len(c)):
x[ctx[i]] += c[i]
return x
##############################################################################################################
model = RWKV_RNN(MODEL_NAME)
print('\n\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n')
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
src_len = len(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
print(context.replace('\n', '\n '), end='')
model.clear()
if TRIAL == 0:
init_state = types.SimpleNamespace()
for i in range(src_len):
x = ctx[:i+1]
if i == src_len - 1:
init_state.out = model.run(x)
else:
model.run(x)
model.save(init_state)
else:
model.load(init_state)
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
x = ctx[:i+1]
x = x[-ctx_len:]
if i == src_len:
out = copy.deepcopy(init_state.out)
else:
out = model.run(x)
if DEBUG_DEBUG:
print('model', np.array(x), '==>', np.array(
out), np.max(out), np.min(out))
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
top_p_usual=top_p, top_p_newline=top_p_newline)
char = char.item()
print(tokenizer.itos[int(char)].replace(
'\n', '\n '), end='', flush=True)
ctx += [char]
print('\n' + '-' * 40, end='')

@ -0,0 +1,349 @@
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
from torch.utils.cpp_extension import load
import math
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)
########################################################################################################
# CUDA Kernel
########################################################################################################
T_MAX = 1024 # increase this if your ctx_len > 1024
B_GROUP_FORWARD = 8 # set to 8 for best performance
B_GROUP_BACKWARD = 2 # set to 2 for best performance
timex_cuda = load(name="timex", sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}'])
class TimeX(torch.autograd.Function):
@staticmethod
def forward(ctx, w, k, B, C, T, eps):
ctx.B = B
ctx.C = C
ctx.T = T
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
w = w.contiguous()
k = k.contiguous()
ctx.save_for_backward(w, k)
wk = torch.empty((B, C, T), device='cuda',
memory_format=torch.contiguous_format)
timex_cuda.forward(w, k, wk, eps, B, C, T)
return wk
@staticmethod
def backward(ctx, gwk):
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
w, k = ctx.saved_tensors
gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
memory_format=torch.contiguous_format)
gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
memory_format=torch.contiguous_format)
timex_cuda.backward(w, k, gwk.contiguous(), gw,
gk, ctx.B, ctx.C, ctx.T)
return (gw.sum(dim=0), gk, None, None, None, None)
########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
RWKV_K_CLAMP = 60 # e^60 = 1e26
RWKV_K_EPS = 1e-16
RWKV_HEAD_QK_DIM = 256
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
with torch.no_grad():
name = '[unknown weight]'
for name, parameter in module.named_parameters(): # find the name of the weight
if id(m.weight) == id(parameter):
break
shape = m.weight.data.shape
gain = 1.0
scale = 1.0 # extra scale for gain
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
scale = 1e-4
else:
scale = 0
if isinstance(m, nn.Linear):
if m.bias is not None:
m.bias.data.zero_()
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
scale = 0.5
if hasattr(m, 'scale_init'):
scale = m.scale_init
# print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
gain *= scale
if scale == -999:
nn.init.eye_(m.weight)
elif gain == 0:
# zero init is great for some RWKV matrices
nn.init.zeros_(m.weight)
elif gain > 0:
nn.init.orthogonal_(m.weight, gain=gain)
else:
nn.init.normal_(m.weight, mean=0.0, std=-scale)
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_embd = config.n_embd
attn_sz = config.n_embd
############# fancy init of time_w curves ###################################
f1_begin = 3.0
f1_end = 1.2
f2_begin = 0.65
f2_end = 0.4
with torch.no_grad(): # initial time_w curves for better convergence
decay_speed = torch.ones(attn_sz, 1)
first_sa_layer_id = 1
for h in range(attn_sz):
f1 = f1_begin + (layer_id-first_sa_layer_id) / \
(config.n_layer-1-first_sa_layer_id) * (f1_end - f1_begin)
f2 = f2_begin + (layer_id-first_sa_layer_id) / \
(config.n_layer-1-first_sa_layer_id) * (f2_end - f2_begin)
if layer_id == first_sa_layer_id:
f1 += 0.5
if layer_id == config.n_layer-2:
f2 = 0.4
if layer_id == config.n_layer-1:
f2 = 0.37
decay_speed[h][0] = math.pow(f2, h / (attn_sz-1) * 7) * f1
self.time_decay = nn.Parameter(torch.log(decay_speed)) # will use exp(self.time_decay) to ensure time_decay > 0
self.time_curve = torch.tensor(
[-(config.ctx_len - 2 - i) for i in range(config.ctx_len-1)]).unsqueeze(0)
self.time_curve = self.time_curve.to('cuda')
self.time_first = nn.Parameter(torch.ones(attn_sz, 1) * math.log(0.3))
#############################################################################
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # init to "shift half of the channels"
ww = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd // 2):
ww[0, 0, i] = 0
self.time_mix = nn.Parameter(ww)
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 0
def forward(self, x):
B, T, C = x.size()
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
k = self.key(x).transpose(-1, -2)
v = self.value(x).transpose(-1, -2)
r = self.receptance(x)
# RWKV_K_CLAMP can be removed if the CUDA kernel substracts the correct k_max for each k (I will do this later)
k = torch.clamp(k, max=RWKV_K_CLAMP)
k = torch.exp(k)
kv = k * v
self.time_w = torch.cat(
[torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1)
w = torch.exp(self.time_w)
wkv = TimeX.apply(w, kv, B, C, T, 0)
# RWKV_K_EPS can be removed if the CUDA kernel sets 0/0 = 0 (I will do this later)
wk = TimeX.apply(w, k, B, C, T, RWKV_K_EPS)
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv)
return rwkv
class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # init to "shift half of the channels"
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd // 2):
x[0, 0, i] = 0
self.time_mix = nn.Parameter(x)
hidden_sz = 4 * config.n_embd
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
self.value.scale_init = 0
self.receptance.scale_init = 0
def forward(self, x):
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
k = self.key(x)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(x)) * kv
return rkv
########################################################################################################
# The GPT Model with our blocks
########################################################################################################
class GPTConfig:
def __init__(self, vocab_size, ctx_len, **kwargs):
self.vocab_size = vocab_size
self.ctx_len = ctx_len
for k, v in kwargs.items():
setattr(self, k, v)
class Block(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.config = config
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(config, layer_id+1000)
else:
self.att = RWKV_TimeMix(config, layer_id)
self.ffn = RWKV_ChannelMix(config, layer_id)
def forward(self, x):
x = self.ln1(x)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(x) # better in some cases
else:
x = x + self.att(x)
x = self.ln2(x)
x = x + self.ffn(x)
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.step = 0
self.config = config
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.Sequential(*[Block(config, i)
for i in range(config.n_layer)])
self.ln_out = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len
RWKV_Init(self, config)
logger.info("number of parameters: %e", sum(p.numel()
for p in self.parameters()))
def get_ctx_len(self):
return self.ctx_len
def _init_weights(self, module):
if isinstance(module, (nn.Linear)):
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, (nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=1e-5)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def configure_optimizers(self, train_config):
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
for mn, m in self.named_modules(): # here we disable weight_decay
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
no_decay.add(fpn)
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(
inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
optim_groups = [
{"params": [param_dict[pn]
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.Adam(
optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
return optimizer
def forward(self, idx, targets=None):
self.step += 1
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).float()
x = self.head(x) + c
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
return x, loss

@ -0,0 +1,170 @@
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
from torch.utils.data.dataloader import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import functional as F
import torch.nn as nn
import torch.optim as optim
import torch
from tqdm.auto import tqdm
import numpy as np
import logging
import os
import datetime
import sys
import math
# import wandb # comment this if you don't have wandb
# print('logging to wandb... (comment it if you don\'t have wandb)')
logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
log_file = open("mylog.txt", "a")
class TrainerConfig:
max_epochs = 10
batch_size = 64
learning_rate = 4e-4
betas = (0.9, 0.99)
eps = 1e-8
grad_norm_clip = 1.0
lr_decay = True # linear warmup followed by cosine decay
warmup_tokens = 0
final_tokens = 0
epoch_save_frequency = 0
epoch_save_path = 'trained-'
num_workers = 0 # for DataLoader
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class Trainer:
def __init__(self, model, train_dataset, test_dataset, config):
self.model = model
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.config = config
self.avg_loss = -1
self.steps = 0
if 'wandb' in sys.modules:
cfg = model.config
for k in config.__dict__:
setattr(cfg, k, config.__dict__[k]) # combine cfg
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' +
datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
self.device = 'cpu'
if torch.cuda.is_available(): # take over whatever gpus are on the system
self.device = torch.cuda.current_device()
def get_run_name(self):
raw_model = self.model.module if hasattr(
self.model, "module") else self.model
cfg = raw_model.config
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
return run_name
def train(self):
model, config = self.model, self.config
raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config)
def run_epoch(split):
is_train = split == 'train'
model.train(is_train)
data = self.train_dataset if is_train else self.test_dataset
if config.num_workers > 0:
loader = DataLoader(data, shuffle=False, pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers)
else:
loader = DataLoader(data, shuffle=False,
batch_size=config.batch_size,
num_workers=config.num_workers)
pbar = tqdm(enumerate(loader), total=len(
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
for it, (x, y) in pbar:
x = x.to(self.device) # place data on the correct device
y = y.to(self.device)
with torch.set_grad_enabled(is_train):
_, loss = model(x, y) # forward the model
if is_train: # backprop and update the parameters
model.zero_grad()
loss.backward()
if config.grad_norm_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), config.grad_norm_clip)
optimizer.step()
if config.lr_decay: # decay the learning rate based on our progress
# number of tokens processed this step (i.e. label is not -100)
self.tokens += (y >= 0).sum()
lr_final_factor = config.lr_final / config.learning_rate
if self.tokens < config.warmup_tokens:
# linear warmup
lr_mult = lr_final_factor + \
(1 - lr_final_factor) * float(self.tokens) / \
float(config.warmup_tokens)
progress = 0
else:
# cosine learning rate decay
progress = float(self.tokens - config.warmup_tokens) / float(
max(1, config.final_tokens - config.warmup_tokens))
lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor /
2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1
lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
lr = config.learning_rate
now_loss = loss.item() # report progress
self.lr = lr
if 'wandb' in sys.modules:
wandb.log({"loss": now_loss},
step=self.steps * self.config.batch_size)
self.steps += 1
if self.avg_loss < 0:
self.avg_loss = now_loss
else:
factor = 1 / (it + 1)
self.avg_loss = self.avg_loss * \
(1.0 - factor) + now_loss * factor
pbar.set_description(
f"epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):
run_epoch('train')
log_file.write(
f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n')
log_file.flush()
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
# DataParallel wrappers keep raw model object in .module
raw_model = self.model.module if hasattr(
self.model, "module") else self.model
torch.save(raw_model.state_dict(),
self.config.epoch_save_path + str(epoch+1) + '.pth')

@ -0,0 +1,80 @@
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import json
import random
import time
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
class TOKENIZER():
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
self.word_table = json.load(result_file)
self.vocab_size = len(self.word_table)
self.stoi = {v: int(k) for k, v in self.word_table.items()}
self.itos = {int(k): v for k, v in self.word_table.items()}
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
# out[self.UNKNOWN_CHAR] = -float('Inf')
lastChar = int(x[-1])
probs = F.softmax(torch.tensor(out), dim=-1)
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
sorted_probs, s_index = torch.sort(probs, descending=True)
# for j in range(30):
# pp = sorted_probs[j].item()
# if pp < 0.005:
# break
# ss = self.itos[int(s_index[j])].replace('\n','_')
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
# print('')
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
return torch.multinomial(probs, num_samples=1)[0]
def to_float(x):
return x.cpu().detach().numpy().flatten()[0].astype(float)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

@ -0,0 +1,141 @@
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import logging
import datetime
import json
from src.model import GPT, GPTConfig
from src.trainer import Trainer, TrainerConfig
from torch.utils.data import Dataset
import torch
import numpy as np
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
### Step 1: set training data ##########################################################################
datafile = "enwik8"
datafile_encoding = 'utf-8'
# datafile_encoding = 'utf-16le'
### Step 2: set model size #############################################################################
ctx_len = 1024 # ===> increase T_MAX in model.py if your ctx_len > 1024
n_layer = 6
n_embd = 512
# 'RWKV' (better for char-level English) or 'RWKV-ffnPre' (better in some cases)
model_type = 'RWKV'
### Step 3: set batch size #############################################################################
# ===> batch_size must be divisible by B_GROUP_FORWARD and B_GROUP_BACKWARD in model.py
# For example, if your batch_size = 20, you can set B_GROUP_FORWARD = 4, B_GROUP_BACKWARD = 2
# If you see "CUDA out of memory", reduce it. Use GPU-Z to find the highest value for your VRAM.
batch_size = 40
### Step 4: set learning rate, training 'epochs' #######################################################
lr_init = 6e-4
lr_final = 1e-5
# the 'epoch' here is very short and of fixed length (ctx_len * epoch_length_fixed tokens)
n_epoch = 1000
# 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc.
epoch_save_frequency = 30
epoch_save_path = 'trained-'
epoch_length_fixed = 10000
########################################################################################################
# import src.utils
# src.utils.set_seed(42) # remember to change seed if you load a model
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
grad_norm_clip = 1.0
warmup_tokens = 0
betas = (0.9, 0.99)
eps = 4e-9
num_workers = 0
########################################################################################################
# Load data
########################################################################################################
print('loading data... ' + datafile)
class Dataset(Dataset):
def __init__(self, data, ctx_len):
print('building token list...', end=' ')
unique = sorted(list(set(data)))
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
data_size, vocab_size = len(data), len(unique)
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
self.ctx_len = ctx_len
self.vocab_size = vocab_size
self.data = data
def __len__(self):
return epoch_length_fixed
def __getitem__(self, idx):
# cheat: pick a random spot in dataset
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long,
device=torch.device('cuda'))
y = torch.tensor(dix[1:], dtype=torch.long,
device=torch.device('cuda'))
return x, y
train_dataset = Dataset(
open(datafile, "r", encoding=datafile_encoding).read(), ctx_len)
########################################################################################################
# Train model
########################################################################################################
if __name__ == '__main__':
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
n_layer=n_layer, n_embd=n_embd)).cuda()
# # load a trained model. remember to change random seed
# m2 = torch.load('trained-10000.pth')
# model.load_state_dict(m2)
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, )
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip,
warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()
torch.save(model, 'trained-' + str(n_epoch) + trainer.get_run_name() +
'-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')
Loading…
Cancel
Save