RWKV-3 (test deeper models (n_layer >= 12) to see the advantage)

main
PENG Bo 4 years ago committed by GitHub
parent 1f6461b90b
commit b6403a8aef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,172 @@
#include <stdio.h>
// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)
#define F4(A, B) ((float4 *)(A))[(B) >> 2]
template <typename F>
__global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,
const F eps, const int B, const int C, const int T) {
const int i = blockIdx.y;
const int ij = (B * C) / BF;
const int t = threadIdx.x << 2;
__shared__ F ww[Tmax];
__shared__ F kk[Tmax * BF];
F4(ww, t) = F4(__w, t + T * (i % C));
#pragma unroll
for (int j = 0; j < BF; j++) {
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
}
__syncthreads();
float4 s[BF];
#pragma unroll
for (int j = 0; j < BF; j++) {
s[j] = {eps, eps, eps, eps};
}
const F *__restrict__ const w = ww + T - t - 4;
for (int u = 0; u <= t; u++) {
#pragma unroll
for (int j = 0; j < BF; j++) {
const F x = kk[u + Tmax * j];
s[j].x += w[u + 3] * x;
s[j].y += w[u + 2] * x;
s[j].z += w[u + 1] * x;
s[j].w += w[u + 0] * x;
}
}
#pragma unroll
for (int j = 0; j < BF; j++) {
const F *__restrict__ const k = kk + Tmax * j;
s[j].y += w[t + 3] * k[t + 1];
s[j].z += w[t + 2] * k[t + 1];
s[j].z += w[t + 3] * k[t + 2];
s[j].w += w[t + 1] * k[t + 1];
s[j].w += w[t + 2] * k[t + 2];
s[j].w += w[t + 3] * k[t + 3];
F4(x, t + T * (i + ij * j)) = s[j];
}
}
template <typename F>
__global__ void kernel_backward_W(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
F *__restrict__ const gw, F *__restrict__ const gk,
const int B, const int C, const int T) {
const int i = blockIdx.y;
const int t = threadIdx.x << 2;
__shared__ F k[Tmax];
__shared__ F gg[Tmax];
F4(k, t) = F4(__k, t + T * i);
F4(gg, t) = F4(__gwk, t + T * i);
__syncthreads();
float4 s = {0, 0, 0, 0};
const F *__restrict__ const g = gg + T - t - 4;
for (int u = 0; u <= t; u++) {
F x = k[u];
s.x += g[u + 3] * x;
s.y += g[u + 2] * x;
s.z += g[u + 1] * x;
s.w += g[u + 0] * x;
}
s.y += g[t + 3] * k[t + 1];
s.z += g[t + 2] * k[t + 1];
s.z += g[t + 3] * k[t + 2];
s.w += g[t + 1] * k[t + 1];
s.w += g[t + 2] * k[t + 2];
s.w += g[t + 3] * k[t + 3];
F4(gw, t + T * i) = s;
}
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) {
dim3 gridDim(1, B * C / BF);
dim3 blockDim(T >> 2);
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
}
template <typename F>
__global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
F *__restrict__ const gw, F *__restrict__ const gk,
const int B, const int C, const int T) {
const int i = blockIdx.y;
const int ij = (B * C) / BB;
const int t = threadIdx.x << 2;
__shared__ F w[Tmax];
__shared__ F kk[Tmax * BB];
__shared__ F gg[Tmax * BB];
F4(w, t) = F4(__w, t + T * (i % C));
#pragma unroll
for (int j = 0; j < BB; j++) {
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
F4(gg, t + Tmax * j) = F4(__gwk, t + T * (i + ij * j));
}
__syncthreads();
float4 s[BB];
#pragma unroll
for (int j = 0; j < BB; j++) {
s[j] = {0, 0, 0, 0};
}
for (int u = 0; u <= t; u++) {
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
F x = kk[u + Tmax * j];
s[j].x += g[u + 3] * x;
s[j].y += g[u + 2] * x;
s[j].z += g[u + 1] * x;
s[j].w += g[u + 0] * x;
}
}
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const k = kk + Tmax * j;
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
s[j].y += g[t + 3] * k[t + 1];
s[j].z += g[t + 2] * k[t + 1];
s[j].z += g[t + 3] * k[t + 2];
s[j].w += g[t + 1] * k[t + 1];
s[j].w += g[t + 2] * k[t + 2];
s[j].w += g[t + 3] * k[t + 3];
F4(gw, t + T * (i + ij * j)) = s[j];
}
#pragma unroll
for (int j = 0; j < BB; j++) {
s[j] = {0, 0, 0, 0};
}
for (int u = t + 3; u < T; u++) {
F x = w[u];
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
s[j].x += g[2 - u] * x;
s[j].y += g[3 - u] * x;
s[j].z += g[4 - u] * x;
s[j].w += g[5 - u] * x;
}
}
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
s[j].x += g[2 - t] * w[t + 0];
s[j].x += g[1 - t] * w[t + 1];
s[j].x += g[0 - t] * w[t + 2];
s[j].y += g[2 - t] * w[t + 1];
s[j].y += g[1 - t] * w[t + 2];
s[j].z += g[2 - t] * w[t + 2];
F4(gk, t + T * (i + ij * j)) = s[j];
}
}
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) {
dim3 gridDim(1, B * C / BB);
dim3 blockDim(T >> 2);
kernel_backward<<<gridDim, blockDim>>>(w, k, gwk, gw, gk, B, C, T);
}

@ -0,0 +1,21 @@
#include <torch/extension.h>
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T);
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T);
void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) {
cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T);
}
void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) {
cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "timex forward");
m.def("backward", &backward, "timex backward");
}
TORCH_LIBRARY(timex, m) {
m.def("forward", forward);
m.def("backward", backward);
}

@ -0,0 +1,98 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math
import time
import types
import copy
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER, Dataset
from src.model_run import RWKV_RNN
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
### Step 1: set model ##################################################################################
ctx_len = 1024
n_layer = 6
n_embd = 512
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
# your trained model
MODEL_NAME = 'trained-1'
WORD_NAME = 'vocab' # the .json vocab (generated by train.py
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> all unknown tokens in your context will be denoted by it <--
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False # True False - show softmax output
### Step 2: set context ################################################################################
context = "\nIn the" # ==> this is your prompt
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 500
TEMPERATURE = 1.0
top_p = 0.7
top_p_newline = 0.9
########################################################################################################
print(f'Loading {MODEL_NAME}...')
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
########################################################################################################
context = tokenizer.refine_context(context)
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n')
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
t_begin = time.time_ns()
src_len = len(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
print(('-' * 30) + context, end='')
model.clear()
if TRIAL == 0:
init_state = types.SimpleNamespace()
for i in range(src_len):
x = ctx[:i+1]
if i == src_len - 1:
init_state.out = model.run(x)
else:
model.run(x)
model.save(init_state)
else:
model.load(init_state)
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
x = ctx[:i+1]
x = x[-ctx_len:]
if i == src_len:
out = copy.deepcopy(init_state.out)
else:
out = model.run(x)
if DEBUG_DEBUG:
print('model', np.array(x), '==>', np.array(
out), np.max(out), np.min(out))
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
top_p_usual=top_p, top_p_newline=top_p_newline)
char = char.item()
print(tokenizer.itos[int(char)], end='', flush=True)
ctx += [char]
t_end = time.time_ns()
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')

@ -0,0 +1,363 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
from torch.utils.cpp_extension import load
import math
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)
RWKV_K_CLAMP = 60 # e^60 = 1e26
RWKV_K_EPS = 1e-8
RWKV_HEAD_QK_DIM = 256
print(f'\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
########################################################################################################
# CUDA Kernel
########################################################################################################
T_MAX = 1024 # increase this if your ctx_len > 1024
B_GROUP_FORWARD = 4 # set to 8 for best performance
B_GROUP_BACKWARD = 2 # set to 2 for best performance (sometimes 8 is faster)
timex_cuda = load(name="timex", sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}'])
class TimeX(torch.autograd.Function):
@staticmethod
def forward(ctx, w, k, B, C, T, eps):
ctx.B = B
ctx.C = C
ctx.T = T
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
w = w.contiguous()
k = k.contiguous()
ctx.save_for_backward(w, k)
wk = torch.empty((B, C, T), device='cuda',
memory_format=torch.contiguous_format)
timex_cuda.forward(w, k, wk, eps, B, C, T)
return wk
@staticmethod
def backward(ctx, gwk):
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
w, k = ctx.saved_tensors
gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
memory_format=torch.contiguous_format)
gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
memory_format=torch.contiguous_format)
timex_cuda.backward(w, k, gwk.contiguous(), gw,
gk, ctx.B, ctx.C, ctx.T)
return (gw.sum(dim=0), gk, None, None, None, None)
########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
with torch.no_grad():
name = '[unknown weight]'
for name, parameter in module.named_parameters(): # find the name of the weight
if id(m.weight) == id(parameter):
break
shape = m.weight.data.shape
gain = 1.0
scale = 1.0 # extra scale for gain
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
scale = 1e-4
else:
scale = 0
if isinstance(m, nn.Linear):
if m.bias is not None:
m.bias.data.zero_()
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
scale = 0.5
if hasattr(m, 'scale_init'):
scale = m.scale_init
# print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
gain *= scale
if scale == -999:
nn.init.eye_(m.weight)
elif gain == 0:
# zero init is great for some RWKV matrices
nn.init.zeros_(m.weight)
elif gain > 0:
nn.init.orthogonal_(m.weight, gain=gain)
else:
nn.init.normal_(m.weight, mean=0.0, std=-scale)
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_embd = config.n_embd
attn_sz = config.n_embd
with torch.no_grad(): # fancy init
self.time_curve = torch.tensor([-(config.ctx_len - 2 - i) for i in range(config.ctx_len-1)]).unsqueeze(0)
self.time_curve = self.time_curve.to('cuda')
ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
# fancy time_decay
decay_speed = torch.ones(attn_sz, 1)
for h in range(attn_sz):
decay_speed[h][0] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1)
self.time_decay = nn.Parameter(decay_speed)
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
# fancy time_first
zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5).unsqueeze(1)
self.time_first = nn.Parameter(torch.ones(attn_sz, 1) * math.log(0.3) + zigzag)
# fancy time_mix
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd):
x[0, 0, i] = i / config.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 0
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1))
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
# RWKV_K_CLAMP can be removed if the CUDA kernel substracts the correct k_max for each k (I will do this later)
k = torch.clamp(k, max=RWKV_K_CLAMP) # clamp k to avoid overflow
k = torch.exp(k)
kv = k * v
# Compute the W-curve = [e^(-n * e^time_decay), e^(-(n-1) * e^time_decay), ..., 1, e^(time_first)]
self.time_w = torch.cat(
[torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1)
w = torch.exp(self.time_w)
# Use W to mix kv and k respectively. Add K_EPS to wk to avoid divide-by-zero
wkv = TimeX.apply(w, kv, B, C, T, 0)
# RWKV_K_EPS can be removed if the CUDA kernel sets 0/0 = 0 (I will do this later)
wk = TimeX.apply(w, k, B, C, T, RWKV_K_EPS)
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv)
return rwkv
class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd):
x[0, 0, i] = i / config.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
hidden_sz = 4 * config.n_embd
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
self.value.scale_init = 0
self.receptance.scale_init = 0
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
########################################################################################################
# The GPT Model with our blocks
########################################################################################################
class GPTConfig:
def __init__(self, vocab_size, ctx_len, **kwargs):
self.vocab_size = vocab_size
self.ctx_len = ctx_len
for k, v in kwargs.items():
setattr(self, k, v)
class Block(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.config = config
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(config, layer_id+1000)
else:
self.att = RWKV_TimeMix(config, layer_id)
self.ffn = RWKV_ChannelMix(config, layer_id)
def forward(self, x):
if self.layer_id == 0:
x = self.ln0(x)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(self.ln1(x)) # better in some cases
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.step = 0
self.config = config
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.Sequential(*[Block(config, i)
for i in range(config.n_layer)])
self.ln_out = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
if RWKV_HEAD_QK_DIM > 0:
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len
RWKV_Init(self, config)
logger.info("number of parameters: %e", sum(p.numel()
for p in self.parameters()))
def get_ctx_len(self):
return self.ctx_len
def _init_weights(self, module):
if isinstance(module, (nn.Linear)):
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, (nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=1e-5)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def configure_optimizers(self, train_config):
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
for mn, m in self.named_modules(): # here we disable weight_decay
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
no_decay.add(fpn)
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(
inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
optim_groups = [
{"params": [param_dict[pn]
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.Adam(
optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
return optimizer
def forward(self, idx, targets=None):
self.step += 1
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
if RWKV_HEAD_QK_DIM > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).float()
x = self.head(x) + c
else:
x = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
return x, loss

@ -0,0 +1,319 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import types
import copy
import torch
import math
from torch.nn import functional as F
import torch.nn as nn
RWKV_K_CLAMP = 60
RWKV_K_EPS = 1e-8
RWKV_HEAD_QK_DIM = 256
print(f'\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
DEBUG_TIME = False # True False - show trained time-coeffs
############################################################################################################
RWKV_CFG = types.SimpleNamespace()
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
hidden_sz = 4 * RWKV_CFG.n_embd
self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False)
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
class RWKV_TimeMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1))
self.time_curve = torch.tensor([-(RWKV_CFG.ctx_len - 2 - i) for i in range(RWKV_CFG.ctx_len-1)]).unsqueeze(0)
self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1) * math.log(0.3))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
def forward(self, x):
B, T, C = x.size()
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
k = torch.clamp(k, max=RWKV_K_CLAMP)
k = torch.exp(k)
kv = k * v
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
w = torch.exp(self.time_w)
w = w[:,-T:].unsqueeze(1)
wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + RWKV_K_EPS
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv)
return rwkv
class Block(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd)
self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd)
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(layer_id+1000)
else:
self.att = RWKV_TimeMix(layer_id)
self.ffn = RWKV_ChannelMix(layer_id)
def forward(self, x):
if self.layer_id == 0:
x = self.ln0(x)
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(x)
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class RWKV_GPT(nn.Module):
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len):
global RWKV_CFG
super().__init__()
RWKV_CFG.RUN_DEVICE = RUN_DEVICE
RWKV_CFG.model_type = model_type
RWKV_CFG.vocab_size = vocab_size
RWKV_CFG.n_layer = n_layer
RWKV_CFG.n_embd = n_embd
RWKV_CFG.ctx_len = ctx_len
print('\nloading RWKV-GPT', MODEL_NAME)
self.emb = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
self.ln_out = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
if RWKV_HEAD_QK_DIM > 0:
self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(ctx_len, ctx_len)))
self.ctx_len = ctx_len
self.eval()
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
self.eval()
def forward(self, idx):
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
if RWKV_HEAD_QK_DIM > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).float()
x = self.head(x) + c
else:
x = self.head(x)
return x
############################################################################################################
class RWKV_RNN():
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
self.RUN_DEVICE = RUN_DEVICE
self.model_type = model_type
self.n_layer = n_layer
self.n_embd = n_embd
self.ctx_len = ctx_len
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth',
map_location=torch.device(RUN_DEVICE))
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
if DEBUG_TIME and '.time_' in x:
print(x, w[x].squeeze().cpu().numpy())
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
self.hk = None
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
target.hk = copy.deepcopy(self.hk)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
self.hk = copy.deepcopy(target.hk)
def LN(self, xx, w):
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.square(torch.relu(w.key.weight @ xk))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.exp(torch.clamp(w.key.weight @ xk, max=RWKV_K_CLAMP))
v = w.value.weight @ xv
kv = k * v
a = self.aa[name] + w.time_first * kv
b = self.bb[name] + w.time_first * k
self.aa[name] = w.time_decay * self.aa[name] + kv
self.bb[name] = w.time_decay * self.bb[name] + k
rwkv = r * a / (b + RWKV_K_EPS)
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
for i in range(self.n_layer):
if i == 0:
x = self.LN(x, w.blocks[i].ln0)
if i == 0 and self.model_type == 'RWKV-ffnPre':
x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
else:
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
if RWKV_HEAD_QK_DIM > 0:
if self.hk == None:
self.hk = (w.head_k.weight @ x).unsqueeze(0)
else:
self.hk = torch.cat(
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
if self.hk.shape[0] > self.ctx_len:
self.hk = self.hk[-self.ctx_len:, :]
q = w.head_q.weight @ x
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
for i in range(len(c)):
x[ctx[i]] += c[i]
else:
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
return x

@ -0,0 +1,170 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
from torch.utils.data.dataloader import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import functional as F
import torch.nn as nn
import torch.optim as optim
import torch
from tqdm.auto import tqdm
import numpy as np
import logging
import os
import datetime
import sys
import math
# import wandb # comment this if you don't have wandb
# print('logging to wandb... (comment it if you don\'t have wandb)')
logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
log_file = open("mylog.txt", "a")
class TrainerConfig:
max_epochs = 10
batch_size = 64
learning_rate = 4e-4
betas = (0.9, 0.99)
eps = 1e-8
grad_norm_clip = 1.0
lr_decay = True # linear warmup followed by cosine decay
warmup_tokens = 0
final_tokens = 0
epoch_save_frequency = 0
epoch_save_path = 'trained-'
num_workers = 0 # for DataLoader
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class Trainer:
def __init__(self, model, train_dataset, test_dataset, config):
self.model = model
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.config = config
self.avg_loss = -1
self.steps = 0
if 'wandb' in sys.modules:
cfg = model.config
for k in config.__dict__:
setattr(cfg, k, config.__dict__[k]) # combine cfg
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' +
datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
self.device = 'cpu'
if torch.cuda.is_available(): # take over whatever gpus are on the system
self.device = torch.cuda.current_device()
def get_run_name(self):
raw_model = self.model.module if hasattr(
self.model, "module") else self.model
cfg = raw_model.config
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
return run_name
def train(self):
model, config = self.model, self.config
raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config)
def run_epoch(split):
is_train = split == 'train'
model.train(is_train)
data = self.train_dataset if is_train else self.test_dataset
if config.num_workers > 0:
loader = DataLoader(data, shuffle=False, pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers)
else:
loader = DataLoader(data, shuffle=False,
batch_size=config.batch_size,
num_workers=config.num_workers)
pbar = tqdm(enumerate(loader), total=len(
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
for it, (x, y) in pbar:
x = x.to(self.device) # place data on the correct device
y = y.to(self.device)
with torch.set_grad_enabled(is_train):
_, loss = model(x, y) # forward the model
if is_train: # backprop and update the parameters
model.zero_grad()
loss.backward()
if config.grad_norm_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), config.grad_norm_clip)
optimizer.step()
if config.lr_decay: # decay the learning rate based on our progress
# number of tokens processed this step (i.e. label is not -100)
self.tokens += (y >= 0).sum()
lr_final_factor = config.lr_final / config.learning_rate
if self.tokens < config.warmup_tokens:
# linear warmup
lr_mult = lr_final_factor + \
(1 - lr_final_factor) * float(self.tokens) / \
float(config.warmup_tokens)
progress = 0
else:
# cosine learning rate decay
progress = float(self.tokens - config.warmup_tokens) / float(
max(1, config.final_tokens - config.warmup_tokens))
lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor /
2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1
lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
lr = config.learning_rate
now_loss = loss.item() # report progress
self.lr = lr
if 'wandb' in sys.modules:
wandb.log({"loss": now_loss},
step=self.steps * self.config.batch_size)
self.steps += 1
if self.avg_loss < 0:
self.avg_loss = now_loss
else:
factor = 1 / (it + 1)
self.avg_loss = self.avg_loss * \
(1.0 - factor) + now_loss * factor
pbar.set_description(
f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):
run_epoch('train')
log_file.write(
f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n')
log_file.flush()
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
# DataParallel wrappers keep raw model object in .module
raw_model = self.model.module if hasattr(
self.model, "module") else self.model
torch.save(raw_model.state_dict(),
self.config.epoch_save_path + str(epoch+1) + '.pth')

@ -0,0 +1,122 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import json
import random
import time
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
class Dataset(Dataset):
def __init__(self, data, ctx_len, epoch_length_fixed):
print('building token list...', end=' ')
unique = sorted(list(set(data)))
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
data_size, vocab_size = len(data), len(unique)
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.vocab_size = vocab_size
self.data = data
def __len__(self):
return self.epoch_length_fixed
def __getitem__(self, idx):
# cheat: pick a random spot in dataset
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long,
device=torch.device('cuda'))
y = torch.tensor(dix[1:], dtype=torch.long,
device=torch.device('cuda'))
return x, y
class TOKENIZER():
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
self.word_table = json.load(result_file)
self.vocab_size = len(self.word_table)
self.stoi = {v: int(k) for k, v in self.word_table.items()}
self.itos = {int(k): v for k, v in self.word_table.items()}
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
# out[self.UNKNOWN_CHAR] = -float('Inf')
lastChar = int(x[-1])
probs = F.softmax(torch.tensor(out), dim=-1)
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
sorted_probs, s_index = torch.sort(probs, descending=True)
# for j in range(30):
# pp = sorted_probs[j].item()
# if pp < 0.005:
# break
# ss = self.itos[int(s_index[j])].replace('\n','_')
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
# print('')
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
return torch.multinomial(probs, num_samples=1)[0]
def to_float(x):
return x.cpu().detach().numpy().flatten()[0].astype(float)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

@ -0,0 +1,108 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os
if True: # True False ---> Set to False if you don't understand it
print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import src.utils
src.utils.set_seed(42) # make training deterministic (including dataloader). if you are doing this, remember to change seed when you load a model (otherwise the dataloader loads old samples)
import logging
import datetime
from src.model import GPT, GPTConfig
from src.trainer import Trainer, TrainerConfig
from src.utils import Dataset
import torch
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
### Step 1: set training data ##########################################################################
datafile = "../data/enwik8" # your data
datafile_encoding = 'utf-8'
# datafile_encoding = 'utf-16le'
### Step 2: set model size #############################################################################
# ----> test deeper models (n_layer at least 12) to see the advantage of RWKV-3 over RWKV-2
ctx_len = 1024 # increase T_MAX in model.py if your ctx_len > 1024
n_layer = 6
n_embd = 512
# 'RWKV' (better for English) or 'RWKV-ffnPre' (better in some cases)
model_type = 'RWKV'
# ---> there is a RWKV_HEAD_QK_DIM in model.py and model_run.py
# set it to 256, then it's using my headQK trick (similar to a tiny attention) to improve loss
# set it to 0, then it's a pure RNN (attention-free)
### Step 3: set batch size #############################################################################
# ---> batch_size must be divisible by B_GROUP_FORWARD and B_GROUP_BACKWARD in model.py
# for example, if your batch_size = 20, you can set B_GROUP_FORWARD = 4, B_GROUP_BACKWARD = 2
# if you see "CUDA out of memory", reduce batch_size. Use nvidia-smi to find the highest value for your GPU.
batch_size = 12
### Step 4: set learning rate, number of mini-epochs #######################################################
lr_init = 8e-4 # we can use larger lr because of preLN
lr_final = 1e-5
# the mini-epoch is very short and of fixed length (length = ctx_len * epoch_length_fixed tokens)
n_epoch = 500
epoch_length_fixed = 10000
# 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, ...
epoch_save_frequency = 10
epoch_save_path = 'trained-'
########################################################################################################
grad_norm_clip = 1.0
warmup_tokens = 0
betas = (0.9, 0.99)
eps = 4e-9
num_workers = 0
########################################################################################################
# Load data
########################################################################################################
print('loading data... ' + datafile)
train_dataset = Dataset(open(
datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)
########################################################################################################
# Train model
########################################################################################################
if __name__ == '__main__':
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
n_layer=n_layer, n_embd=n_embd)).cuda()
### load a trained model
# m2 = torch.load('trained-61.pth')
# model.load_state_dict(m2)
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, )
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip,
warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()
torch.save(model.state_dict(), 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() +
'-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')

@ -0,0 +1,65 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
# this is for verifying the results of different models and make sure they agree with each other
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
RUN_DEVICE = 'cuda'
import torch
from src.model_run import RWKV_RNN, RWKV_GPT
from src.model import GPT, GPTConfig
ctx_len = 1024
n_layer = 6
n_embd = 512
model_type = 'RWKV'
model_name = 'trained-1'
from src.utils import TOKENIZER
tokenizer = TOKENIZER('vocab', UNKNOWN_CHAR=' ')
########################################################################################################
model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda()
print('loading ' + model_name)
m2 = torch.load(model_name + '.pth', map_location=RUN_DEVICE)
model_train.load_state_dict(m2)
model_rnn = RWKV_RNN(model_name, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
model_gpt = RWKV_GPT(model_name, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len).cuda()
########################################################################################################
context = '\nIn a'
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
print(f'input len {len(ctx)} data {ctx}')
########################################################################################################
print('\nRWKV-GPT output')
out = model_gpt.forward(torch.tensor(ctx).unsqueeze(0).cuda())[0].detach().cpu().numpy()
print(out)
print('\nRWKV-RNN output')
model_rnn.clear()
src_len = len(ctx)
for i in range(src_len):
x = ctx[:i+1]
out = model_rnn.run(x)
if i < 3 or i >= src_len - 3:
print(torch.tensor(out).detach().cpu().numpy())
if i == 2:
print('...')
print('\nRWKV-train output')
ctx += [0] * (ctx_len - src_len) # pad to ctx_len
ctx = [ctx] * 4 # increase batch size (to make it work with B_GROUP_FORWARD & B_GROUP_BACKWARD)
out = model_train.forward(torch.tensor(ctx).cuda())[0][0][:src_len].detach().cpu().numpy()
print(out, '\n')
Loading…
Cancel
Save