diff --git a/RWKV-v4/cuda/wkv_cuda.cu b/RWKV-v4/cuda/wkv_cuda.cu new file mode 100644 index 0000000..720317c --- /dev/null +++ b/RWKV-v4/cuda/wkv_cuda.cu @@ -0,0 +1,125 @@ +#include +#include + +#define MIN_VALUE (-1e38) + +template +__global__ void kernel_forward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + F *__restrict__ const _y) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + + F p = 0, q = 0, o = MIN_VALUE; + // p and q are running sums divided by exp(o) (to avoid overflows) + for (int i = 0; i < T; i++) { + const int ii = i * C; + + F no = max(o, u + k[ii]); + F A = exp(o - no); + F B = exp(u + k[ii] - no); + y[ii] = (A * p + B * v[ii]) / (A * q + B); + + no = max(w + o, k[ii]); + A = exp(w + o - no); + B = exp(k[ii] - no); + p = A * p + B * v[ii]; + q = A * q + B; + o = no; + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, + F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const gy = _gy + _offset; + + F *__restrict__ const gk = _gk + _offset; + F *__restrict__ const gv = _gv + _offset; + + F y[Tmax], z[Tmax], zexp[Tmax]; + + F gw = 0, gu = 0; + F p = 0, q = 0; + F dpdw = 0, dqdw = 0; + F o = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + F no = max(o, k[ii] + u); + F A = exp(o - no); + F B = exp(k[ii] + u - no); + + F num = A * p + B * v[ii]; + F iden = 1 / (A * q + B); + + y[i] = num * iden; + z[i] = iden; + zexp[i] = k[ii] + u - no; + + gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; + gu += gy[ii] * (v[ii] - y[i]) * B * iden; + + no = max(w + o, k[ii]); + A = exp(w + o - no); + B = exp(k[ii] - no); + dpdw = A * (p + dpdw); + dqdw = A * (q + dqdw); + p = A * p + B * v[ii]; + q = A * q + B; + o = no; + } + + F gp = 0, gq = 0; + o = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + F A = gy[ii] * z[i] * exp(zexp[i]); + F B = exp(k[ii] + o); + gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); + gv[ii] = A + B * gp; + + F no = max(w + o, zexp[i] - k[ii] - u); + A = exp(w + o - no); + B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); + gp = A * gp + B; + gq = A * gq - B * y[i]; + o = no; + } + + // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] += gw * _w[_c]; + _gu[_offsetBC] += gu; +} + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { + dim3 threadsPerBlock( min(C, 1024) ); + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { + dim3 threadsPerBlock( min(C, 1024) ); + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); +} diff --git a/RWKV-v4/cuda/wkv_op.cpp b/RWKV-v4/cuda/wkv_op.cpp new file mode 100644 index 0000000..efe56d8 --- /dev/null +++ b/RWKV-v4/cuda/wkv_op.cpp @@ -0,0 +1,21 @@ +#include + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); + +void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("backward", &backward, "wkv backward"); +} + +TORCH_LIBRARY(wkv, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/RWKV-v4/deepspeed.json b/RWKV-v4/deepspeed.json new file mode 100644 index 0000000..6bbfe74 --- /dev/null +++ b/RWKV-v4/deepspeed.json @@ -0,0 +1,37 @@ +{ + "zero_allow_untested_optimizer":true, + "zero_optimization":{ + "stage":2, + "contiguous_gradients":true, + "overlap_comm":true, + "allgather_partitions":true, + "reduce_scatter":true, + "allgather_bucket_size":200000000, + "reduce_bucket_size":200000000, + "sub_group_size":1000000000000 + }, + "activation_checkpointing":{ + "partition_activations":false, + "cpu_checkpointing":false, + "contiguous_memory_optimization":false, + "synchronize_checkpoint_boundary":false + }, + "aio":{ + "block_size":1048576, + "queue_depth":8, + "single_submit":false, + "overlap_events":true, + "thread_count":1 + }, + "gradient_clipping": 1.0, + "gradient_accumulation_steps": 1, + "fp16": { + "fp16": true, + "enabled": true, + "loss_scale": 0, + "initial_scale_power": 12, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + } +} diff --git a/RWKV-v4/run.py b/RWKV-v4/run.py new file mode 100644 index 0000000..c6862a5 --- /dev/null +++ b/RWKV-v4/run.py @@ -0,0 +1,98 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import numpy as np +import math +import time +import types +import copy +import torch +from torch.nn import functional as F +from src.utils import TOKENIZER, Dataset +from src.model_run import RWKV_RNN +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True +np.set_printoptions(precision=4, suppress=True, linewidth=200) + +### Step 1: set model ################################################################################## + +ctx_len = 1024 +n_layer = 6 +n_embd = 512 +model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' + +# your trained model +MODEL_NAME = 'trained-1' +WORD_NAME = 'vocab' # the .json vocab (generated by train.py + +# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <-- +# --> all unknown tokens in your context will be denoted by it <-- +UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity + +RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda' +DEBUG_DEBUG = False # True False - show softmax output + +### Step 2: set context ################################################################################ + +context = "\nIn the" # ==> this is your prompt + +NUM_TRIALS = 999 +LENGTH_PER_TRIAL = 500 + +TEMPERATURE = 1.0 +top_p = 0.7 +top_p_newline = 0.9 + +######################################################################################################## + +print(f'Loading {MODEL_NAME}...') +model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) +tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) + +######################################################################################################## + +context = tokenizer.refine_context(context) +print('\nYour prompt has ' + str(len(context)) + ' tokens.') +print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n') + +for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): + t_begin = time.time_ns() + + src_len = len(context) + ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] + print(('-' * 30) + context, end='') + + model.clear() + if TRIAL == 0: + init_state = types.SimpleNamespace() + for i in range(src_len): + x = ctx[:i+1] + if i == src_len - 1: + init_state.out = model.run(x) + else: + model.run(x) + model.save(init_state) + else: + model.load(init_state) + + for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): + x = ctx[:i+1] + x = x[-ctx_len:] + + if i == src_len: + out = copy.deepcopy(init_state.out) + else: + out = model.run(x) + if DEBUG_DEBUG: + print('model', np.array(x), '==>', np.array( + out), np.max(out), np.min(out)) + + char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, + top_p_usual=top_p, top_p_newline=top_p_newline) + char = char.item() + print(tokenizer.itos[int(char)], end='', flush=True) + ctx += [char] + t_end = time.time_ns() + print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ') diff --git a/RWKV-v4/src/model.py b/RWKV-v4/src/model.py new file mode 100644 index 0000000..6151b65 --- /dev/null +++ b/RWKV-v4/src/model.py @@ -0,0 +1,348 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import math, os +import numpy as np +import logging +import torch +import torch.nn as nn +from torch.nn import functional as F +from deepspeed.ops.adam import FusedAdam + +logger = logging.getLogger(__name__) + +RWKV_HEAD_QK_DIM = 256 +print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') + +######################################################################################################## +# CUDA Kernel +######################################################################################################## + +T_MAX = 4096 # increase this if your ctx_len is long +# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice + +from torch.utils.cpp_extension import load +wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], + verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}']) + +class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + ctx.save_for_backward(w, u, k, v) + y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + return y.half() + + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w, u, k, v = ctx.saved_tensors + gw = torch.zeros((B, C), device='cuda') + gu = torch.zeros((B, C), device='cuda') + gk = torch.zeros((B, T, C), device='cuda') + gv = torch.zeros((B, T, C), device='cuda') + wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) + +def RUN_CUDA(B, T, C, w, u, k, v): + return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) + +######################################################################################################## +# RWKV: RWKV Time-mix + RWKV Channel-mix +######################################################################################################## + +def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module + print('\n[--> first run, init model params (very slow for large models) <--]\n') + print('\n[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n') + for m in module.modules(): + if not isinstance(m, (nn.Linear, nn.Embedding)): + continue + with torch.no_grad(): + name = '[unknown weight]' + for name, parameter in module.named_parameters(): # find the name of the weight + if id(m.weight) == id(parameter): + break + + shape = m.weight.data.shape + gain = 1.0 + scale = 1.0 # extra scale for gain + + if isinstance(m, nn.Embedding): + gain = math.sqrt(max(shape[0], shape[1])) + if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb? + scale = 1e-4 + else: + scale = 0 + + if isinstance(m, nn.Linear): + if m.bias is not None: + m.bias.data.zero_() + if shape[0] > shape[1]: + gain = math.sqrt(shape[0] / shape[1]) + if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection? + scale = 0.5 + + if hasattr(m, 'scale_init'): + scale = m.scale_init + + # print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name) + + gain *= scale + if scale == -999: + nn.init.eye_(m.weight) + elif gain == 0: + # zero init is great for some RWKV matrices + nn.init.zeros_(m.weight) + elif gain > 0: + nn.init.orthogonal_(m.weight, gain=gain) + else: + nn.init.normal_(m.weight, mean=0.0, std=-scale) + + +class RWKV_TimeMix(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + self.layer_id = layer_id + self.ctx_len = config.ctx_len + self.n_embd = config.n_embd + + attn_sz = config.n_embd + + with torch.no_grad(): # fancy init + ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1 + ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0 + + # fancy time_decay + decay_speed = torch.ones(attn_sz) + for h in range(attn_sz): + decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1) + self.time_decay = nn.Parameter(decay_speed) + # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) + + # fancy time_first + zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5) + self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag) + + # fancy time_mix + x = torch.ones(1, 1, config.n_embd) + for i in range(config.n_embd): + x[0, 0, i] = i / config.n_embd + self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + self.key = nn.Linear(config.n_embd, attn_sz, bias=False) + self.value = nn.Linear(config.n_embd, attn_sz, bias=False) + self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False) + + self.output = nn.Linear(attn_sz, config.n_embd, bias=False) + + self.key.scale_init = 0 + self.receptance.scale_init = 0 + self.output.scale_init = 0 + + def forward(self, x): + B, T, C = x.size() # x = (Batch,Time,Channel) + + # Mix x with the previous timestep to produce xk, xv, xr + xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + + # Use xk, xv, xr to produce k, v, r + k = self.key(xk) + v = self.value(xv) + r = self.receptance(xr) + + rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) + rwkv = self.output(rwkv) + return rwkv + + +class RWKV_ChannelMix(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + self.layer_id = layer_id + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0 + + x = torch.ones(1, 1, config.n_embd) + for i in range(config.n_embd): + x[0, 0, i] = i / config.n_embd + + self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + + hidden_sz = 4 * config.n_embd + self.key = nn.Linear(config.n_embd, hidden_sz, bias=False) + self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False) + self.value = nn.Linear(hidden_sz, config.n_embd, bias=False) + + self.value.scale_init = 0 + self.receptance.scale_init = 0 + + def forward(self, x): + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + + k = self.key(xk) + k = torch.square(torch.relu(k)) + kv = self.value(k) + + rkv = torch.sigmoid(self.receptance(xr)) * kv + return rkv + +######################################################################################################## +# The GPT Model with our blocks +######################################################################################################## + + +class GPTConfig: + def __init__(self, vocab_size, ctx_len, **kwargs): + self.vocab_size = vocab_size + self.ctx_len = ctx_len + for k, v in kwargs.items(): + setattr(self, k, v) + + +class Block(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + self.config = config + self.layer_id = layer_id + + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + + if self.layer_id == 0: + self.ln0 = nn.LayerNorm(config.n_embd) + + if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': + self.ffnPre = RWKV_ChannelMix(config, layer_id+1000) + else: + self.att = RWKV_TimeMix(config, layer_id) + + self.ffn = RWKV_ChannelMix(config, layer_id) + + def forward(self, x): + if self.layer_id == 0: + x = self.ln0(x) + if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': + x = x + self.ffnPre(self.ln1(x)) # better in some cases + else: + x = x + self.att(self.ln1(x)) + x = x + self.ffn(self.ln2(x)) + return x + + +class GPT(nn.Module): + def __init__(self, config): + super().__init__() + self.step = 0 + self.config = config + + self.emb = nn.Embedding(config.vocab_size, config.n_embd) + + self.blocks = nn.Sequential(*[Block(config, i) + for i in range(config.n_layer)]) + + self.ln_out = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + if RWKV_HEAD_QK_DIM > 0: + self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False) + self.head_q.scale_init = 0 + self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False) + self.head_k.scale_init = 0.1 + self.register_buffer("copy_mask", torch.tril( + torch.ones(config.ctx_len, config.ctx_len))) + + self.ctx_len = config.ctx_len + + try: + if os.environ['RWKV_LOAD_MODEL'] == str(False): + RWKV_Init(self, config) + except: + pass + + logger.info("number of parameters: %e", sum(p.numel() + for p in self.parameters())) + + def get_ctx_len(self): + return self.ctx_len + + def _init_weights(self, module): + if isinstance(module, (nn.Linear)): + module.weight.data.normal_(mean=0.0, std=0.01) + if isinstance(module, (nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=1e-5) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def configure_optimizers(self, train_config): + no_decay = set() + + for mn, m in self.named_modules(): # here we disable weight_decay + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + no_decay.add(fpn) + + param_dict = {pn: p for pn, p in self.named_parameters()} + optim_groups = [ + {"params": [param_dict[pn] + for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + + optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + + return optimizer + + def forward(self, idx, targets=None): + idx = idx.to(self.emb.weight.device) + + self.step += 1 + B, T = idx.size() + assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len." + + x = self.emb(idx) + x = self.blocks(x) + x = self.ln_out(x) + + if RWKV_HEAD_QK_DIM > 0: + q = self.head_q(x)[:, :T, :] + k = self.head_k(x)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM) + c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) + c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half() + x = self.head(x) + c + else: + x = self.head(x) + + loss = None + if targets is not None: + loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1)) + + return x, loss diff --git a/RWKV-v4/src/model_run.py b/RWKV-v4/src/model_run.py new file mode 100644 index 0000000..7eb3809 --- /dev/null +++ b/RWKV-v4/src/model_run.py @@ -0,0 +1,366 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import types +import copy +import torch +import math +from torch.nn import functional as F +import torch.nn as nn + +RWKV_HEAD_QK_DIM = 256 +print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') + +DEBUG_TIME = False # True False - show trained time-coeffs + +######################################################################################################## +# CUDA Kernel +######################################################################################################## + +T_MAX = 4096 # increase this if your ctx_len is long +# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice + +from torch.utils.cpp_extension import load +wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], + verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}']) + +class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + ctx.save_for_backward(w, u, k, v) + y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + return y.half() + + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w, u, k, v = ctx.saved_tensors + gw = torch.zeros((B, C), device='cuda') + gu = torch.zeros((B, C), device='cuda') + gk = torch.zeros((B, T, C), device='cuda') + gv = torch.zeros((B, T, C), device='cuda') + wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) + +def RUN_CUDA(B, T, C, w, u, k, v): + return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) + +############################################################################################################ + +RWKV_CFG = types.SimpleNamespace() + +class RWKV_ChannelMix(nn.Module): + def __init__(self, layer_id): + super().__init__() + self.layer_id = layer_id + + self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) + self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) + + hidden_sz = 4 * RWKV_CFG.n_embd + self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False) + self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False) + + def forward(self, x): + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + + k = self.key(xk) + k = torch.square(torch.relu(k)) + kv = self.value(k) + + rkv = torch.sigmoid(self.receptance(xr)) * kv + return rkv + +class RWKV_TimeMix(nn.Module): + def __init__(self, layer_id): + super().__init__() + self.layer_id = layer_id + self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd)) + self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd) * math.log(0.3)) + + self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) + self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) + self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) + + self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + + self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + + def forward(self, x): + B, T, C = x.size() + + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + + k = self.key(xk) + v = self.value(xv) + r = self.receptance(xr) + + rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) + + rwkv = self.output(rwkv) + return rwkv + +class Block(nn.Module): + def __init__(self, layer_id): + super().__init__() + self.layer_id = layer_id + + self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd) + self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd) + if self.layer_id == 0: + self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd) + + if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre': + self.ffnPre = RWKV_ChannelMix(layer_id+1000) + else: + self.att = RWKV_TimeMix(layer_id) + + self.ffn = RWKV_ChannelMix(layer_id) + + def forward(self, x): + if self.layer_id == 0: + x = self.ln0(x) + if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre': + x = x + self.ffnPre(self.ln1(x)) + else: + x = x + self.att(self.ln1(x)) + x = x + self.ffn(self.ln2(x)) + return x + +class RWKV_GPT(nn.Module): + def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len): + global RWKV_CFG + super().__init__() + + RWKV_CFG.RUN_DEVICE = RUN_DEVICE + RWKV_CFG.model_type = model_type + RWKV_CFG.vocab_size = vocab_size + RWKV_CFG.n_layer = n_layer + RWKV_CFG.n_embd = n_embd + RWKV_CFG.ctx_len = ctx_len + + print('\nloading RWKV-GPT', MODEL_NAME) + + self.emb = nn.Embedding(vocab_size, n_embd) + + self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)]) + + self.ln_out = nn.LayerNorm(n_embd) + self.head = nn.Linear(n_embd, vocab_size, bias=False) + + if RWKV_HEAD_QK_DIM > 0: + self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False) + self.head_q.scale_init = 0 + self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False) + self.head_k.scale_init = 0.1 + self.register_buffer("copy_mask", torch.tril( + torch.ones(ctx_len, ctx_len))) + + self.ctx_len = ctx_len + self.eval() + self.load_state_dict(torch.load(MODEL_NAME + '.pth')) + self.eval() + + def forward(self, idx): + B, T = idx.size() + assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len." + + x = self.emb(idx) + x = self.blocks(x) + x = self.ln_out(x) + + if RWKV_HEAD_QK_DIM > 0: + q = self.head_q(x)[:, :T, :] + k = self.head_k(x)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM) + c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) + + c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).float() + x = self.head(x) + c + else: + x = self.head(x) + + return x + +############################################################################################################ + +class RWKV_RNN(): # this is running in FP32 at this moment + def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len): + self.RUN_DEVICE = RUN_DEVICE + self.model_type = model_type + self.n_layer = n_layer + self.n_embd = n_embd + self.ctx_len = ctx_len + + self.w = types.SimpleNamespace() + + w = torch.load(MODEL_NAME + '.pth', + map_location=torch.device(RUN_DEVICE)) + for x in w.keys(): + w[x] = w[x].float() + if '.time_' in x: + w[x] = w[x].squeeze() + if '.time_decay' in x: + w[x] = -torch.exp(w[x]) + if DEBUG_TIME and '.time_' in x: + print(x, w[x].squeeze().cpu().numpy()) + + xx = x.split('.') + here = self.w + for i in range(len(xx)): + if xx[i].isdigit(): + ii = int(xx[i]) + if ii not in here: + here[ii] = types.SimpleNamespace() + here = here[ii] + else: + if i == len(xx) - 1: + setattr(here, xx[i], w[x]) + elif not hasattr(here, xx[i]): + if xx[i+1].isdigit(): + setattr(here, xx[i], {}) + else: + setattr(here, xx[i], types.SimpleNamespace()) + here = getattr(here, xx[i]) + + self.clear() + + def clear(self): + self.xx = {} + self.aa = {} + self.bb = {} + self.pp = {} + self.hk = None + + def save(self, target): + target.xx = copy.deepcopy(self.xx) + target.aa = copy.deepcopy(self.aa) + target.bb = copy.deepcopy(self.bb) + target.pp = copy.deepcopy(self.pp) + target.hk = copy.deepcopy(self.hk) + + def load(self, target): + self.xx = copy.deepcopy(target.xx) + self.aa = copy.deepcopy(target.aa) + self.bb = copy.deepcopy(target.bb) + self.pp = copy.deepcopy(target.pp) + self.hk = copy.deepcopy(target.hk) + + def LN(self, xx, w): + return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias) + + def FF(self, xx, w, name): + if name not in self.xx: + self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k) + xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r) + self.xx[name] = xx + + r = torch.sigmoid(w.receptance.weight @ xr) + k = torch.square(torch.relu(w.key.weight @ xk)) + kv = w.value.weight @ k + + return r * kv + + def SA(self, xx, w, name): + if name not in self.xx: + self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30 + + xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k) + xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v) + xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r) + self.xx[name] = xx + + r = torch.sigmoid(w.receptance.weight @ xr) + + k = w.key.weight @ xk + v = w.value.weight @ xv + + pp = self.pp[name] + aa = self.aa[name] + bb = self.bb[name] + ww = w.time_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + a = e1 * aa + e2 * v + b = e1 * bb + e2 + ww = pp + w.time_decay + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + self.aa[name] = e1 * aa + e2 * v + self.bb[name] = e1 * bb + e2 + self.pp[name] = p + + rwkv = r * a / b + + return w.output.weight @ rwkv + + def run(self, ctx): + w = self.w + x = w.emb.weight[ctx[-1]] + + for i in range(self.n_layer): + if i == 0: + x = self.LN(x, w.blocks[i].ln0) + if i == 0 and self.model_type == 'RWKV-ffnPre': + x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}') + else: + x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}') + x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}') + + x = self.LN(x, w.ln_out) + + if RWKV_HEAD_QK_DIM > 0: + if self.hk == None: + self.hk = (w.head_k.weight @ x).unsqueeze(0) + else: + self.hk = torch.cat( + [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0) + if self.hk.shape[0] > self.ctx_len: + self.hk = self.hk[-self.ctx_len:, :] + + q = w.head_q.weight @ x + + x = w.head.weight @ x + x = x.cpu().numpy().tolist() + + c = (self.hk @ q) / RWKV_HEAD_QK_DIM + for i in range(len(c)): + x[ctx[i]] += c[i] + else: + x = w.head.weight @ x + x = x.cpu().numpy().tolist() + + return x diff --git a/RWKV-v4/src/trainer.py b/RWKV-v4/src/trainer.py new file mode 100644 index 0000000..2966daa --- /dev/null +++ b/RWKV-v4/src/trainer.py @@ -0,0 +1,177 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import os +NUM_GPUS = int(os.environ['RWKV_NUM_GPUS']) +USE_WANDB = (int(os.environ['USE_WANDB']) == 1) + +from torch.utils.data.dataloader import DataLoader +import torch +from tqdm.auto import tqdm +import logging +import datetime +import math +from pytorch_lightning.lite import LightningLite + +logger = logging.getLogger(__name__) +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True + +class TrainerConfig: + batch_size = 64 + learning_rate = 4e-4 + betas = (0.9, 0.99) + eps = 1e-8 + grad_norm_clip = 1.0 + warmup_tokens = 0 + final_tokens = 0 + epoch_save_frequency = 0 + epoch_save_path = 'trained-' + num_workers = 0 # for DataLoader + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + +from src.model import GPT, GPTConfig + +class Trainer(LightningLite): + + def get_run_name(self): + raw_model = self.model.module if hasattr( + self.model, "module") else self.model + cfg = raw_model.config + run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \ + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) + return run_name + + def run(self, m_cfg, train_dataset, test_dataset, config): + self.cuda_id = int(str(self.device).strip('cuda:')) + print('[0]') + model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=m_cfg.model_type, + n_layer=m_cfg.n_layer, n_embd=m_cfg.n_embd)) + print('[1]') + model.to(self.device) + print('[2]') + with torch.no_grad(): + if m_cfg.LOAD_MODEL: + print('loading', m_cfg.MODEL_NAME) + m2 = torch.load(m_cfg.MODEL_NAME + '.pth', map_location=torch.device(self.device)) + model.load_state_dict(m2) + del m2 + + self.model = model + self.train_dataset = train_dataset + self.test_dataset = test_dataset + self.config = config + self.avg_loss = -1 + self.EPOCH_BEGIN = m_cfg.EPOCH_BEGIN + + self.steps = self.EPOCH_BEGIN * (len(self.train_dataset) // (config.batch_size // NUM_GPUS)) + + if self.cuda_id == 0: + log_file = open("mylog.txt", "a") + if USE_WANDB: + print('logging to wandb... (comment it if you don\'t have wandb)') + import wandb # comment this if you don't have wandb + cfg = model.config + for k in config.__dict__: + setattr(cfg, k, config.__dict__[k]) # combine cfg + wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False) + + model, config = self.model, self.config + raw_model = model.module if hasattr(self.model, "module") else model + optimizer = raw_model.configure_optimizers(config) + model, optimizer = self.setup(model, optimizer) + print('[3]') + + def run_epoch(split): + is_train = split == 'train' + model.train(is_train) + data = self.train_dataset if is_train else self.test_dataset + data.idx_begin = self.steps * config.batch_size + 1 + data.cuda_id = self.cuda_id + + if config.num_workers > 0: + loader = DataLoader(data, shuffle=False, pin_memory=True, + batch_size=config.batch_size // NUM_GPUS, + num_workers=config.num_workers) + else: + loader = DataLoader(data, shuffle=False, + batch_size=config.batch_size // NUM_GPUS, + num_workers=config.num_workers) + + pbar = tqdm(enumerate(loader), total=len( + loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) + loader = self.setup_dataloaders(loader) + + for it, (x, y) in pbar: + with torch.set_grad_enabled(is_train): + _, loss = model(x, y) # forward the model + + all_loss = [loss.clone() for _ in range(NUM_GPUS)] + torch.distributed.all_gather(all_loss, loss) + + if is_train: # backprop and update the parameters + model.zero_grad() + self.backward(loss) + + # deepspeed will handle gradient_clipping + + optimizer.step() + + # decay the learning rate based on our progress + self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) + lr_final_factor = config.lr_final / config.learning_rate + if self.tokens < config.warmup_tokens: + # linear warmup + lr_mult = lr_final_factor + \ + (1 - lr_final_factor) * float(self.tokens) / \ + float(config.warmup_tokens) + progress = 0 + else: + # exponential learning rate decay + progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) + if progress >= 1: + lr_mult = lr_final_factor + else: + lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1)) + lr = config.learning_rate * lr_mult + + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + self.lr = lr + self.steps += 1 + + now_loss = 0 + for gg in range(NUM_GPUS): + now_loss += all_loss[gg].item() + now_loss = now_loss / NUM_GPUS # report progress + if USE_WANDB and self.cuda_id == 0: + wandb.log({"loss": now_loss}, step = self.steps) + + if self.avg_loss < 0: + self.avg_loss = now_loss + else: + factor = 1 / (it + 1) + self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor + + pbar.set_description(f"miniE {epoch+1+self.EPOCH_BEGIN} s {self.steps} prog {progress*100.0:.2f}% : ppl {math.exp(self.avg_loss):.6f} loss {self.avg_loss:.6f} lr {lr:e}") + + self.tokens = 0 # counter used for learning rate decay + for epoch in range(99999999): + + run_epoch('train') + if math.isnan(self.avg_loss): + exit(0) + + if self.cuda_id == 0: + log_file.write(f'{epoch+1+self.EPOCH_BEGIN} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} {epoch+1} \n') + log_file.flush() + + if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1): + raw_model = self.model.module if hasattr(self.model, "module") else self.model + torch.save(raw_model.state_dict(), self.config.epoch_save_path + str(epoch+1+self.EPOCH_BEGIN) + '.pth') diff --git a/RWKV-v4/src/utils.py b/RWKV-v4/src/utils.py new file mode 100644 index 0000000..1cdf01f --- /dev/null +++ b/RWKV-v4/src/utils.py @@ -0,0 +1,122 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import os +try: + NUM_GPUS = int(os.environ['RWKV_NUM_GPUS']) +except: + NUM_GPUS = 1 + +import json +import random +import numpy as np +import torch +from torch.nn import functional as F +from torch.utils.data import Dataset + +class Dataset(Dataset): + def __init__(self, data, ctx_len, epoch_length_fixed): + print('building token list...', end=' ') + unique = sorted(list(set(data))) + # print() + # for u in unique: + # print(u, end=' ') + # print('\n\n') + + xx = 0 + xxObj = {} + for u in unique: + xxObj[xx] = u + xx += 1 + with open('vocab.json', "w", encoding="utf-16") as vocab_file: + vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) + + data_size, vocab_size = len(data), len(unique) + print('data has %d tokens, %d unique.' % (data_size, vocab_size)) + self.stoi = {ch: i for i, ch in enumerate(unique)} + self.itos = {i: ch for i, ch in enumerate(unique)} + self.ctx_len = ctx_len + self.epoch_length_fixed = epoch_length_fixed + self.vocab_size = vocab_size + self.data = data + + def __len__(self): + return self.epoch_length_fixed // NUM_GPUS + + def __getitem__(self, idx): + # cheat: pick a random spot in dataset + i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) + chunk = self.data[i:i+self.ctx_len+1] + dix = [self.stoi[s] for s in chunk] + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) + return x, y + + +class TOKENIZER(): + def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): + with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: + self.word_table = json.load(result_file) + + self.vocab_size = len(self.word_table) + + self.stoi = {v: int(k) for k, v in self.word_table.items()} + self.itos = {int(k): v for k, v in self.word_table.items()} + + self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] + + def refine_context(self, context): + context = context.strip().split('\n') + for c in range(len(context)): + context[c] = context[c].strip().strip('\u3000').strip('\r') + context = list(filter(lambda c: c != '', context)) + context = '\n' + ('\n'.join(context)).strip() + if context == '': + context = '\n' + + return context + + def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): + # out[self.UNKNOWN_CHAR] = -float('Inf') + + lastChar = int(x[-1]) + + probs = F.softmax(torch.tensor(out), dim=-1) + + if self.itos[lastChar] == '\n': + top_p = top_p_newline + else: + top_p = top_p_usual + + sorted_probs, s_index = torch.sort(probs, descending=True) + + # for j in range(30): + # pp = sorted_probs[j].item() + # if pp < 0.005: + # break + # ss = self.itos[int(s_index[j])].replace('\n','_') + # print(f'{math.floor(pp*100):>3.0f}{ss}', end='') + # print('') + + cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy() + cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + + probs[probs < cutoff] = 0 + # print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "") + + if temperature != 1.0: + probs = probs.pow(1.0 / temperature) + + return torch.multinomial(probs, num_samples=1)[0] + + +def to_float(x): + return x.cpu().detach().numpy().flatten()[0].astype(float) + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/RWKV-v4/train.py b/RWKV-v4/train.py new file mode 100644 index 0000000..865a715 --- /dev/null +++ b/RWKV-v4/train.py @@ -0,0 +1,135 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import os + +os.environ['USE_WANDB'] = '0' # 0 = False, 1 = True + +### This is using DeepSpeed stage2 + FP16 ############################################################## + +os.environ['RWKV_NUM_GPUS'] = '1' # num of GPUs to use +NUM_GPUS = int(os.environ['RWKV_NUM_GPUS']) + +### Change these if you want to continue training from a saved model ################################### + +EPOCH_BEGIN = 0 +LOAD_MODEL = False # shall we continue from the #EPOCH_BEGIN model? +os.environ['RWKV_LOAD_MODEL'] = str(LOAD_MODEL) + +######################################################################################################## + +# if False: # True False ---> Set to False if you don't understand it +# print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n") +# import src.utils +# src.utils.set_seed(42) # make training deterministic (including dataloader). if you are doing this, remember to change seed when you load a model (otherwise the dataloader loads old samples) + +import logging, types +from src.utils import Dataset +import torch +import numpy as np + +np.set_printoptions(precision=4, suppress=True, linewidth=200) +logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,) +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True + +### Step 1: set training data ########################################################################## + +datafile = "../data/enwik8" # your data +datafile_encoding = 'utf-8' +# datafile_encoding = 'utf-16le' + +### Step 2: set model size ############################################################################# + +ctx_len = 1024 # increase T_MAX in model.py if your ctx_len is very long +n_layer = 6 +n_embd = 512 + +# 'RWKV' or 'RWKV-ffnPre' (better in some cases) +model_type = 'RWKV' + +# ---> there is also a RWKV_HEAD_QK_DIM in model.py and model_run.py <--- +# set it to 256, then it's using my headQK trick (similar to a tiny attention) to improve loss +# set it to 0, then it's a pure RNN (attention-free) + +### Step 3: set batch size ############################################################################# + +# if you see "CUDA out of memory", reduce batch_size. Use nvidia-smi to find the highest value for your GPU. +batch_size = 12 +assert (batch_size % NUM_GPUS == 0) + +### Step 4: set learning rate, number of mini-epochs ####################################################### +# +# By default we are using exponential LR decay. +# Here are my suggestions for training. +# Let's say you are training a L6-D512 model. +# 1) Set lr_init = lr_final = 8e-4. Let it run for some mini-epochs, until you feel like reducing LR. +# 2) Check epoch_save_frequency and make sure the partially-trained model is saved. Ctrl+C to stop the run. +# 3) Set lr_init = 8e-4, lr_final = 1e-5, betas = (0.9, 0.999). +# 4) Set EPOCH_BEGIN & LOAD_MODEL to load the partially-trained model. Continue the training. +# +# For L12-D768, set lr_init = 6e-4. For L24-D1024, set lr_init = 4e-4. For L24-D2048, set lr_init = 3e-4. + +lr_init = 8e-4 +lr_final = 1e-5 + +# the mini-epoch is very short and of fixed length (length = ctx_len * epoch_length_fixed tokens) +n_epoch = 500 +epoch_length_fixed = (10000 // batch_size) * batch_size # feel free to increase it if you have lots of GPU + +# epoch_save_frequency 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, ... +epoch_save_frequency = 10 +epoch_save_path = 'trained-' +MODEL_NAME = epoch_save_path + str(EPOCH_BEGIN) + +######################################################################################################## + +if LOAD_MODEL and EPOCH_BEGIN > 0: # we are not saving gradients. so let's have some warmup if we load a model + warmup_tokens = ctx_len * batch_size * 50 +else: + warmup_tokens = ctx_len * batch_size * 0 + +betas = (0.9, 0.99) +eps = 1e-8 + +num_workers = 1 # DataLoader worker. I only tested num_workers = 1 + +######################################################################################################## +# Load data +######################################################################################################## + +print('loading data... ' + datafile) +train_dataset = Dataset(open( + datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed) + +######################################################################################################## +# Train model +######################################################################################################## +if __name__ == '__main__': + from src.trainer import Trainer, TrainerConfig + + print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', + betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, ) + + tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, + learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, + warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path) + m_cfg = types.SimpleNamespace() + m_cfg.model_type = model_type + m_cfg.n_layer = n_layer + m_cfg.n_embd = n_embd + m_cfg.EPOCH_BEGIN = EPOCH_BEGIN + m_cfg.LOAD_MODEL = LOAD_MODEL + m_cfg.MODEL_NAME = MODEL_NAME + + from pytorch_lightning.strategies import DeepSpeedStrategy + + # you can set grad_norm_clip in deepspeed.json + + trainer = Trainer(strategy=DeepSpeedStrategy(config='deepspeed.json'), devices=NUM_GPUS, accelerator="gpu", precision=16) + print(trainer._strategy.config) + + trainer.run(m_cfg, train_dataset, None, tconf) diff --git a/RWKV-v4/verify.py b/RWKV-v4/verify.py new file mode 100644 index 0000000..75e8e55 --- /dev/null +++ b/RWKV-v4/verify.py @@ -0,0 +1,63 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +# this is for verifying the results of different models and make sure they agree with each other + +import numpy as np +np.set_printoptions(precision=4, suppress=True, linewidth=200) + +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +RUN_DEVICE = 'cuda' + +import torch +from src.model_run import RWKV_RNN, RWKV_GPT +from src.model import GPT, GPTConfig + +ctx_len = 1024 +n_layer = 6 +n_embd = 512 +model_type = 'RWKV' + +model_name = 'trained-1' + +from src.utils import TOKENIZER +tokenizer = TOKENIZER('vocab', UNKNOWN_CHAR=' ') + +######################################################################################################## + +model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda().half() +print('loading ' + model_name) +m2 = torch.load(model_name + '.pth', map_location=RUN_DEVICE) +model_train.load_state_dict(m2) + +model_rnn = RWKV_RNN(model_name, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) +model_gpt = RWKV_GPT(model_name, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len).cuda() + +######################################################################################################## + +context = '\nIn a' +ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] +print(f'input len {len(ctx)} data {ctx}') + +######################################################################################################## + +print('\nRWKV-GPT output') +out = model_gpt.forward(torch.tensor(ctx).unsqueeze(0).cuda())[0].detach().cpu().numpy() +print(out) + +print('\nRWKV-RNN output') +model_rnn.clear() +src_len = len(ctx) +for i in range(src_len): + x = ctx[:i+1] + out = model_rnn.run(x) + if i < 3 or i >= src_len - 3: + print(torch.tensor(out).detach().cpu().numpy()) + if i == 2: + print('...') + +print('\nRWKV-train output') +out = model_train.forward(torch.tensor([ctx]).cuda())[0][0].detach().cpu().numpy() +print(out, '\n')