RWKV v2 RNN is here. Probably the strongest LM as of now.
parent
1f189a4034
commit
5f21ddf20d
@ -0,0 +1,172 @@
|
||||
#include <stdio.h>
|
||||
|
||||
// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)
|
||||
|
||||
#define F4(A, B) ((float4 *)(A))[(B) >> 2]
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,
|
||||
const F eps, const int B, const int C, const int T) {
|
||||
const int i = blockIdx.y;
|
||||
const int ij = (B * C) / BF;
|
||||
const int t = threadIdx.x << 2;
|
||||
|
||||
__shared__ F ww[Tmax];
|
||||
__shared__ F kk[Tmax * BF];
|
||||
F4(ww, t) = F4(__w, t + T * (i % C));
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BF; j++) {
|
||||
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float4 s[BF];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BF; j++) {
|
||||
s[j] = {eps, eps, eps, eps};
|
||||
}
|
||||
const F *__restrict__ const w = ww + T - t - 4;
|
||||
for (int u = 0; u <= t; u++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BF; j++) {
|
||||
const F x = kk[u + Tmax * j];
|
||||
s[j].x += w[u + 3] * x;
|
||||
s[j].y += w[u + 2] * x;
|
||||
s[j].z += w[u + 1] * x;
|
||||
s[j].w += w[u + 0] * x;
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BF; j++) {
|
||||
const F *__restrict__ const k = kk + Tmax * j;
|
||||
s[j].y += w[t + 3] * k[t + 1];
|
||||
s[j].z += w[t + 2] * k[t + 1];
|
||||
s[j].z += w[t + 3] * k[t + 2];
|
||||
s[j].w += w[t + 1] * k[t + 1];
|
||||
s[j].w += w[t + 2] * k[t + 2];
|
||||
s[j].w += w[t + 3] * k[t + 3];
|
||||
F4(x, t + T * (i + ij * j)) = s[j];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_backward_W(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
|
||||
F *__restrict__ const gw, F *__restrict__ const gk,
|
||||
const int B, const int C, const int T) {
|
||||
const int i = blockIdx.y;
|
||||
const int t = threadIdx.x << 2;
|
||||
|
||||
__shared__ F k[Tmax];
|
||||
__shared__ F gg[Tmax];
|
||||
F4(k, t) = F4(__k, t + T * i);
|
||||
F4(gg, t) = F4(__gwk, t + T * i);
|
||||
__syncthreads();
|
||||
|
||||
float4 s = {0, 0, 0, 0};
|
||||
|
||||
const F *__restrict__ const g = gg + T - t - 4;
|
||||
for (int u = 0; u <= t; u++) {
|
||||
F x = k[u];
|
||||
s.x += g[u + 3] * x;
|
||||
s.y += g[u + 2] * x;
|
||||
s.z += g[u + 1] * x;
|
||||
s.w += g[u + 0] * x;
|
||||
}
|
||||
s.y += g[t + 3] * k[t + 1];
|
||||
s.z += g[t + 2] * k[t + 1];
|
||||
s.z += g[t + 3] * k[t + 2];
|
||||
s.w += g[t + 1] * k[t + 1];
|
||||
s.w += g[t + 2] * k[t + 2];
|
||||
s.w += g[t + 3] * k[t + 3];
|
||||
F4(gw, t + T * i) = s;
|
||||
}
|
||||
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) {
|
||||
dim3 gridDim(1, B * C / BF);
|
||||
dim3 blockDim(T >> 2);
|
||||
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
|
||||
F *__restrict__ const gw, F *__restrict__ const gk,
|
||||
const int B, const int C, const int T) {
|
||||
const int i = blockIdx.y;
|
||||
const int ij = (B * C) / BB;
|
||||
const int t = threadIdx.x << 2;
|
||||
|
||||
__shared__ F w[Tmax];
|
||||
__shared__ F kk[Tmax * BB];
|
||||
__shared__ F gg[Tmax * BB];
|
||||
F4(w, t) = F4(__w, t + T * (i % C));
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BB; j++) {
|
||||
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
|
||||
F4(gg, t + Tmax * j) = F4(__gwk, t + T * (i + ij * j));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float4 s[BB];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BB; j++) {
|
||||
s[j] = {0, 0, 0, 0};
|
||||
}
|
||||
|
||||
for (int u = 0; u <= t; u++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BB; j++) {
|
||||
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
|
||||
F x = kk[u + Tmax * j];
|
||||
s[j].x += g[u + 3] * x;
|
||||
s[j].y += g[u + 2] * x;
|
||||
s[j].z += g[u + 1] * x;
|
||||
s[j].w += g[u + 0] * x;
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BB; j++) {
|
||||
const F *__restrict__ const k = kk + Tmax * j;
|
||||
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
|
||||
s[j].y += g[t + 3] * k[t + 1];
|
||||
s[j].z += g[t + 2] * k[t + 1];
|
||||
s[j].z += g[t + 3] * k[t + 2];
|
||||
s[j].w += g[t + 1] * k[t + 1];
|
||||
s[j].w += g[t + 2] * k[t + 2];
|
||||
s[j].w += g[t + 3] * k[t + 3];
|
||||
F4(gw, t + T * (i + ij * j)) = s[j];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BB; j++) {
|
||||
s[j] = {0, 0, 0, 0};
|
||||
}
|
||||
|
||||
for (int u = t + 3; u < T; u++) {
|
||||
F x = w[u];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BB; j++) {
|
||||
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
|
||||
s[j].x += g[2 - u] * x;
|
||||
s[j].y += g[3 - u] * x;
|
||||
s[j].z += g[4 - u] * x;
|
||||
s[j].w += g[5 - u] * x;
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BB; j++) {
|
||||
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
|
||||
s[j].x += g[2 - t] * w[t + 0];
|
||||
s[j].x += g[1 - t] * w[t + 1];
|
||||
s[j].x += g[0 - t] * w[t + 2];
|
||||
s[j].y += g[2 - t] * w[t + 1];
|
||||
s[j].y += g[1 - t] * w[t + 2];
|
||||
s[j].z += g[2 - t] * w[t + 2];
|
||||
F4(gk, t + T * (i + ij * j)) = s[j];
|
||||
}
|
||||
}
|
||||
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) {
|
||||
dim3 gridDim(1, B * C / BB);
|
||||
dim3 blockDim(T >> 2);
|
||||
kernel_backward<<<gridDim, blockDim>>>(w, k, gwk, gw, gk, B, C, T);
|
||||
}
|
||||
@ -0,0 +1,21 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T);
|
||||
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T);
|
||||
|
||||
void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) {
|
||||
cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T);
|
||||
}
|
||||
void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) {
|
||||
cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &forward, "timex forward");
|
||||
m.def("backward", &backward, "timex backward");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(timex, m) {
|
||||
m.def("forward", forward);
|
||||
m.def("backward", backward);
|
||||
}
|
||||
@ -0,0 +1,349 @@
|
||||
########################################################################################################
|
||||
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
import math
|
||||
import numpy as np
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
########################################################################################################
|
||||
# CUDA Kernel
|
||||
########################################################################################################
|
||||
|
||||
T_MAX = 1024 # increase this if your ctx_len > 1024
|
||||
B_GROUP_FORWARD = 8 # set to 8 for best performance
|
||||
B_GROUP_BACKWARD = 2 # set to 2 for best performance
|
||||
|
||||
timex_cuda = load(name="timex", sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"],
|
||||
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}'])
|
||||
|
||||
|
||||
class TimeX(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, w, k, B, C, T, eps):
|
||||
ctx.B = B
|
||||
ctx.C = C
|
||||
ctx.T = T
|
||||
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
|
||||
w = w.contiguous()
|
||||
k = k.contiguous()
|
||||
ctx.save_for_backward(w, k)
|
||||
wk = torch.empty((B, C, T), device='cuda',
|
||||
memory_format=torch.contiguous_format)
|
||||
timex_cuda.forward(w, k, wk, eps, B, C, T)
|
||||
return wk
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gwk):
|
||||
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
|
||||
w, k = ctx.saved_tensors
|
||||
gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
|
||||
memory_format=torch.contiguous_format)
|
||||
gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
|
||||
memory_format=torch.contiguous_format)
|
||||
timex_cuda.backward(w, k, gwk.contiguous(), gw,
|
||||
gk, ctx.B, ctx.C, ctx.T)
|
||||
return (gw.sum(dim=0), gk, None, None, None, None)
|
||||
|
||||
########################################################################################################
|
||||
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
||||
########################################################################################################
|
||||
|
||||
|
||||
RWKV_K_CLAMP = 60 # e^60 = 1e26
|
||||
RWKV_K_EPS = 1e-16
|
||||
RWKV_HEAD_QK_DIM = 256
|
||||
|
||||
|
||||
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
|
||||
for m in module.modules():
|
||||
if not isinstance(m, (nn.Linear, nn.Embedding)):
|
||||
continue
|
||||
with torch.no_grad():
|
||||
name = '[unknown weight]'
|
||||
for name, parameter in module.named_parameters(): # find the name of the weight
|
||||
if id(m.weight) == id(parameter):
|
||||
break
|
||||
|
||||
shape = m.weight.data.shape
|
||||
gain = 1.0
|
||||
scale = 1.0 # extra scale for gain
|
||||
|
||||
if isinstance(m, nn.Embedding):
|
||||
gain = math.sqrt(max(shape[0], shape[1]))
|
||||
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
|
||||
scale = 1e-4
|
||||
else:
|
||||
scale = 0
|
||||
|
||||
if isinstance(m, nn.Linear):
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
if shape[0] > shape[1]:
|
||||
gain = math.sqrt(shape[0] / shape[1])
|
||||
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
|
||||
scale = 0.5
|
||||
|
||||
if hasattr(m, 'scale_init'):
|
||||
scale = m.scale_init
|
||||
|
||||
# print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
|
||||
|
||||
gain *= scale
|
||||
if scale == -999:
|
||||
nn.init.eye_(m.weight)
|
||||
elif gain == 0:
|
||||
# zero init is great for some RWKV matrices
|
||||
nn.init.zeros_(m.weight)
|
||||
elif gain > 0:
|
||||
nn.init.orthogonal_(m.weight, gain=gain)
|
||||
else:
|
||||
nn.init.normal_(m.weight, mean=0.0, std=-scale)
|
||||
|
||||
|
||||
class RWKV_TimeMix(nn.Module):
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.ctx_len = config.ctx_len
|
||||
self.n_embd = config.n_embd
|
||||
|
||||
attn_sz = config.n_embd
|
||||
|
||||
############# fancy init of time_w curves ###################################
|
||||
f1_begin = 3.0
|
||||
f1_end = 1.2
|
||||
f2_begin = 0.65
|
||||
f2_end = 0.4
|
||||
with torch.no_grad(): # initial time_w curves for better convergence
|
||||
decay_speed = torch.ones(attn_sz, 1)
|
||||
first_sa_layer_id = 1
|
||||
for h in range(attn_sz):
|
||||
f1 = f1_begin + (layer_id-first_sa_layer_id) / \
|
||||
(config.n_layer-1-first_sa_layer_id) * (f1_end - f1_begin)
|
||||
f2 = f2_begin + (layer_id-first_sa_layer_id) / \
|
||||
(config.n_layer-1-first_sa_layer_id) * (f2_end - f2_begin)
|
||||
if layer_id == first_sa_layer_id:
|
||||
f1 += 0.5
|
||||
if layer_id == config.n_layer-2:
|
||||
f2 = 0.4
|
||||
if layer_id == config.n_layer-1:
|
||||
f2 = 0.37
|
||||
decay_speed[h][0] = math.pow(f2, h / (attn_sz-1) * 7) * f1
|
||||
self.time_decay = nn.Parameter(torch.log(decay_speed)) # will use exp(self.time_decay) to ensure time_decay > 0
|
||||
self.time_curve = torch.tensor(
|
||||
[-(config.ctx_len - 2 - i) for i in range(config.ctx_len-1)]).unsqueeze(0)
|
||||
self.time_curve = self.time_curve.to('cuda')
|
||||
self.time_first = nn.Parameter(torch.ones(attn_sz, 1) * math.log(0.3))
|
||||
#############################################################################
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
with torch.no_grad(): # init to "shift half of the channels"
|
||||
ww = torch.ones(1, 1, config.n_embd)
|
||||
for i in range(config.n_embd // 2):
|
||||
ww[0, 0, i] = 0
|
||||
self.time_mix = nn.Parameter(ww)
|
||||
|
||||
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
|
||||
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
|
||||
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
|
||||
|
||||
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
|
||||
|
||||
self.key.scale_init = 0
|
||||
self.receptance.scale_init = 0
|
||||
self.output.scale_init = 0
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size()
|
||||
|
||||
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
|
||||
|
||||
k = self.key(x).transpose(-1, -2)
|
||||
v = self.value(x).transpose(-1, -2)
|
||||
r = self.receptance(x)
|
||||
|
||||
# RWKV_K_CLAMP can be removed if the CUDA kernel substracts the correct k_max for each k (I will do this later)
|
||||
k = torch.clamp(k, max=RWKV_K_CLAMP)
|
||||
k = torch.exp(k)
|
||||
kv = k * v
|
||||
|
||||
self.time_w = torch.cat(
|
||||
[torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1)
|
||||
w = torch.exp(self.time_w)
|
||||
|
||||
wkv = TimeX.apply(w, kv, B, C, T, 0)
|
||||
# RWKV_K_EPS can be removed if the CUDA kernel sets 0/0 = 0 (I will do this later)
|
||||
wk = TimeX.apply(w, k, B, C, T, RWKV_K_EPS)
|
||||
|
||||
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
|
||||
rwkv = self.output(rwkv)
|
||||
return rwkv
|
||||
|
||||
|
||||
class RWKV_ChannelMix(nn.Module):
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
with torch.no_grad(): # init to "shift half of the channels"
|
||||
x = torch.ones(1, 1, config.n_embd)
|
||||
for i in range(config.n_embd // 2):
|
||||
x[0, 0, i] = 0
|
||||
self.time_mix = nn.Parameter(x)
|
||||
|
||||
hidden_sz = 4 * config.n_embd
|
||||
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
|
||||
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
||||
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
|
||||
|
||||
self.value.scale_init = 0
|
||||
self.receptance.scale_init = 0
|
||||
|
||||
def forward(self, x):
|
||||
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
|
||||
|
||||
k = self.key(x)
|
||||
k = torch.square(torch.relu(k))
|
||||
kv = self.value(k)
|
||||
|
||||
rkv = torch.sigmoid(self.receptance(x)) * kv
|
||||
return rkv
|
||||
|
||||
########################################################################################################
|
||||
# The GPT Model with our blocks
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class GPTConfig:
|
||||
def __init__(self, vocab_size, ctx_len, **kwargs):
|
||||
self.vocab_size = vocab_size
|
||||
self.ctx_len = ctx_len
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.ln1 = nn.LayerNorm(config.n_embd)
|
||||
self.ln2 = nn.LayerNorm(config.n_embd)
|
||||
|
||||
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
|
||||
self.ffnPre = RWKV_ChannelMix(config, layer_id+1000)
|
||||
else:
|
||||
self.att = RWKV_TimeMix(config, layer_id)
|
||||
|
||||
self.ffn = RWKV_ChannelMix(config, layer_id)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.ln1(x)
|
||||
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
|
||||
x = x + self.ffnPre(x) # better in some cases
|
||||
else:
|
||||
x = x + self.att(x)
|
||||
x = self.ln2(x)
|
||||
x = x + self.ffn(x)
|
||||
return x
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.step = 0
|
||||
self.config = config
|
||||
|
||||
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
|
||||
self.blocks = nn.Sequential(*[Block(config, i)
|
||||
for i in range(config.n_layer)])
|
||||
|
||||
self.ln_out = nn.LayerNorm(config.n_embd)
|
||||
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
||||
self.head_q.scale_init = 0
|
||||
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
||||
self.head_k.scale_init = 0.1
|
||||
self.register_buffer("copy_mask", torch.tril(
|
||||
torch.ones(config.ctx_len, config.ctx_len)))
|
||||
|
||||
self.ctx_len = config.ctx_len
|
||||
|
||||
RWKV_Init(self, config)
|
||||
|
||||
logger.info("number of parameters: %e", sum(p.numel()
|
||||
for p in self.parameters()))
|
||||
|
||||
def get_ctx_len(self):
|
||||
return self.ctx_len
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear)):
|
||||
module.weight.data.normal_(mean=0.0, std=0.01)
|
||||
if isinstance(module, (nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=1e-5)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def configure_optimizers(self, train_config):
|
||||
# separate out all parameters to those that will and won't experience regularizing weight decay
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
|
||||
for mn, m in self.named_modules(): # here we disable weight_decay
|
||||
for pn, p in m.named_parameters():
|
||||
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
||||
no_decay.add(fpn)
|
||||
|
||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(
|
||||
inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
||||
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
||||
% (str(param_dict.keys() - union_params), )
|
||||
|
||||
optim_groups = [
|
||||
{"params": [param_dict[pn]
|
||||
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
optimizer = torch.optim.Adam(
|
||||
optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
|
||||
|
||||
return optimizer
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
self.step += 1
|
||||
B, T = idx.size()
|
||||
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
|
||||
x = self.emb(idx)
|
||||
|
||||
x = self.blocks(x)
|
||||
|
||||
x = self.ln_out(x)
|
||||
|
||||
q = self.head_q(x)[:, :T, :]
|
||||
k = self.head_k(x)[:, :T, :]
|
||||
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
|
||||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
||||
|
||||
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).float()
|
||||
x = self.head(x) + c
|
||||
|
||||
loss = None
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
|
||||
|
||||
return x, loss
|
||||
@ -0,0 +1,170 @@
|
||||
########################################################################################################
|
||||
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.nn import functional as F
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
import logging
|
||||
import os
|
||||
import datetime
|
||||
import sys
|
||||
import math
|
||||
|
||||
# import wandb # comment this if you don't have wandb
|
||||
# print('logging to wandb... (comment it if you don\'t have wandb)')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
log_file = open("mylog.txt", "a")
|
||||
|
||||
|
||||
class TrainerConfig:
|
||||
max_epochs = 10
|
||||
batch_size = 64
|
||||
learning_rate = 4e-4
|
||||
betas = (0.9, 0.99)
|
||||
eps = 1e-8
|
||||
grad_norm_clip = 1.0
|
||||
lr_decay = True # linear warmup followed by cosine decay
|
||||
warmup_tokens = 0
|
||||
final_tokens = 0
|
||||
epoch_save_frequency = 0
|
||||
epoch_save_path = 'trained-'
|
||||
num_workers = 0 # for DataLoader
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
||||
def __init__(self, model, train_dataset, test_dataset, config):
|
||||
self.model = model
|
||||
self.train_dataset = train_dataset
|
||||
self.test_dataset = test_dataset
|
||||
self.config = config
|
||||
self.avg_loss = -1
|
||||
self.steps = 0
|
||||
|
||||
if 'wandb' in sys.modules:
|
||||
cfg = model.config
|
||||
for k in config.__dict__:
|
||||
setattr(cfg, k, config.__dict__[k]) # combine cfg
|
||||
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' +
|
||||
datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
|
||||
|
||||
self.device = 'cpu'
|
||||
if torch.cuda.is_available(): # take over whatever gpus are on the system
|
||||
self.device = torch.cuda.current_device()
|
||||
|
||||
def get_run_name(self):
|
||||
raw_model = self.model.module if hasattr(
|
||||
self.model, "module") else self.model
|
||||
cfg = raw_model.config
|
||||
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
|
||||
cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
|
||||
return run_name
|
||||
|
||||
def train(self):
|
||||
model, config = self.model, self.config
|
||||
raw_model = model.module if hasattr(self.model, "module") else model
|
||||
optimizer = raw_model.configure_optimizers(config)
|
||||
|
||||
def run_epoch(split):
|
||||
is_train = split == 'train'
|
||||
model.train(is_train)
|
||||
data = self.train_dataset if is_train else self.test_dataset
|
||||
|
||||
if config.num_workers > 0:
|
||||
loader = DataLoader(data, shuffle=False, pin_memory=True,
|
||||
batch_size=config.batch_size,
|
||||
num_workers=config.num_workers)
|
||||
else:
|
||||
loader = DataLoader(data, shuffle=False,
|
||||
batch_size=config.batch_size,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
pbar = tqdm(enumerate(loader), total=len(
|
||||
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
|
||||
|
||||
for it, (x, y) in pbar:
|
||||
x = x.to(self.device) # place data on the correct device
|
||||
y = y.to(self.device)
|
||||
|
||||
with torch.set_grad_enabled(is_train):
|
||||
_, loss = model(x, y) # forward the model
|
||||
|
||||
if is_train: # backprop and update the parameters
|
||||
model.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
if config.grad_norm_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), config.grad_norm_clip)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
if config.lr_decay: # decay the learning rate based on our progress
|
||||
# number of tokens processed this step (i.e. label is not -100)
|
||||
self.tokens += (y >= 0).sum()
|
||||
lr_final_factor = config.lr_final / config.learning_rate
|
||||
if self.tokens < config.warmup_tokens:
|
||||
# linear warmup
|
||||
lr_mult = lr_final_factor + \
|
||||
(1 - lr_final_factor) * float(self.tokens) / \
|
||||
float(config.warmup_tokens)
|
||||
progress = 0
|
||||
else:
|
||||
# cosine learning rate decay
|
||||
progress = float(self.tokens - config.warmup_tokens) / float(
|
||||
max(1, config.final_tokens - config.warmup_tokens))
|
||||
lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor /
|
||||
2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1
|
||||
lr = config.learning_rate * lr_mult
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
else:
|
||||
lr = config.learning_rate
|
||||
|
||||
now_loss = loss.item() # report progress
|
||||
self.lr = lr
|
||||
|
||||
if 'wandb' in sys.modules:
|
||||
wandb.log({"loss": now_loss},
|
||||
step=self.steps * self.config.batch_size)
|
||||
self.steps += 1
|
||||
|
||||
if self.avg_loss < 0:
|
||||
self.avg_loss = now_loss
|
||||
else:
|
||||
factor = 1 / (it + 1)
|
||||
self.avg_loss = self.avg_loss * \
|
||||
(1.0 - factor) + now_loss * factor
|
||||
pbar.set_description(
|
||||
f"epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
|
||||
|
||||
self.tokens = 0 # counter used for learning rate decay
|
||||
for epoch in range(config.max_epochs):
|
||||
|
||||
run_epoch('train')
|
||||
|
||||
log_file.write(
|
||||
f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n')
|
||||
log_file.flush()
|
||||
|
||||
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
|
||||
# DataParallel wrappers keep raw model object in .module
|
||||
raw_model = self.model.module if hasattr(
|
||||
self.model, "module") else self.model
|
||||
torch.save(raw_model.state_dict(),
|
||||
self.config.epoch_save_path + str(epoch+1) + '.pth')
|
||||
@ -0,0 +1,80 @@
|
||||
########################################################################################################
|
||||
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class TOKENIZER():
|
||||
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
|
||||
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
|
||||
self.word_table = json.load(result_file)
|
||||
|
||||
self.vocab_size = len(self.word_table)
|
||||
|
||||
self.stoi = {v: int(k) for k, v in self.word_table.items()}
|
||||
self.itos = {int(k): v for k, v in self.word_table.items()}
|
||||
|
||||
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
||||
|
||||
def refine_context(self, context):
|
||||
context = context.strip().split('\n')
|
||||
for c in range(len(context)):
|
||||
context[c] = context[c].strip().strip('\u3000').strip('\r')
|
||||
context = list(filter(lambda c: c != '', context))
|
||||
context = '\n' + ('\n'.join(context)).strip()
|
||||
if context == '':
|
||||
context = '\n'
|
||||
|
||||
return context
|
||||
|
||||
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
|
||||
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
||||
|
||||
lastChar = int(x[-1])
|
||||
|
||||
probs = F.softmax(torch.tensor(out), dim=-1)
|
||||
|
||||
if self.itos[lastChar] == '\n':
|
||||
top_p = top_p_newline
|
||||
else:
|
||||
top_p = top_p_usual
|
||||
|
||||
sorted_probs, s_index = torch.sort(probs, descending=True)
|
||||
|
||||
# for j in range(30):
|
||||
# pp = sorted_probs[j].item()
|
||||
# if pp < 0.005:
|
||||
# break
|
||||
# ss = self.itos[int(s_index[j])].replace('\n','_')
|
||||
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
|
||||
# print('')
|
||||
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
|
||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||
|
||||
probs[probs < cutoff] = 0
|
||||
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
|
||||
|
||||
if temperature != 1.0:
|
||||
probs = probs.pow(1.0 / temperature)
|
||||
|
||||
return torch.multinomial(probs, num_samples=1)[0]
|
||||
|
||||
|
||||
def to_float(x):
|
||||
return x.cpu().detach().numpy().flatten()[0].astype(float)
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
@ -0,0 +1,141 @@
|
||||
########################################################################################################
|
||||
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import logging
|
||||
import datetime
|
||||
import json
|
||||
from src.model import GPT, GPTConfig
|
||||
from src.trainer import Trainer, TrainerConfig
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
import numpy as np
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
### Step 1: set training data ##########################################################################
|
||||
|
||||
datafile = "enwik8"
|
||||
datafile_encoding = 'utf-8'
|
||||
# datafile_encoding = 'utf-16le'
|
||||
|
||||
### Step 2: set model size #############################################################################
|
||||
|
||||
ctx_len = 1024 # ===> increase T_MAX in model.py if your ctx_len > 1024
|
||||
n_layer = 6
|
||||
n_embd = 512
|
||||
|
||||
# 'RWKV' (better for char-level English) or 'RWKV-ffnPre' (better in some cases)
|
||||
model_type = 'RWKV'
|
||||
|
||||
### Step 3: set batch size #############################################################################
|
||||
|
||||
# ===> batch_size must be divisible by B_GROUP_FORWARD and B_GROUP_BACKWARD in model.py
|
||||
# For example, if your batch_size = 20, you can set B_GROUP_FORWARD = 4, B_GROUP_BACKWARD = 2
|
||||
# If you see "CUDA out of memory", reduce it. Use GPU-Z to find the highest value for your VRAM.
|
||||
batch_size = 40
|
||||
|
||||
### Step 4: set learning rate, training 'epochs' #######################################################
|
||||
|
||||
lr_init = 6e-4
|
||||
lr_final = 1e-5
|
||||
# the 'epoch' here is very short and of fixed length (ctx_len * epoch_length_fixed tokens)
|
||||
n_epoch = 1000
|
||||
# 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc.
|
||||
epoch_save_frequency = 30
|
||||
epoch_save_path = 'trained-'
|
||||
|
||||
epoch_length_fixed = 10000
|
||||
|
||||
########################################################################################################
|
||||
|
||||
|
||||
# import src.utils
|
||||
# src.utils.set_seed(42) # remember to change seed if you load a model
|
||||
|
||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
|
||||
|
||||
grad_norm_clip = 1.0
|
||||
warmup_tokens = 0
|
||||
|
||||
betas = (0.9, 0.99)
|
||||
eps = 4e-9
|
||||
|
||||
num_workers = 0
|
||||
|
||||
########################################################################################################
|
||||
# Load data
|
||||
########################################################################################################
|
||||
|
||||
print('loading data... ' + datafile)
|
||||
|
||||
|
||||
class Dataset(Dataset):
|
||||
def __init__(self, data, ctx_len):
|
||||
print('building token list...', end=' ')
|
||||
unique = sorted(list(set(data)))
|
||||
# print()
|
||||
# for u in unique:
|
||||
# print(u, end=' ')
|
||||
# print('\n\n')
|
||||
|
||||
xx = 0
|
||||
xxObj = {}
|
||||
for u in unique:
|
||||
xxObj[xx] = u
|
||||
xx += 1
|
||||
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
|
||||
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
||||
|
||||
data_size, vocab_size = len(data), len(unique)
|
||||
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
|
||||
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
||||
self.itos = {i: ch for i, ch in enumerate(unique)}
|
||||
self.ctx_len = ctx_len
|
||||
self.vocab_size = vocab_size
|
||||
self.data = data
|
||||
|
||||
def __len__(self):
|
||||
return epoch_length_fixed
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# cheat: pick a random spot in dataset
|
||||
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
|
||||
chunk = self.data[i:i+self.ctx_len+1]
|
||||
dix = [self.stoi[s] for s in chunk]
|
||||
x = torch.tensor(dix[:-1], dtype=torch.long,
|
||||
device=torch.device('cuda'))
|
||||
y = torch.tensor(dix[1:], dtype=torch.long,
|
||||
device=torch.device('cuda'))
|
||||
return x, y
|
||||
|
||||
|
||||
train_dataset = Dataset(
|
||||
open(datafile, "r", encoding=datafile_encoding).read(), ctx_len)
|
||||
|
||||
########################################################################################################
|
||||
# Train model
|
||||
########################################################################################################
|
||||
if __name__ == '__main__':
|
||||
|
||||
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
|
||||
n_layer=n_layer, n_embd=n_embd)).cuda()
|
||||
|
||||
# # load a trained model. remember to change random seed
|
||||
# m2 = torch.load('trained-10000.pth')
|
||||
# model.load_state_dict(m2)
|
||||
|
||||
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
|
||||
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, )
|
||||
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
|
||||
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip,
|
||||
warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
|
||||
trainer = Trainer(model, train_dataset, None, tconf)
|
||||
|
||||
trainer.train()
|
||||
|
||||
torch.save(model, 'trained-' + str(n_epoch) + trainer.get_run_name() +
|
||||
'-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')
|
||||
Loading…
Reference in New Issue