From 3a7e6a6aa34f56ef443f5539f4a687b6ebb352b9 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Sat, 17 Sep 2022 10:39:25 +0000 Subject: [PATCH] faster inference --- RWKV-v4neo/run.py | 184 ++++++++++++++++++++++++++++++++++ RWKV-v4neo/src/model_run.py | 192 ++++++++++++++++++++++++++++++++++++ RWKV-v4neo/src/utils.py | 82 ++++++++++++++- 3 files changed, 457 insertions(+), 1 deletion(-) create mode 100644 RWKV-v4neo/run.py create mode 100644 RWKV-v4neo/src/model_run.py diff --git a/RWKV-v4neo/run.py b/RWKV-v4neo/run.py new file mode 100644 index 0000000..2d1cbed --- /dev/null +++ b/RWKV-v4neo/run.py @@ -0,0 +1,184 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import numpy as np +import math, os +import time +import types +import copy +import torch +from src.utils import TOKENIZER + +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True +np.set_printoptions(precision=4, suppress=True, linewidth=200) + +######################################################################################################## +# Step 1: set model +# +# Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch. +# +# Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models. +######################################################################################################## + +TOKEN_MODE = "pile" # char / bpe / pile + +n_layer = 6 +n_embd = 512 +ctx_len = 1024 + +if TOKEN_MODE == "char": + MODEL_NAME = "trained-500" # your trained model + WORD_NAME = "vocab" # the .json vocab (generated by train.py) + # set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it + UNKNOWN_CHAR = " " # here we just set it to ' ' for simplicity + +elif TOKEN_MODE == "bpe": + MODEL_NAME = "trained-500" # your trained model + WORD_NAME = [ + "model-vocab.json", + "model-merges.txt", + ] # [vocab, merge] for your BPE model + UNKNOWN_CHAR = None + +elif TOKEN_MODE == "pile": + WORD_NAME = [ + "20B_tokenizer.json", + "20B_tokenizer.json", + ] # [vocab, vocab] for Pile model + UNKNOWN_CHAR = None + + # ---> you can set MODEL_NAME to your fine-tuned model <--- + + # MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023" + # n_layer = 12 + # n_embd = 768 + # ctx_len = 1024 + + # MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066' + # n_layer = 24 + # n_embd = 1024 + # ctx_len = 1024 + + MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-1B5-20220903-8040' + n_layer = 24 + n_embd = 2048 + ctx_len = 1024 + +os.environ["RWKV_FLOAT_MODE"] = "fp32" # currently only supprts fp32 +os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda' +model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre' + +######################################################################################################## +# Step 2: set prompt & sampling stuffs +######################################################################################################## + +# context = 'A' +# context = "\nIn the" +# context = '\nSugar:' +context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." + +NUM_TRIALS = 999 +LENGTH_PER_TRIAL = 333 + +TEMPERATURE = 1.0 +top_p = 0.8 +top_p_newline = 0.9 # only used in TOKEN_MODE = char + +DEBUG_DEBUG = False # True False --> show softmax output + +######################################################################################################## + +print(f"Loading {MODEL_NAME}...") +from src.model_run import RWKV_RNN + +model = RWKV_RNN( + MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len +) +tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) + +######################################################################################################## + +if tokenizer.charMode: + context = tokenizer.refine_context(context) + ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] +else: + ctx = tokenizer.tokenizer.encode(context) +src_len = len(ctx) +src_ctx = ctx.copy() + +print("\nYour prompt has " + str(src_len) + " tokens.") +print( + "\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n" +) + +# time_slot = {} +# time_ref = time.time_ns() + +# def record_time(name): +# if name not in time_slot: +# time_slot[name] = 1e20 +# tt = (time.time_ns() - time_ref) / 1e9 +# if tt < time_slot[name]: +# time_slot[name] = tt + +for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): + # time_ref = time.time_ns() + + print(("-" * 50) + context, end="") + ctx = src_ctx.copy() + model.clear() + + if TRIAL == 0: + init_state = types.SimpleNamespace() + for i in range(src_len): + x = ctx[: i + 1] + if i == src_len - 1: + init_state.out = model.forward(x) + else: + model.forward(x, preprocess_only=True) + model.save(init_state) + else: + model.load(init_state) + + # record_time('model_pre') + for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): + # time_ref = time.time_ns() + + x = ctx[: i + 1] + x = x[-ctx_len:] + + if i == src_len: + out = copy.deepcopy(init_state.out) + else: + out = model.forward(x) + # record_time('model_run') + if DEBUG_DEBUG: + print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy())) + + if TOKEN_MODE == "pile": + out[0] = -999999999 # disable <|endoftext|> + + time_ref = time.time_ns() + char = tokenizer.sample_logits( + out, + x, + ctx_len, + temperature=TEMPERATURE, + top_p_usual=top_p, + top_p_newline=top_p_newline, + ) + if tokenizer.charMode: + print(tokenizer.itos[char], end="", flush=True) + else: + print(tokenizer.tokenizer.decode(char), end="", flush=True) + ctx += [char] + + # record_time('model_sampling') + print() + # print(f'\n\n{time_slot}\n\n') + # print( + # f"\n--- preprocess {round((t_mid - t_begin) / (10 ** 9), 2)}s, generation {round((t_end - t_mid) / (10 ** 9), 2)}s", end = '' + # ) diff --git a/RWKV-v4neo/src/model_run.py b/RWKV-v4neo/src/model_run.py new file mode 100644 index 0000000..0d6499c --- /dev/null +++ b/RWKV-v4neo/src/model_run.py @@ -0,0 +1,192 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import types +import copy +import torch +import math, os +from torch.nn import functional as F +import torch.nn as nn + +def __nop(ob): + return ob + + +MyModule = nn.Module +MyFunction = __nop +# MyModule = torch.jit.ScriptModule +# MyFunction = torch.jit.script_method + +RWKV_HEAD_QK_DIM = 0 +print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') + +DEBUG_TIME = False # True False - show trained time-coeffs + +############################################################################################################ + +class RWKV_RNN(MyModule): # this is running in FP32 at this moment + def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len): + super().__init__() + + self.RUN_DEVICE = RUN_DEVICE + self.model_type = model_type + self.n_layer = n_layer + self.n_embd = n_embd + self.ctx_len = ctx_len + + self.w = types.SimpleNamespace() + + w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE)) + for x in w.keys(): + w[x] = w[x].float() + if '.time_' in x: + w[x] = w[x].squeeze() + if '.time_decay' in x: + w[x] = -torch.exp(w[x]) + if DEBUG_TIME and '.time_' in x: + print(x, w[x].squeeze().cpu().numpy()) + + xx = x.split('.') + here = self.w + for i in range(len(xx)): + if xx[i].isdigit(): + ii = int(xx[i]) + if ii not in here: + here[ii] = types.SimpleNamespace() + here = here[ii] + else: + if i == len(xx) - 1: + setattr(here, xx[i], w[x]) + elif not hasattr(here, xx[i]): + if xx[i+1].isdigit(): + setattr(here, xx[i], {}) + else: + setattr(here, xx[i], types.SimpleNamespace()) + here = getattr(here, xx[i]) + + self.clear() + self.eval() + + def clear(self): + self.xx = {} + self.aa = {} + self.bb = {} + self.pp = {} + self.hk = None + + def save(self, target): + target.xx = copy.deepcopy(self.xx) + target.aa = copy.deepcopy(self.aa) + target.bb = copy.deepcopy(self.bb) + target.pp = copy.deepcopy(self.pp) + target.hk = copy.deepcopy(self.hk) + + def load(self, target): + self.xx = copy.deepcopy(target.xx) + self.aa = copy.deepcopy(target.aa) + self.bb = copy.deepcopy(target.bb) + self.pp = copy.deepcopy(target.pp) + self.hk = copy.deepcopy(target.hk) + + @MyFunction + def LN(self, xx, w): + return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias) + + @MyFunction + def FF(self, xx, w, name): + if name not in self.xx: + self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k) + xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r) + self.xx[name] = xx + + r = torch.sigmoid(w.receptance.weight @ xr) + k = torch.square(torch.relu(w.key.weight @ xk)) + kv = w.value.weight @ k + + return r * kv + + @MyFunction + def SA(self, xx, w, name): + if name not in self.xx: + self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30 + + xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k) + xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v) + xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r) + self.xx[name] = xx + + r = torch.sigmoid(w.receptance.weight @ xr) + + k = w.key.weight @ xk + v = w.value.weight @ xv + + pp = self.pp[name] + aa = self.aa[name] + bb = self.bb[name] + ww = w.time_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + a = e1 * aa + e2 * v + b = e1 * bb + e2 + ww = pp + w.time_decay + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + self.aa[name] = e1 * aa + e2 * v + self.bb[name] = e1 * bb + e2 + self.pp[name] = p + + rwkv = r * a / b + + return w.output.weight @ rwkv + + def forward(self, ctx, preprocess_only = False): + with torch.no_grad(): + w = self.w + x = w.emb.weight[ctx[-1]] + + for i in range(self.n_layer): + if i == 0: + x = self.LN(x, w.blocks[i].ln0) + if i == 0 and self.model_type == 'RWKV-ffnPre': + x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}') + else: + x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}') + x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}') + + x = self.LN(x, w.ln_out) + + if RWKV_HEAD_QK_DIM > 0: + if self.hk == None: + self.hk = (w.head_k.weight @ x).unsqueeze(0) + else: + self.hk = torch.cat( + [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0) + if self.hk.shape[0] > self.ctx_len: + self.hk = self.hk[-self.ctx_len:, :] + + if preprocess_only: + return None + + q = w.head_q.weight @ x + + x = w.head.weight @ x + x = x + + c = (self.hk @ q) / RWKV_HEAD_QK_DIM + for i in range(len(c)): + x[ctx[i]] += c[i] + else: + if preprocess_only: + return None + + x = w.head.weight @ x + x = x + + return x diff --git a/RWKV-v4neo/src/utils.py b/RWKV-v4neo/src/utils.py index f2cbe99..ea25990 100644 --- a/RWKV-v4neo/src/utils.py +++ b/RWKV-v4neo/src/utils.py @@ -1,5 +1,85 @@ -import random +import json, time, random, os +import numpy as np +import torch +from torch.nn import functional as F +time_slot = {} +time_ref = time.time_ns() + +def record_time(name): + if name not in time_slot: + time_slot[name] = 1e20 + tt = (time.time_ns() - time_ref) / 1e9 + if tt < time_slot[name]: + time_slot[name] = tt + +class TOKENIZER(): + def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): + if 'list' in str(type(WORD_NAME)): + self.charMode = False + if WORD_NAME[0] == WORD_NAME[1]: + from transformers import PreTrainedTokenizerFast + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) + else: + from transformers import GPT2TokenizerFast + self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) + self.vocab_size = len(self.tokenizer) + else: + self.charMode = True + with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: + self.word_table = json.load(result_file) + + self.vocab_size = len(self.word_table) + + self.stoi = {v: int(k) for k, v in self.word_table.items()} + self.itos = {int(k): v for k, v in self.word_table.items()} + + self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] + + def refine_context(self, context): + context = context.strip().split('\n') + for c in range(len(context)): + context[c] = context[c].strip().strip('\u3000').strip('\r') + context = list(filter(lambda c: c != '', context)) + context = '\n' + ('\n'.join(context)).strip() + if context == '': + context = '\n' + return context + + def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): + # out[self.UNKNOWN_CHAR] = -float('Inf') + lastChar = int(x[-1]) + + probs = F.softmax(out, dim=-1) + + if self.charMode: + if self.itos[lastChar] == '\n': + top_p = top_p_newline + else: + top_p = top_p_usual + else: + top_p = top_p_usual + + if os.environ["RWKV_RUN_DEVICE"] == "cpu": + probs = probs.numpy() + sorted_probs = np.sort(probs)[::-1] + cumulative_probs = np.cumsum(sorted_probs) + cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + probs[probs < cutoff] = 0 + if temperature != 1.0: + probs = probs.pow(1.0 / temperature) + probs = probs / np.sum(probs) + out = np.random.choice(a=len(probs), p=probs) + return out + else: + sorted_probs = torch.sort(probs, descending=True)[0] + cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() + cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + probs[probs < cutoff] = 0 + if temperature != 1.0: + probs = probs.pow(1.0 / temperature) + out = torch.multinomial(probs, num_samples=1)[0] + return out def MaybeIsPrime(number): if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):