RWKV-4 with DeepSpeed & FP16 & Better CUDA Kernel
parent
dfb75dd89d
commit
165dfd1b9e
@ -0,0 +1,125 @@
|
|||||||
|
#include <stdio.h>
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#define MIN_VALUE (-1e38)
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
__global__ void kernel_forward(const int B, const int T, const int C,
|
||||||
|
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
||||||
|
F *__restrict__ const _y) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int _b = idx / C;
|
||||||
|
const int _c = idx % C;
|
||||||
|
const int _offset = _b * T * C + _c;
|
||||||
|
|
||||||
|
F u = _u[_c];
|
||||||
|
F w = _w[_c];
|
||||||
|
const F *__restrict__ const k = _k + _offset;
|
||||||
|
const F *__restrict__ const v = _v + _offset;
|
||||||
|
F *__restrict__ const y = _y + _offset;
|
||||||
|
|
||||||
|
F p = 0, q = 0, o = MIN_VALUE;
|
||||||
|
// p and q are running sums divided by exp(o) (to avoid overflows)
|
||||||
|
for (int i = 0; i < T; i++) {
|
||||||
|
const int ii = i * C;
|
||||||
|
|
||||||
|
F no = max(o, u + k[ii]);
|
||||||
|
F A = exp(o - no);
|
||||||
|
F B = exp(u + k[ii] - no);
|
||||||
|
y[ii] = (A * p + B * v[ii]) / (A * q + B);
|
||||||
|
|
||||||
|
no = max(w + o, k[ii]);
|
||||||
|
A = exp(w + o - no);
|
||||||
|
B = exp(k[ii] - no);
|
||||||
|
p = A * p + B * v[ii];
|
||||||
|
q = A * q + B;
|
||||||
|
o = no;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
__global__ void kernel_backward(const int B, const int T, const int C,
|
||||||
|
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
|
||||||
|
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int _b = idx / C;
|
||||||
|
const int _c = idx % C;
|
||||||
|
const int _offset = _b * T * C + _c;
|
||||||
|
|
||||||
|
F u = _u[_c];
|
||||||
|
F w = _w[_c];
|
||||||
|
const F *__restrict__ const k = _k + _offset;
|
||||||
|
const F *__restrict__ const v = _v + _offset;
|
||||||
|
const F *__restrict__ const gy = _gy + _offset;
|
||||||
|
|
||||||
|
F *__restrict__ const gk = _gk + _offset;
|
||||||
|
F *__restrict__ const gv = _gv + _offset;
|
||||||
|
|
||||||
|
F y[Tmax], z[Tmax], zexp[Tmax];
|
||||||
|
|
||||||
|
F gw = 0, gu = 0;
|
||||||
|
F p = 0, q = 0;
|
||||||
|
F dpdw = 0, dqdw = 0;
|
||||||
|
F o = MIN_VALUE;
|
||||||
|
for (int i = 0; i < T; i++) {
|
||||||
|
const int ii = i * C;
|
||||||
|
F no = max(o, k[ii] + u);
|
||||||
|
F A = exp(o - no);
|
||||||
|
F B = exp(k[ii] + u - no);
|
||||||
|
|
||||||
|
F num = A * p + B * v[ii];
|
||||||
|
F iden = 1 / (A * q + B);
|
||||||
|
|
||||||
|
y[i] = num * iden;
|
||||||
|
z[i] = iden;
|
||||||
|
zexp[i] = k[ii] + u - no;
|
||||||
|
|
||||||
|
gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
|
||||||
|
gu += gy[ii] * (v[ii] - y[i]) * B * iden;
|
||||||
|
|
||||||
|
no = max(w + o, k[ii]);
|
||||||
|
A = exp(w + o - no);
|
||||||
|
B = exp(k[ii] - no);
|
||||||
|
dpdw = A * (p + dpdw);
|
||||||
|
dqdw = A * (q + dqdw);
|
||||||
|
p = A * p + B * v[ii];
|
||||||
|
q = A * q + B;
|
||||||
|
o = no;
|
||||||
|
}
|
||||||
|
|
||||||
|
F gp = 0, gq = 0;
|
||||||
|
o = MIN_VALUE;
|
||||||
|
for (int i = T - 1; i >= 0; i--) {
|
||||||
|
const int ii = i * C;
|
||||||
|
F A = gy[ii] * z[i] * exp(zexp[i]);
|
||||||
|
F B = exp(k[ii] + o);
|
||||||
|
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
|
||||||
|
gv[ii] = A + B * gp;
|
||||||
|
|
||||||
|
F no = max(w + o, zexp[i] - k[ii] - u);
|
||||||
|
A = exp(w + o - no);
|
||||||
|
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
|
||||||
|
gp = A * gp + B;
|
||||||
|
gq = A * gq - B * y[i];
|
||||||
|
o = no;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
|
||||||
|
const int _offsetBC = _b * C + _c;
|
||||||
|
_gw[_offsetBC] += gw * _w[_c];
|
||||||
|
_gu[_offsetBC] += gu;
|
||||||
|
}
|
||||||
|
|
||||||
|
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
|
||||||
|
dim3 threadsPerBlock( min(C, 1024) );
|
||||||
|
assert(B * C % threadsPerBlock.x == 0);
|
||||||
|
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||||
|
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
|
||||||
|
dim3 threadsPerBlock( min(C, 1024) );
|
||||||
|
assert(B * C % threadsPerBlock.x == 0);
|
||||||
|
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||||
|
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
|
||||||
|
}
|
||||||
@ -0,0 +1,21 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
|
||||||
|
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
|
||||||
|
|
||||||
|
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
||||||
|
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
|
||||||
|
}
|
||||||
|
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
||||||
|
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("forward", &forward, "wkv forward");
|
||||||
|
m.def("backward", &backward, "wkv backward");
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY(wkv, m) {
|
||||||
|
m.def("forward", forward);
|
||||||
|
m.def("backward", backward);
|
||||||
|
}
|
||||||
@ -0,0 +1,37 @@
|
|||||||
|
{
|
||||||
|
"zero_allow_untested_optimizer":true,
|
||||||
|
"zero_optimization":{
|
||||||
|
"stage":2,
|
||||||
|
"contiguous_gradients":true,
|
||||||
|
"overlap_comm":true,
|
||||||
|
"allgather_partitions":true,
|
||||||
|
"reduce_scatter":true,
|
||||||
|
"allgather_bucket_size":200000000,
|
||||||
|
"reduce_bucket_size":200000000,
|
||||||
|
"sub_group_size":1000000000000
|
||||||
|
},
|
||||||
|
"activation_checkpointing":{
|
||||||
|
"partition_activations":false,
|
||||||
|
"cpu_checkpointing":false,
|
||||||
|
"contiguous_memory_optimization":false,
|
||||||
|
"synchronize_checkpoint_boundary":false
|
||||||
|
},
|
||||||
|
"aio":{
|
||||||
|
"block_size":1048576,
|
||||||
|
"queue_depth":8,
|
||||||
|
"single_submit":false,
|
||||||
|
"overlap_events":true,
|
||||||
|
"thread_count":1
|
||||||
|
},
|
||||||
|
"gradient_clipping": 1.0,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"fp16": {
|
||||||
|
"fp16": true,
|
||||||
|
"enabled": true,
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 12,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,98 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from src.utils import TOKENIZER, Dataset
|
||||||
|
from src.model_run import RWKV_RNN
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||||
|
|
||||||
|
### Step 1: set model ##################################################################################
|
||||||
|
|
||||||
|
ctx_len = 1024
|
||||||
|
n_layer = 6
|
||||||
|
n_embd = 512
|
||||||
|
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
|
||||||
|
|
||||||
|
# your trained model
|
||||||
|
MODEL_NAME = 'trained-1'
|
||||||
|
WORD_NAME = 'vocab' # the .json vocab (generated by train.py
|
||||||
|
|
||||||
|
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
|
||||||
|
# --> all unknown tokens in your context will be denoted by it <--
|
||||||
|
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
|
||||||
|
|
||||||
|
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
|
||||||
|
DEBUG_DEBUG = False # True False - show softmax output
|
||||||
|
|
||||||
|
### Step 2: set context ################################################################################
|
||||||
|
|
||||||
|
context = "\nIn the" # ==> this is your prompt
|
||||||
|
|
||||||
|
NUM_TRIALS = 999
|
||||||
|
LENGTH_PER_TRIAL = 500
|
||||||
|
|
||||||
|
TEMPERATURE = 1.0
|
||||||
|
top_p = 0.7
|
||||||
|
top_p_newline = 0.9
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
print(f'Loading {MODEL_NAME}...')
|
||||||
|
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
|
||||||
|
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
context = tokenizer.refine_context(context)
|
||||||
|
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
|
||||||
|
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n')
|
||||||
|
|
||||||
|
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
|
||||||
|
t_begin = time.time_ns()
|
||||||
|
|
||||||
|
src_len = len(context)
|
||||||
|
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
|
||||||
|
print(('-' * 30) + context, end='')
|
||||||
|
|
||||||
|
model.clear()
|
||||||
|
if TRIAL == 0:
|
||||||
|
init_state = types.SimpleNamespace()
|
||||||
|
for i in range(src_len):
|
||||||
|
x = ctx[:i+1]
|
||||||
|
if i == src_len - 1:
|
||||||
|
init_state.out = model.run(x)
|
||||||
|
else:
|
||||||
|
model.run(x)
|
||||||
|
model.save(init_state)
|
||||||
|
else:
|
||||||
|
model.load(init_state)
|
||||||
|
|
||||||
|
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
|
||||||
|
x = ctx[:i+1]
|
||||||
|
x = x[-ctx_len:]
|
||||||
|
|
||||||
|
if i == src_len:
|
||||||
|
out = copy.deepcopy(init_state.out)
|
||||||
|
else:
|
||||||
|
out = model.run(x)
|
||||||
|
if DEBUG_DEBUG:
|
||||||
|
print('model', np.array(x), '==>', np.array(
|
||||||
|
out), np.max(out), np.min(out))
|
||||||
|
|
||||||
|
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
|
||||||
|
top_p_usual=top_p, top_p_newline=top_p_newline)
|
||||||
|
char = char.item()
|
||||||
|
print(tokenizer.itos[int(char)], end='', flush=True)
|
||||||
|
ctx += [char]
|
||||||
|
t_end = time.time_ns()
|
||||||
|
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')
|
||||||
@ -0,0 +1,348 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import math, os
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from deepspeed.ops.adam import FusedAdam
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
RWKV_HEAD_QK_DIM = 256
|
||||||
|
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
# CUDA Kernel
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
T_MAX = 4096 # increase this if your ctx_len is long
|
||||||
|
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
|
||||||
|
|
||||||
|
from torch.utils.cpp_extension import load
|
||||||
|
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
|
||||||
|
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}'])
|
||||||
|
|
||||||
|
class WKV(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, B, T, C, w, u, k, v):
|
||||||
|
ctx.B = B
|
||||||
|
ctx.T = T
|
||||||
|
ctx.C = C
|
||||||
|
assert T <= T_MAX
|
||||||
|
assert B * C % min(C, 1024) == 0
|
||||||
|
w = -torch.exp(w.float().contiguous())
|
||||||
|
u = u.float().contiguous()
|
||||||
|
k = k.float().contiguous()
|
||||||
|
v = v.float().contiguous()
|
||||||
|
ctx.save_for_backward(w, u, k, v)
|
||||||
|
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
|
||||||
|
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
||||||
|
return y.half()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, gy):
|
||||||
|
B = ctx.B
|
||||||
|
T = ctx.T
|
||||||
|
C = ctx.C
|
||||||
|
assert T <= T_MAX
|
||||||
|
assert B * C % min(C, 1024) == 0
|
||||||
|
w, u, k, v = ctx.saved_tensors
|
||||||
|
gw = torch.zeros((B, C), device='cuda')
|
||||||
|
gu = torch.zeros((B, C), device='cuda')
|
||||||
|
gk = torch.zeros((B, T, C), device='cuda')
|
||||||
|
gv = torch.zeros((B, T, C), device='cuda')
|
||||||
|
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
|
||||||
|
gw = torch.sum(gw, dim=0)
|
||||||
|
gu = torch.sum(gu, dim=0)
|
||||||
|
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
||||||
|
|
||||||
|
def RUN_CUDA(B, T, C, w, u, k, v):
|
||||||
|
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
|
||||||
|
print('\n[--> first run, init model params (very slow for large models) <--]\n')
|
||||||
|
print('\n[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n')
|
||||||
|
for m in module.modules():
|
||||||
|
if not isinstance(m, (nn.Linear, nn.Embedding)):
|
||||||
|
continue
|
||||||
|
with torch.no_grad():
|
||||||
|
name = '[unknown weight]'
|
||||||
|
for name, parameter in module.named_parameters(): # find the name of the weight
|
||||||
|
if id(m.weight) == id(parameter):
|
||||||
|
break
|
||||||
|
|
||||||
|
shape = m.weight.data.shape
|
||||||
|
gain = 1.0
|
||||||
|
scale = 1.0 # extra scale for gain
|
||||||
|
|
||||||
|
if isinstance(m, nn.Embedding):
|
||||||
|
gain = math.sqrt(max(shape[0], shape[1]))
|
||||||
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
|
||||||
|
scale = 1e-4
|
||||||
|
else:
|
||||||
|
scale = 0
|
||||||
|
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
if shape[0] > shape[1]:
|
||||||
|
gain = math.sqrt(shape[0] / shape[1])
|
||||||
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
|
||||||
|
scale = 0.5
|
||||||
|
|
||||||
|
if hasattr(m, 'scale_init'):
|
||||||
|
scale = m.scale_init
|
||||||
|
|
||||||
|
# print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
|
||||||
|
|
||||||
|
gain *= scale
|
||||||
|
if scale == -999:
|
||||||
|
nn.init.eye_(m.weight)
|
||||||
|
elif gain == 0:
|
||||||
|
# zero init is great for some RWKV matrices
|
||||||
|
nn.init.zeros_(m.weight)
|
||||||
|
elif gain > 0:
|
||||||
|
nn.init.orthogonal_(m.weight, gain=gain)
|
||||||
|
else:
|
||||||
|
nn.init.normal_(m.weight, mean=0.0, std=-scale)
|
||||||
|
|
||||||
|
|
||||||
|
class RWKV_TimeMix(nn.Module):
|
||||||
|
def __init__(self, config, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.ctx_len = config.ctx_len
|
||||||
|
self.n_embd = config.n_embd
|
||||||
|
|
||||||
|
attn_sz = config.n_embd
|
||||||
|
|
||||||
|
with torch.no_grad(): # fancy init
|
||||||
|
ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1
|
||||||
|
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
|
||||||
|
|
||||||
|
# fancy time_decay
|
||||||
|
decay_speed = torch.ones(attn_sz)
|
||||||
|
for h in range(attn_sz):
|
||||||
|
decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
||||||
|
self.time_decay = nn.Parameter(decay_speed)
|
||||||
|
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
||||||
|
|
||||||
|
# fancy time_first
|
||||||
|
zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5)
|
||||||
|
self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
|
||||||
|
|
||||||
|
# fancy time_mix
|
||||||
|
x = torch.ones(1, 1, config.n_embd)
|
||||||
|
for i in range(config.n_embd):
|
||||||
|
x[0, 0, i] = i / config.n_embd
|
||||||
|
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||||
|
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
||||||
|
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
|
||||||
|
|
||||||
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||||
|
|
||||||
|
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
|
||||||
|
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
|
||||||
|
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
|
||||||
|
|
||||||
|
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
|
||||||
|
|
||||||
|
self.key.scale_init = 0
|
||||||
|
self.receptance.scale_init = 0
|
||||||
|
self.output.scale_init = 0
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, T, C = x.size() # x = (Batch,Time,Channel)
|
||||||
|
|
||||||
|
# Mix x with the previous timestep to produce xk, xv, xr
|
||||||
|
xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
||||||
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||||
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||||
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||||
|
|
||||||
|
# Use xk, xv, xr to produce k, v, r
|
||||||
|
k = self.key(xk)
|
||||||
|
v = self.value(xv)
|
||||||
|
r = self.receptance(xr)
|
||||||
|
|
||||||
|
rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
||||||
|
rwkv = self.output(rwkv)
|
||||||
|
return rwkv
|
||||||
|
|
||||||
|
|
||||||
|
class RWKV_ChannelMix(nn.Module):
|
||||||
|
def __init__(self, config, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||||
|
|
||||||
|
with torch.no_grad(): # fancy init of time_mix
|
||||||
|
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
|
||||||
|
|
||||||
|
x = torch.ones(1, 1, config.n_embd)
|
||||||
|
for i in range(config.n_embd):
|
||||||
|
x[0, 0, i] = i / config.n_embd
|
||||||
|
|
||||||
|
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||||
|
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||||
|
|
||||||
|
hidden_sz = 4 * config.n_embd
|
||||||
|
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
|
||||||
|
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
||||||
|
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
|
||||||
|
|
||||||
|
self.value.scale_init = 0
|
||||||
|
self.receptance.scale_init = 0
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
xx = self.time_shift(x)
|
||||||
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||||
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||||
|
|
||||||
|
k = self.key(xk)
|
||||||
|
k = torch.square(torch.relu(k))
|
||||||
|
kv = self.value(k)
|
||||||
|
|
||||||
|
rkv = torch.sigmoid(self.receptance(xr)) * kv
|
||||||
|
return rkv
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
# The GPT Model with our blocks
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class GPTConfig:
|
||||||
|
def __init__(self, vocab_size, ctx_len, **kwargs):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.ctx_len = ctx_len
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, config, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
|
self.ln1 = nn.LayerNorm(config.n_embd)
|
||||||
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
||||||
|
|
||||||
|
if self.layer_id == 0:
|
||||||
|
self.ln0 = nn.LayerNorm(config.n_embd)
|
||||||
|
|
||||||
|
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
|
||||||
|
self.ffnPre = RWKV_ChannelMix(config, layer_id+1000)
|
||||||
|
else:
|
||||||
|
self.att = RWKV_TimeMix(config, layer_id)
|
||||||
|
|
||||||
|
self.ffn = RWKV_ChannelMix(config, layer_id)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.layer_id == 0:
|
||||||
|
x = self.ln0(x)
|
||||||
|
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
|
||||||
|
x = x + self.ffnPre(self.ln1(x)) # better in some cases
|
||||||
|
else:
|
||||||
|
x = x + self.att(self.ln1(x))
|
||||||
|
x = x + self.ffn(self.ln2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GPT(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.step = 0
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
|
||||||
|
|
||||||
|
self.blocks = nn.Sequential(*[Block(config, i)
|
||||||
|
for i in range(config.n_layer)])
|
||||||
|
|
||||||
|
self.ln_out = nn.LayerNorm(config.n_embd)
|
||||||
|
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
if RWKV_HEAD_QK_DIM > 0:
|
||||||
|
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
||||||
|
self.head_q.scale_init = 0
|
||||||
|
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
||||||
|
self.head_k.scale_init = 0.1
|
||||||
|
self.register_buffer("copy_mask", torch.tril(
|
||||||
|
torch.ones(config.ctx_len, config.ctx_len)))
|
||||||
|
|
||||||
|
self.ctx_len = config.ctx_len
|
||||||
|
|
||||||
|
try:
|
||||||
|
if os.environ['RWKV_LOAD_MODEL'] == str(False):
|
||||||
|
RWKV_Init(self, config)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info("number of parameters: %e", sum(p.numel()
|
||||||
|
for p in self.parameters()))
|
||||||
|
|
||||||
|
def get_ctx_len(self):
|
||||||
|
return self.ctx_len
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if isinstance(module, (nn.Linear)):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=0.01)
|
||||||
|
if isinstance(module, (nn.Embedding)):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=1e-5)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
def configure_optimizers(self, train_config):
|
||||||
|
no_decay = set()
|
||||||
|
|
||||||
|
for mn, m in self.named_modules(): # here we disable weight_decay
|
||||||
|
for pn, p in m.named_parameters():
|
||||||
|
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
||||||
|
no_decay.add(fpn)
|
||||||
|
|
||||||
|
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||||
|
optim_groups = [
|
||||||
|
{"params": [param_dict[pn]
|
||||||
|
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
||||||
|
]
|
||||||
|
|
||||||
|
optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def forward(self, idx, targets=None):
|
||||||
|
idx = idx.to(self.emb.weight.device)
|
||||||
|
|
||||||
|
self.step += 1
|
||||||
|
B, T = idx.size()
|
||||||
|
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
|
||||||
|
|
||||||
|
x = self.emb(idx)
|
||||||
|
x = self.blocks(x)
|
||||||
|
x = self.ln_out(x)
|
||||||
|
|
||||||
|
if RWKV_HEAD_QK_DIM > 0:
|
||||||
|
q = self.head_q(x)[:, :T, :]
|
||||||
|
k = self.head_k(x)[:, :T, :]
|
||||||
|
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
|
||||||
|
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
||||||
|
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half()
|
||||||
|
x = self.head(x) + c
|
||||||
|
else:
|
||||||
|
x = self.head(x)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if targets is not None:
|
||||||
|
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1))
|
||||||
|
|
||||||
|
return x, loss
|
||||||
@ -0,0 +1,366 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import types
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from torch.nn import functional as F
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
RWKV_HEAD_QK_DIM = 256
|
||||||
|
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
|
||||||
|
|
||||||
|
DEBUG_TIME = False # True False - show trained time-coeffs
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
# CUDA Kernel
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
T_MAX = 4096 # increase this if your ctx_len is long
|
||||||
|
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
|
||||||
|
|
||||||
|
from torch.utils.cpp_extension import load
|
||||||
|
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
|
||||||
|
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}'])
|
||||||
|
|
||||||
|
class WKV(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, B, T, C, w, u, k, v):
|
||||||
|
ctx.B = B
|
||||||
|
ctx.T = T
|
||||||
|
ctx.C = C
|
||||||
|
assert T <= T_MAX
|
||||||
|
assert B * C % min(C, 1024) == 0
|
||||||
|
w = -torch.exp(w.float().contiguous())
|
||||||
|
u = u.float().contiguous()
|
||||||
|
k = k.float().contiguous()
|
||||||
|
v = v.float().contiguous()
|
||||||
|
ctx.save_for_backward(w, u, k, v)
|
||||||
|
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
|
||||||
|
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
||||||
|
return y.half()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, gy):
|
||||||
|
B = ctx.B
|
||||||
|
T = ctx.T
|
||||||
|
C = ctx.C
|
||||||
|
assert T <= T_MAX
|
||||||
|
assert B * C % min(C, 1024) == 0
|
||||||
|
w, u, k, v = ctx.saved_tensors
|
||||||
|
gw = torch.zeros((B, C), device='cuda')
|
||||||
|
gu = torch.zeros((B, C), device='cuda')
|
||||||
|
gk = torch.zeros((B, T, C), device='cuda')
|
||||||
|
gv = torch.zeros((B, T, C), device='cuda')
|
||||||
|
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
|
||||||
|
gw = torch.sum(gw, dim=0)
|
||||||
|
gu = torch.sum(gu, dim=0)
|
||||||
|
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
||||||
|
|
||||||
|
def RUN_CUDA(B, T, C, w, u, k, v):
|
||||||
|
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
|
||||||
|
|
||||||
|
############################################################################################################
|
||||||
|
|
||||||
|
RWKV_CFG = types.SimpleNamespace()
|
||||||
|
|
||||||
|
class RWKV_ChannelMix(nn.Module):
|
||||||
|
def __init__(self, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
|
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
||||||
|
self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
|
||||||
|
self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
|
||||||
|
|
||||||
|
hidden_sz = 4 * RWKV_CFG.n_embd
|
||||||
|
self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False)
|
||||||
|
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
||||||
|
self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
xx = self.time_shift(x)
|
||||||
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||||
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||||
|
|
||||||
|
k = self.key(xk)
|
||||||
|
k = torch.square(torch.relu(k))
|
||||||
|
kv = self.value(k)
|
||||||
|
|
||||||
|
rkv = torch.sigmoid(self.receptance(xr)) * kv
|
||||||
|
return rkv
|
||||||
|
|
||||||
|
class RWKV_TimeMix(nn.Module):
|
||||||
|
def __init__(self, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd))
|
||||||
|
self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd) * math.log(0.3))
|
||||||
|
|
||||||
|
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
||||||
|
self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
||||||
|
self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
||||||
|
self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
||||||
|
|
||||||
|
self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
||||||
|
self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
||||||
|
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
||||||
|
|
||||||
|
self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, T, C = x.size()
|
||||||
|
|
||||||
|
xx = self.time_shift(x)
|
||||||
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||||
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||||
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||||
|
|
||||||
|
k = self.key(xk)
|
||||||
|
v = self.value(xv)
|
||||||
|
r = self.receptance(xr)
|
||||||
|
|
||||||
|
rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
||||||
|
|
||||||
|
rwkv = self.output(rwkv)
|
||||||
|
return rwkv
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
|
self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd)
|
||||||
|
self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd)
|
||||||
|
if self.layer_id == 0:
|
||||||
|
self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd)
|
||||||
|
|
||||||
|
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
|
||||||
|
self.ffnPre = RWKV_ChannelMix(layer_id+1000)
|
||||||
|
else:
|
||||||
|
self.att = RWKV_TimeMix(layer_id)
|
||||||
|
|
||||||
|
self.ffn = RWKV_ChannelMix(layer_id)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.layer_id == 0:
|
||||||
|
x = self.ln0(x)
|
||||||
|
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
|
||||||
|
x = x + self.ffnPre(self.ln1(x))
|
||||||
|
else:
|
||||||
|
x = x + self.att(self.ln1(x))
|
||||||
|
x = x + self.ffn(self.ln2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
class RWKV_GPT(nn.Module):
|
||||||
|
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len):
|
||||||
|
global RWKV_CFG
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
RWKV_CFG.RUN_DEVICE = RUN_DEVICE
|
||||||
|
RWKV_CFG.model_type = model_type
|
||||||
|
RWKV_CFG.vocab_size = vocab_size
|
||||||
|
RWKV_CFG.n_layer = n_layer
|
||||||
|
RWKV_CFG.n_embd = n_embd
|
||||||
|
RWKV_CFG.ctx_len = ctx_len
|
||||||
|
|
||||||
|
print('\nloading RWKV-GPT', MODEL_NAME)
|
||||||
|
|
||||||
|
self.emb = nn.Embedding(vocab_size, n_embd)
|
||||||
|
|
||||||
|
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
|
||||||
|
|
||||||
|
self.ln_out = nn.LayerNorm(n_embd)
|
||||||
|
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
||||||
|
|
||||||
|
if RWKV_HEAD_QK_DIM > 0:
|
||||||
|
self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
||||||
|
self.head_q.scale_init = 0
|
||||||
|
self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
||||||
|
self.head_k.scale_init = 0.1
|
||||||
|
self.register_buffer("copy_mask", torch.tril(
|
||||||
|
torch.ones(ctx_len, ctx_len)))
|
||||||
|
|
||||||
|
self.ctx_len = ctx_len
|
||||||
|
self.eval()
|
||||||
|
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
def forward(self, idx):
|
||||||
|
B, T = idx.size()
|
||||||
|
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
|
||||||
|
|
||||||
|
x = self.emb(idx)
|
||||||
|
x = self.blocks(x)
|
||||||
|
x = self.ln_out(x)
|
||||||
|
|
||||||
|
if RWKV_HEAD_QK_DIM > 0:
|
||||||
|
q = self.head_q(x)[:, :T, :]
|
||||||
|
k = self.head_k(x)[:, :T, :]
|
||||||
|
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
|
||||||
|
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
||||||
|
|
||||||
|
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).float()
|
||||||
|
x = self.head(x) + c
|
||||||
|
else:
|
||||||
|
x = self.head(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
############################################################################################################
|
||||||
|
|
||||||
|
class RWKV_RNN(): # this is running in FP32 at this moment
|
||||||
|
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
|
||||||
|
self.RUN_DEVICE = RUN_DEVICE
|
||||||
|
self.model_type = model_type
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.n_embd = n_embd
|
||||||
|
self.ctx_len = ctx_len
|
||||||
|
|
||||||
|
self.w = types.SimpleNamespace()
|
||||||
|
|
||||||
|
w = torch.load(MODEL_NAME + '.pth',
|
||||||
|
map_location=torch.device(RUN_DEVICE))
|
||||||
|
for x in w.keys():
|
||||||
|
w[x] = w[x].float()
|
||||||
|
if '.time_' in x:
|
||||||
|
w[x] = w[x].squeeze()
|
||||||
|
if '.time_decay' in x:
|
||||||
|
w[x] = -torch.exp(w[x])
|
||||||
|
if DEBUG_TIME and '.time_' in x:
|
||||||
|
print(x, w[x].squeeze().cpu().numpy())
|
||||||
|
|
||||||
|
xx = x.split('.')
|
||||||
|
here = self.w
|
||||||
|
for i in range(len(xx)):
|
||||||
|
if xx[i].isdigit():
|
||||||
|
ii = int(xx[i])
|
||||||
|
if ii not in here:
|
||||||
|
here[ii] = types.SimpleNamespace()
|
||||||
|
here = here[ii]
|
||||||
|
else:
|
||||||
|
if i == len(xx) - 1:
|
||||||
|
setattr(here, xx[i], w[x])
|
||||||
|
elif not hasattr(here, xx[i]):
|
||||||
|
if xx[i+1].isdigit():
|
||||||
|
setattr(here, xx[i], {})
|
||||||
|
else:
|
||||||
|
setattr(here, xx[i], types.SimpleNamespace())
|
||||||
|
here = getattr(here, xx[i])
|
||||||
|
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.xx = {}
|
||||||
|
self.aa = {}
|
||||||
|
self.bb = {}
|
||||||
|
self.pp = {}
|
||||||
|
self.hk = None
|
||||||
|
|
||||||
|
def save(self, target):
|
||||||
|
target.xx = copy.deepcopy(self.xx)
|
||||||
|
target.aa = copy.deepcopy(self.aa)
|
||||||
|
target.bb = copy.deepcopy(self.bb)
|
||||||
|
target.pp = copy.deepcopy(self.pp)
|
||||||
|
target.hk = copy.deepcopy(self.hk)
|
||||||
|
|
||||||
|
def load(self, target):
|
||||||
|
self.xx = copy.deepcopy(target.xx)
|
||||||
|
self.aa = copy.deepcopy(target.aa)
|
||||||
|
self.bb = copy.deepcopy(target.bb)
|
||||||
|
self.pp = copy.deepcopy(target.pp)
|
||||||
|
self.hk = copy.deepcopy(target.hk)
|
||||||
|
|
||||||
|
def LN(self, xx, w):
|
||||||
|
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
|
||||||
|
|
||||||
|
def FF(self, xx, w, name):
|
||||||
|
if name not in self.xx:
|
||||||
|
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||||
|
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
|
||||||
|
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
|
||||||
|
self.xx[name] = xx
|
||||||
|
|
||||||
|
r = torch.sigmoid(w.receptance.weight @ xr)
|
||||||
|
k = torch.square(torch.relu(w.key.weight @ xk))
|
||||||
|
kv = w.value.weight @ k
|
||||||
|
|
||||||
|
return r * kv
|
||||||
|
|
||||||
|
def SA(self, xx, w, name):
|
||||||
|
if name not in self.xx:
|
||||||
|
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||||
|
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||||
|
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||||
|
self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30
|
||||||
|
|
||||||
|
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
|
||||||
|
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
|
||||||
|
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
|
||||||
|
self.xx[name] = xx
|
||||||
|
|
||||||
|
r = torch.sigmoid(w.receptance.weight @ xr)
|
||||||
|
|
||||||
|
k = w.key.weight @ xk
|
||||||
|
v = w.value.weight @ xv
|
||||||
|
|
||||||
|
pp = self.pp[name]
|
||||||
|
aa = self.aa[name]
|
||||||
|
bb = self.bb[name]
|
||||||
|
ww = w.time_first + k
|
||||||
|
p = torch.maximum(pp, ww)
|
||||||
|
e1 = torch.exp(pp - p)
|
||||||
|
e2 = torch.exp(ww - p)
|
||||||
|
a = e1 * aa + e2 * v
|
||||||
|
b = e1 * bb + e2
|
||||||
|
ww = pp + w.time_decay
|
||||||
|
p = torch.maximum(ww, k)
|
||||||
|
e1 = torch.exp(ww - p)
|
||||||
|
e2 = torch.exp(k - p)
|
||||||
|
self.aa[name] = e1 * aa + e2 * v
|
||||||
|
self.bb[name] = e1 * bb + e2
|
||||||
|
self.pp[name] = p
|
||||||
|
|
||||||
|
rwkv = r * a / b
|
||||||
|
|
||||||
|
return w.output.weight @ rwkv
|
||||||
|
|
||||||
|
def run(self, ctx):
|
||||||
|
w = self.w
|
||||||
|
x = w.emb.weight[ctx[-1]]
|
||||||
|
|
||||||
|
for i in range(self.n_layer):
|
||||||
|
if i == 0:
|
||||||
|
x = self.LN(x, w.blocks[i].ln0)
|
||||||
|
if i == 0 and self.model_type == 'RWKV-ffnPre':
|
||||||
|
x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
|
||||||
|
else:
|
||||||
|
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
|
||||||
|
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
|
||||||
|
|
||||||
|
x = self.LN(x, w.ln_out)
|
||||||
|
|
||||||
|
if RWKV_HEAD_QK_DIM > 0:
|
||||||
|
if self.hk == None:
|
||||||
|
self.hk = (w.head_k.weight @ x).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
self.hk = torch.cat(
|
||||||
|
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
|
||||||
|
if self.hk.shape[0] > self.ctx_len:
|
||||||
|
self.hk = self.hk[-self.ctx_len:, :]
|
||||||
|
|
||||||
|
q = w.head_q.weight @ x
|
||||||
|
|
||||||
|
x = w.head.weight @ x
|
||||||
|
x = x.cpu().numpy().tolist()
|
||||||
|
|
||||||
|
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
|
||||||
|
for i in range(len(c)):
|
||||||
|
x[ctx[i]] += c[i]
|
||||||
|
else:
|
||||||
|
x = w.head.weight @ x
|
||||||
|
x = x.cpu().numpy().tolist()
|
||||||
|
|
||||||
|
return x
|
||||||
@ -0,0 +1,177 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import os
|
||||||
|
NUM_GPUS = int(os.environ['RWKV_NUM_GPUS'])
|
||||||
|
USE_WANDB = (int(os.environ['USE_WANDB']) == 1)
|
||||||
|
|
||||||
|
from torch.utils.data.dataloader import DataLoader
|
||||||
|
import torch
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
import logging
|
||||||
|
import datetime
|
||||||
|
import math
|
||||||
|
from pytorch_lightning.lite import LightningLite
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
class TrainerConfig:
|
||||||
|
batch_size = 64
|
||||||
|
learning_rate = 4e-4
|
||||||
|
betas = (0.9, 0.99)
|
||||||
|
eps = 1e-8
|
||||||
|
grad_norm_clip = 1.0
|
||||||
|
warmup_tokens = 0
|
||||||
|
final_tokens = 0
|
||||||
|
epoch_save_frequency = 0
|
||||||
|
epoch_save_path = 'trained-'
|
||||||
|
num_workers = 0 # for DataLoader
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
from src.model import GPT, GPTConfig
|
||||||
|
|
||||||
|
class Trainer(LightningLite):
|
||||||
|
|
||||||
|
def get_run_name(self):
|
||||||
|
raw_model = self.model.module if hasattr(
|
||||||
|
self.model, "module") else self.model
|
||||||
|
cfg = raw_model.config
|
||||||
|
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
|
||||||
|
cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
|
||||||
|
return run_name
|
||||||
|
|
||||||
|
def run(self, m_cfg, train_dataset, test_dataset, config):
|
||||||
|
self.cuda_id = int(str(self.device).strip('cuda:'))
|
||||||
|
print('[0]')
|
||||||
|
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=m_cfg.model_type,
|
||||||
|
n_layer=m_cfg.n_layer, n_embd=m_cfg.n_embd))
|
||||||
|
print('[1]')
|
||||||
|
model.to(self.device)
|
||||||
|
print('[2]')
|
||||||
|
with torch.no_grad():
|
||||||
|
if m_cfg.LOAD_MODEL:
|
||||||
|
print('loading', m_cfg.MODEL_NAME)
|
||||||
|
m2 = torch.load(m_cfg.MODEL_NAME + '.pth', map_location=torch.device(self.device))
|
||||||
|
model.load_state_dict(m2)
|
||||||
|
del m2
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.train_dataset = train_dataset
|
||||||
|
self.test_dataset = test_dataset
|
||||||
|
self.config = config
|
||||||
|
self.avg_loss = -1
|
||||||
|
self.EPOCH_BEGIN = m_cfg.EPOCH_BEGIN
|
||||||
|
|
||||||
|
self.steps = self.EPOCH_BEGIN * (len(self.train_dataset) // (config.batch_size // NUM_GPUS))
|
||||||
|
|
||||||
|
if self.cuda_id == 0:
|
||||||
|
log_file = open("mylog.txt", "a")
|
||||||
|
if USE_WANDB:
|
||||||
|
print('logging to wandb... (comment it if you don\'t have wandb)')
|
||||||
|
import wandb # comment this if you don't have wandb
|
||||||
|
cfg = model.config
|
||||||
|
for k in config.__dict__:
|
||||||
|
setattr(cfg, k, config.__dict__[k]) # combine cfg
|
||||||
|
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
|
||||||
|
|
||||||
|
model, config = self.model, self.config
|
||||||
|
raw_model = model.module if hasattr(self.model, "module") else model
|
||||||
|
optimizer = raw_model.configure_optimizers(config)
|
||||||
|
model, optimizer = self.setup(model, optimizer)
|
||||||
|
print('[3]')
|
||||||
|
|
||||||
|
def run_epoch(split):
|
||||||
|
is_train = split == 'train'
|
||||||
|
model.train(is_train)
|
||||||
|
data = self.train_dataset if is_train else self.test_dataset
|
||||||
|
data.idx_begin = self.steps * config.batch_size + 1
|
||||||
|
data.cuda_id = self.cuda_id
|
||||||
|
|
||||||
|
if config.num_workers > 0:
|
||||||
|
loader = DataLoader(data, shuffle=False, pin_memory=True,
|
||||||
|
batch_size=config.batch_size // NUM_GPUS,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
else:
|
||||||
|
loader = DataLoader(data, shuffle=False,
|
||||||
|
batch_size=config.batch_size // NUM_GPUS,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
|
||||||
|
pbar = tqdm(enumerate(loader), total=len(
|
||||||
|
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
|
||||||
|
loader = self.setup_dataloaders(loader)
|
||||||
|
|
||||||
|
for it, (x, y) in pbar:
|
||||||
|
with torch.set_grad_enabled(is_train):
|
||||||
|
_, loss = model(x, y) # forward the model
|
||||||
|
|
||||||
|
all_loss = [loss.clone() for _ in range(NUM_GPUS)]
|
||||||
|
torch.distributed.all_gather(all_loss, loss)
|
||||||
|
|
||||||
|
if is_train: # backprop and update the parameters
|
||||||
|
model.zero_grad()
|
||||||
|
self.backward(loss)
|
||||||
|
|
||||||
|
# deepspeed will handle gradient_clipping
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# decay the learning rate based on our progress
|
||||||
|
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
|
||||||
|
lr_final_factor = config.lr_final / config.learning_rate
|
||||||
|
if self.tokens < config.warmup_tokens:
|
||||||
|
# linear warmup
|
||||||
|
lr_mult = lr_final_factor + \
|
||||||
|
(1 - lr_final_factor) * float(self.tokens) / \
|
||||||
|
float(config.warmup_tokens)
|
||||||
|
progress = 0
|
||||||
|
else:
|
||||||
|
# exponential learning rate decay
|
||||||
|
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
|
||||||
|
if progress >= 1:
|
||||||
|
lr_mult = lr_final_factor
|
||||||
|
else:
|
||||||
|
lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1))
|
||||||
|
lr = config.learning_rate * lr_mult
|
||||||
|
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
self.lr = lr
|
||||||
|
self.steps += 1
|
||||||
|
|
||||||
|
now_loss = 0
|
||||||
|
for gg in range(NUM_GPUS):
|
||||||
|
now_loss += all_loss[gg].item()
|
||||||
|
now_loss = now_loss / NUM_GPUS # report progress
|
||||||
|
if USE_WANDB and self.cuda_id == 0:
|
||||||
|
wandb.log({"loss": now_loss}, step = self.steps)
|
||||||
|
|
||||||
|
if self.avg_loss < 0:
|
||||||
|
self.avg_loss = now_loss
|
||||||
|
else:
|
||||||
|
factor = 1 / (it + 1)
|
||||||
|
self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor
|
||||||
|
|
||||||
|
pbar.set_description(f"miniE {epoch+1+self.EPOCH_BEGIN} s {self.steps} prog {progress*100.0:.2f}% : ppl {math.exp(self.avg_loss):.6f} loss {self.avg_loss:.6f} lr {lr:e}")
|
||||||
|
|
||||||
|
self.tokens = 0 # counter used for learning rate decay
|
||||||
|
for epoch in range(99999999):
|
||||||
|
|
||||||
|
run_epoch('train')
|
||||||
|
if math.isnan(self.avg_loss):
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
if self.cuda_id == 0:
|
||||||
|
log_file.write(f'{epoch+1+self.EPOCH_BEGIN} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} {epoch+1} \n')
|
||||||
|
log_file.flush()
|
||||||
|
|
||||||
|
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
|
||||||
|
raw_model = self.model.module if hasattr(self.model, "module") else self.model
|
||||||
|
torch.save(raw_model.state_dict(), self.config.epoch_save_path + str(epoch+1+self.EPOCH_BEGIN) + '.pth')
|
||||||
@ -0,0 +1,122 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import os
|
||||||
|
try:
|
||||||
|
NUM_GPUS = int(os.environ['RWKV_NUM_GPUS'])
|
||||||
|
except:
|
||||||
|
NUM_GPUS = 1
|
||||||
|
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
class Dataset(Dataset):
|
||||||
|
def __init__(self, data, ctx_len, epoch_length_fixed):
|
||||||
|
print('building token list...', end=' ')
|
||||||
|
unique = sorted(list(set(data)))
|
||||||
|
# print()
|
||||||
|
# for u in unique:
|
||||||
|
# print(u, end=' ')
|
||||||
|
# print('\n\n')
|
||||||
|
|
||||||
|
xx = 0
|
||||||
|
xxObj = {}
|
||||||
|
for u in unique:
|
||||||
|
xxObj[xx] = u
|
||||||
|
xx += 1
|
||||||
|
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
|
||||||
|
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
||||||
|
|
||||||
|
data_size, vocab_size = len(data), len(unique)
|
||||||
|
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
|
||||||
|
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
||||||
|
self.itos = {i: ch for i, ch in enumerate(unique)}
|
||||||
|
self.ctx_len = ctx_len
|
||||||
|
self.epoch_length_fixed = epoch_length_fixed
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.epoch_length_fixed // NUM_GPUS
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# cheat: pick a random spot in dataset
|
||||||
|
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
|
||||||
|
chunk = self.data[i:i+self.ctx_len+1]
|
||||||
|
dix = [self.stoi[s] for s in chunk]
|
||||||
|
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||||
|
y = torch.tensor(dix[1:], dtype=torch.long)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class TOKENIZER():
|
||||||
|
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
|
||||||
|
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
|
||||||
|
self.word_table = json.load(result_file)
|
||||||
|
|
||||||
|
self.vocab_size = len(self.word_table)
|
||||||
|
|
||||||
|
self.stoi = {v: int(k) for k, v in self.word_table.items()}
|
||||||
|
self.itos = {int(k): v for k, v in self.word_table.items()}
|
||||||
|
|
||||||
|
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
||||||
|
|
||||||
|
def refine_context(self, context):
|
||||||
|
context = context.strip().split('\n')
|
||||||
|
for c in range(len(context)):
|
||||||
|
context[c] = context[c].strip().strip('\u3000').strip('\r')
|
||||||
|
context = list(filter(lambda c: c != '', context))
|
||||||
|
context = '\n' + ('\n'.join(context)).strip()
|
||||||
|
if context == '':
|
||||||
|
context = '\n'
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
|
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
|
||||||
|
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
||||||
|
|
||||||
|
lastChar = int(x[-1])
|
||||||
|
|
||||||
|
probs = F.softmax(torch.tensor(out), dim=-1)
|
||||||
|
|
||||||
|
if self.itos[lastChar] == '\n':
|
||||||
|
top_p = top_p_newline
|
||||||
|
else:
|
||||||
|
top_p = top_p_usual
|
||||||
|
|
||||||
|
sorted_probs, s_index = torch.sort(probs, descending=True)
|
||||||
|
|
||||||
|
# for j in range(30):
|
||||||
|
# pp = sorted_probs[j].item()
|
||||||
|
# if pp < 0.005:
|
||||||
|
# break
|
||||||
|
# ss = self.itos[int(s_index[j])].replace('\n','_')
|
||||||
|
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
|
||||||
|
# print('')
|
||||||
|
|
||||||
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
|
||||||
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||||
|
|
||||||
|
probs[probs < cutoff] = 0
|
||||||
|
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
|
||||||
|
|
||||||
|
if temperature != 1.0:
|
||||||
|
probs = probs.pow(1.0 / temperature)
|
||||||
|
|
||||||
|
return torch.multinomial(probs, num_samples=1)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def to_float(x):
|
||||||
|
return x.cpu().detach().numpy().flatten()[0].astype(float)
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
@ -0,0 +1,135 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['USE_WANDB'] = '0' # 0 = False, 1 = True
|
||||||
|
|
||||||
|
### This is using DeepSpeed stage2 + FP16 ##############################################################
|
||||||
|
|
||||||
|
os.environ['RWKV_NUM_GPUS'] = '1' # num of GPUs to use
|
||||||
|
NUM_GPUS = int(os.environ['RWKV_NUM_GPUS'])
|
||||||
|
|
||||||
|
### Change these if you want to continue training from a saved model ###################################
|
||||||
|
|
||||||
|
EPOCH_BEGIN = 0
|
||||||
|
LOAD_MODEL = False # shall we continue from the #EPOCH_BEGIN model?
|
||||||
|
os.environ['RWKV_LOAD_MODEL'] = str(LOAD_MODEL)
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
# if False: # True False ---> Set to False if you don't understand it
|
||||||
|
# print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n")
|
||||||
|
# import src.utils
|
||||||
|
# src.utils.set_seed(42) # make training deterministic (including dataloader). if you are doing this, remember to change seed when you load a model (otherwise the dataloader loads old samples)
|
||||||
|
|
||||||
|
import logging, types
|
||||||
|
from src.utils import Dataset
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||||
|
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
### Step 1: set training data ##########################################################################
|
||||||
|
|
||||||
|
datafile = "../data/enwik8" # your data
|
||||||
|
datafile_encoding = 'utf-8'
|
||||||
|
# datafile_encoding = 'utf-16le'
|
||||||
|
|
||||||
|
### Step 2: set model size #############################################################################
|
||||||
|
|
||||||
|
ctx_len = 1024 # increase T_MAX in model.py if your ctx_len is very long
|
||||||
|
n_layer = 6
|
||||||
|
n_embd = 512
|
||||||
|
|
||||||
|
# 'RWKV' or 'RWKV-ffnPre' (better in some cases)
|
||||||
|
model_type = 'RWKV'
|
||||||
|
|
||||||
|
# ---> there is also a RWKV_HEAD_QK_DIM in model.py and model_run.py <---
|
||||||
|
# set it to 256, then it's using my headQK trick (similar to a tiny attention) to improve loss
|
||||||
|
# set it to 0, then it's a pure RNN (attention-free)
|
||||||
|
|
||||||
|
### Step 3: set batch size #############################################################################
|
||||||
|
|
||||||
|
# if you see "CUDA out of memory", reduce batch_size. Use nvidia-smi to find the highest value for your GPU.
|
||||||
|
batch_size = 12
|
||||||
|
assert (batch_size % NUM_GPUS == 0)
|
||||||
|
|
||||||
|
### Step 4: set learning rate, number of mini-epochs #######################################################
|
||||||
|
#
|
||||||
|
# By default we are using exponential LR decay.
|
||||||
|
# Here are my suggestions for training.
|
||||||
|
# Let's say you are training a L6-D512 model.
|
||||||
|
# 1) Set lr_init = lr_final = 8e-4. Let it run for some mini-epochs, until you feel like reducing LR.
|
||||||
|
# 2) Check epoch_save_frequency and make sure the partially-trained model is saved. Ctrl+C to stop the run.
|
||||||
|
# 3) Set lr_init = 8e-4, lr_final = 1e-5, betas = (0.9, 0.999).
|
||||||
|
# 4) Set EPOCH_BEGIN & LOAD_MODEL to load the partially-trained model. Continue the training.
|
||||||
|
#
|
||||||
|
# For L12-D768, set lr_init = 6e-4. For L24-D1024, set lr_init = 4e-4. For L24-D2048, set lr_init = 3e-4.
|
||||||
|
|
||||||
|
lr_init = 8e-4
|
||||||
|
lr_final = 1e-5
|
||||||
|
|
||||||
|
# the mini-epoch is very short and of fixed length (length = ctx_len * epoch_length_fixed tokens)
|
||||||
|
n_epoch = 500
|
||||||
|
epoch_length_fixed = (10000 // batch_size) * batch_size # feel free to increase it if you have lots of GPU
|
||||||
|
|
||||||
|
# epoch_save_frequency 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, ...
|
||||||
|
epoch_save_frequency = 10
|
||||||
|
epoch_save_path = 'trained-'
|
||||||
|
MODEL_NAME = epoch_save_path + str(EPOCH_BEGIN)
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
if LOAD_MODEL and EPOCH_BEGIN > 0: # we are not saving gradients. so let's have some warmup if we load a model
|
||||||
|
warmup_tokens = ctx_len * batch_size * 50
|
||||||
|
else:
|
||||||
|
warmup_tokens = ctx_len * batch_size * 0
|
||||||
|
|
||||||
|
betas = (0.9, 0.99)
|
||||||
|
eps = 1e-8
|
||||||
|
|
||||||
|
num_workers = 1 # DataLoader worker. I only tested num_workers = 1
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
# Load data
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
print('loading data... ' + datafile)
|
||||||
|
train_dataset = Dataset(open(
|
||||||
|
datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
# Train model
|
||||||
|
########################################################################################################
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from src.trainer import Trainer, TrainerConfig
|
||||||
|
|
||||||
|
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
|
||||||
|
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, )
|
||||||
|
|
||||||
|
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
|
||||||
|
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps,
|
||||||
|
warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
|
||||||
|
m_cfg = types.SimpleNamespace()
|
||||||
|
m_cfg.model_type = model_type
|
||||||
|
m_cfg.n_layer = n_layer
|
||||||
|
m_cfg.n_embd = n_embd
|
||||||
|
m_cfg.EPOCH_BEGIN = EPOCH_BEGIN
|
||||||
|
m_cfg.LOAD_MODEL = LOAD_MODEL
|
||||||
|
m_cfg.MODEL_NAME = MODEL_NAME
|
||||||
|
|
||||||
|
from pytorch_lightning.strategies import DeepSpeedStrategy
|
||||||
|
|
||||||
|
# you can set grad_norm_clip in deepspeed.json
|
||||||
|
|
||||||
|
trainer = Trainer(strategy=DeepSpeedStrategy(config='deepspeed.json'), devices=NUM_GPUS, accelerator="gpu", precision=16)
|
||||||
|
print(trainer._strategy.config)
|
||||||
|
|
||||||
|
trainer.run(m_cfg, train_dataset, None, tconf)
|
||||||
@ -0,0 +1,63 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
# this is for verifying the results of different models and make sure they agree with each other
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
RUN_DEVICE = 'cuda'
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from src.model_run import RWKV_RNN, RWKV_GPT
|
||||||
|
from src.model import GPT, GPTConfig
|
||||||
|
|
||||||
|
ctx_len = 1024
|
||||||
|
n_layer = 6
|
||||||
|
n_embd = 512
|
||||||
|
model_type = 'RWKV'
|
||||||
|
|
||||||
|
model_name = 'trained-1'
|
||||||
|
|
||||||
|
from src.utils import TOKENIZER
|
||||||
|
tokenizer = TOKENIZER('vocab', UNKNOWN_CHAR=' ')
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda().half()
|
||||||
|
print('loading ' + model_name)
|
||||||
|
m2 = torch.load(model_name + '.pth', map_location=RUN_DEVICE)
|
||||||
|
model_train.load_state_dict(m2)
|
||||||
|
|
||||||
|
model_rnn = RWKV_RNN(model_name, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
|
||||||
|
model_gpt = RWKV_GPT(model_name, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len).cuda()
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
context = '\nIn a'
|
||||||
|
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
|
||||||
|
print(f'input len {len(ctx)} data {ctx}')
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
print('\nRWKV-GPT output')
|
||||||
|
out = model_gpt.forward(torch.tensor(ctx).unsqueeze(0).cuda())[0].detach().cpu().numpy()
|
||||||
|
print(out)
|
||||||
|
|
||||||
|
print('\nRWKV-RNN output')
|
||||||
|
model_rnn.clear()
|
||||||
|
src_len = len(ctx)
|
||||||
|
for i in range(src_len):
|
||||||
|
x = ctx[:i+1]
|
||||||
|
out = model_rnn.run(x)
|
||||||
|
if i < 3 or i >= src_len - 3:
|
||||||
|
print(torch.tensor(out).detach().cpu().numpy())
|
||||||
|
if i == 2:
|
||||||
|
print('...')
|
||||||
|
|
||||||
|
print('\nRWKV-train output')
|
||||||
|
out = model_train.forward(torch.tensor([ctx]).cuda())[0][0].detach().cpu().numpy()
|
||||||
|
print(out, '\n')
|
||||||
Loading…
Reference in New Issue