faster inference
parent
bd6803df76
commit
3a7e6a6aa3
@ -0,0 +1,184 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import math, os
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
from src.utils import TOKENIZER
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
# Step 1: set model
|
||||||
|
#
|
||||||
|
# Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch.
|
||||||
|
#
|
||||||
|
# Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models.
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
TOKEN_MODE = "pile" # char / bpe / pile
|
||||||
|
|
||||||
|
n_layer = 6
|
||||||
|
n_embd = 512
|
||||||
|
ctx_len = 1024
|
||||||
|
|
||||||
|
if TOKEN_MODE == "char":
|
||||||
|
MODEL_NAME = "trained-500" # your trained model
|
||||||
|
WORD_NAME = "vocab" # the .json vocab (generated by train.py)
|
||||||
|
# set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it
|
||||||
|
UNKNOWN_CHAR = " " # here we just set it to ' ' for simplicity
|
||||||
|
|
||||||
|
elif TOKEN_MODE == "bpe":
|
||||||
|
MODEL_NAME = "trained-500" # your trained model
|
||||||
|
WORD_NAME = [
|
||||||
|
"model-vocab.json",
|
||||||
|
"model-merges.txt",
|
||||||
|
] # [vocab, merge] for your BPE model
|
||||||
|
UNKNOWN_CHAR = None
|
||||||
|
|
||||||
|
elif TOKEN_MODE == "pile":
|
||||||
|
WORD_NAME = [
|
||||||
|
"20B_tokenizer.json",
|
||||||
|
"20B_tokenizer.json",
|
||||||
|
] # [vocab, vocab] for Pile model
|
||||||
|
UNKNOWN_CHAR = None
|
||||||
|
|
||||||
|
# ---> you can set MODEL_NAME to your fine-tuned model <---
|
||||||
|
|
||||||
|
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
|
||||||
|
# n_layer = 12
|
||||||
|
# n_embd = 768
|
||||||
|
# ctx_len = 1024
|
||||||
|
|
||||||
|
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066'
|
||||||
|
# n_layer = 24
|
||||||
|
# n_embd = 1024
|
||||||
|
# ctx_len = 1024
|
||||||
|
|
||||||
|
MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-1B5-20220903-8040'
|
||||||
|
n_layer = 24
|
||||||
|
n_embd = 2048
|
||||||
|
ctx_len = 1024
|
||||||
|
|
||||||
|
os.environ["RWKV_FLOAT_MODE"] = "fp32" # currently only supprts fp32
|
||||||
|
os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda'
|
||||||
|
model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre'
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
# Step 2: set prompt & sampling stuffs
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
# context = 'A'
|
||||||
|
# context = "\nIn the"
|
||||||
|
# context = '\nSugar:'
|
||||||
|
context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
|
||||||
|
|
||||||
|
NUM_TRIALS = 999
|
||||||
|
LENGTH_PER_TRIAL = 333
|
||||||
|
|
||||||
|
TEMPERATURE = 1.0
|
||||||
|
top_p = 0.8
|
||||||
|
top_p_newline = 0.9 # only used in TOKEN_MODE = char
|
||||||
|
|
||||||
|
DEBUG_DEBUG = False # True False --> show softmax output
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
print(f"Loading {MODEL_NAME}...")
|
||||||
|
from src.model_run import RWKV_RNN
|
||||||
|
|
||||||
|
model = RWKV_RNN(
|
||||||
|
MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len
|
||||||
|
)
|
||||||
|
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
if tokenizer.charMode:
|
||||||
|
context = tokenizer.refine_context(context)
|
||||||
|
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
|
||||||
|
else:
|
||||||
|
ctx = tokenizer.tokenizer.encode(context)
|
||||||
|
src_len = len(ctx)
|
||||||
|
src_ctx = ctx.copy()
|
||||||
|
|
||||||
|
print("\nYour prompt has " + str(src_len) + " tokens.")
|
||||||
|
print(
|
||||||
|
"\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# time_slot = {}
|
||||||
|
# time_ref = time.time_ns()
|
||||||
|
|
||||||
|
# def record_time(name):
|
||||||
|
# if name not in time_slot:
|
||||||
|
# time_slot[name] = 1e20
|
||||||
|
# tt = (time.time_ns() - time_ref) / 1e9
|
||||||
|
# if tt < time_slot[name]:
|
||||||
|
# time_slot[name] = tt
|
||||||
|
|
||||||
|
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
|
||||||
|
# time_ref = time.time_ns()
|
||||||
|
|
||||||
|
print(("-" * 50) + context, end="")
|
||||||
|
ctx = src_ctx.copy()
|
||||||
|
model.clear()
|
||||||
|
|
||||||
|
if TRIAL == 0:
|
||||||
|
init_state = types.SimpleNamespace()
|
||||||
|
for i in range(src_len):
|
||||||
|
x = ctx[: i + 1]
|
||||||
|
if i == src_len - 1:
|
||||||
|
init_state.out = model.forward(x)
|
||||||
|
else:
|
||||||
|
model.forward(x, preprocess_only=True)
|
||||||
|
model.save(init_state)
|
||||||
|
else:
|
||||||
|
model.load(init_state)
|
||||||
|
|
||||||
|
# record_time('model_pre')
|
||||||
|
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
|
||||||
|
# time_ref = time.time_ns()
|
||||||
|
|
||||||
|
x = ctx[: i + 1]
|
||||||
|
x = x[-ctx_len:]
|
||||||
|
|
||||||
|
if i == src_len:
|
||||||
|
out = copy.deepcopy(init_state.out)
|
||||||
|
else:
|
||||||
|
out = model.forward(x)
|
||||||
|
# record_time('model_run')
|
||||||
|
if DEBUG_DEBUG:
|
||||||
|
print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy()))
|
||||||
|
|
||||||
|
if TOKEN_MODE == "pile":
|
||||||
|
out[0] = -999999999 # disable <|endoftext|>
|
||||||
|
|
||||||
|
time_ref = time.time_ns()
|
||||||
|
char = tokenizer.sample_logits(
|
||||||
|
out,
|
||||||
|
x,
|
||||||
|
ctx_len,
|
||||||
|
temperature=TEMPERATURE,
|
||||||
|
top_p_usual=top_p,
|
||||||
|
top_p_newline=top_p_newline,
|
||||||
|
)
|
||||||
|
if tokenizer.charMode:
|
||||||
|
print(tokenizer.itos[char], end="", flush=True)
|
||||||
|
else:
|
||||||
|
print(tokenizer.tokenizer.decode(char), end="", flush=True)
|
||||||
|
ctx += [char]
|
||||||
|
|
||||||
|
# record_time('model_sampling')
|
||||||
|
print()
|
||||||
|
# print(f'\n\n{time_slot}\n\n')
|
||||||
|
# print(
|
||||||
|
# f"\n--- preprocess {round((t_mid - t_begin) / (10 ** 9), 2)}s, generation {round((t_end - t_mid) / (10 ** 9), 2)}s", end = ''
|
||||||
|
# )
|
||||||
@ -0,0 +1,192 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import types
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
import math, os
|
||||||
|
from torch.nn import functional as F
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
def __nop(ob):
|
||||||
|
return ob
|
||||||
|
|
||||||
|
|
||||||
|
MyModule = nn.Module
|
||||||
|
MyFunction = __nop
|
||||||
|
# MyModule = torch.jit.ScriptModule
|
||||||
|
# MyFunction = torch.jit.script_method
|
||||||
|
|
||||||
|
RWKV_HEAD_QK_DIM = 0
|
||||||
|
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
|
||||||
|
|
||||||
|
DEBUG_TIME = False # True False - show trained time-coeffs
|
||||||
|
|
||||||
|
############################################################################################################
|
||||||
|
|
||||||
|
class RWKV_RNN(MyModule): # this is running in FP32 at this moment
|
||||||
|
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.RUN_DEVICE = RUN_DEVICE
|
||||||
|
self.model_type = model_type
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.n_embd = n_embd
|
||||||
|
self.ctx_len = ctx_len
|
||||||
|
|
||||||
|
self.w = types.SimpleNamespace()
|
||||||
|
|
||||||
|
w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE))
|
||||||
|
for x in w.keys():
|
||||||
|
w[x] = w[x].float()
|
||||||
|
if '.time_' in x:
|
||||||
|
w[x] = w[x].squeeze()
|
||||||
|
if '.time_decay' in x:
|
||||||
|
w[x] = -torch.exp(w[x])
|
||||||
|
if DEBUG_TIME and '.time_' in x:
|
||||||
|
print(x, w[x].squeeze().cpu().numpy())
|
||||||
|
|
||||||
|
xx = x.split('.')
|
||||||
|
here = self.w
|
||||||
|
for i in range(len(xx)):
|
||||||
|
if xx[i].isdigit():
|
||||||
|
ii = int(xx[i])
|
||||||
|
if ii not in here:
|
||||||
|
here[ii] = types.SimpleNamespace()
|
||||||
|
here = here[ii]
|
||||||
|
else:
|
||||||
|
if i == len(xx) - 1:
|
||||||
|
setattr(here, xx[i], w[x])
|
||||||
|
elif not hasattr(here, xx[i]):
|
||||||
|
if xx[i+1].isdigit():
|
||||||
|
setattr(here, xx[i], {})
|
||||||
|
else:
|
||||||
|
setattr(here, xx[i], types.SimpleNamespace())
|
||||||
|
here = getattr(here, xx[i])
|
||||||
|
|
||||||
|
self.clear()
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.xx = {}
|
||||||
|
self.aa = {}
|
||||||
|
self.bb = {}
|
||||||
|
self.pp = {}
|
||||||
|
self.hk = None
|
||||||
|
|
||||||
|
def save(self, target):
|
||||||
|
target.xx = copy.deepcopy(self.xx)
|
||||||
|
target.aa = copy.deepcopy(self.aa)
|
||||||
|
target.bb = copy.deepcopy(self.bb)
|
||||||
|
target.pp = copy.deepcopy(self.pp)
|
||||||
|
target.hk = copy.deepcopy(self.hk)
|
||||||
|
|
||||||
|
def load(self, target):
|
||||||
|
self.xx = copy.deepcopy(target.xx)
|
||||||
|
self.aa = copy.deepcopy(target.aa)
|
||||||
|
self.bb = copy.deepcopy(target.bb)
|
||||||
|
self.pp = copy.deepcopy(target.pp)
|
||||||
|
self.hk = copy.deepcopy(target.hk)
|
||||||
|
|
||||||
|
@MyFunction
|
||||||
|
def LN(self, xx, w):
|
||||||
|
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
|
||||||
|
|
||||||
|
@MyFunction
|
||||||
|
def FF(self, xx, w, name):
|
||||||
|
if name not in self.xx:
|
||||||
|
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||||
|
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
|
||||||
|
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
|
||||||
|
self.xx[name] = xx
|
||||||
|
|
||||||
|
r = torch.sigmoid(w.receptance.weight @ xr)
|
||||||
|
k = torch.square(torch.relu(w.key.weight @ xk))
|
||||||
|
kv = w.value.weight @ k
|
||||||
|
|
||||||
|
return r * kv
|
||||||
|
|
||||||
|
@MyFunction
|
||||||
|
def SA(self, xx, w, name):
|
||||||
|
if name not in self.xx:
|
||||||
|
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||||
|
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||||
|
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||||
|
self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30
|
||||||
|
|
||||||
|
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
|
||||||
|
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
|
||||||
|
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
|
||||||
|
self.xx[name] = xx
|
||||||
|
|
||||||
|
r = torch.sigmoid(w.receptance.weight @ xr)
|
||||||
|
|
||||||
|
k = w.key.weight @ xk
|
||||||
|
v = w.value.weight @ xv
|
||||||
|
|
||||||
|
pp = self.pp[name]
|
||||||
|
aa = self.aa[name]
|
||||||
|
bb = self.bb[name]
|
||||||
|
ww = w.time_first + k
|
||||||
|
p = torch.maximum(pp, ww)
|
||||||
|
e1 = torch.exp(pp - p)
|
||||||
|
e2 = torch.exp(ww - p)
|
||||||
|
a = e1 * aa + e2 * v
|
||||||
|
b = e1 * bb + e2
|
||||||
|
ww = pp + w.time_decay
|
||||||
|
p = torch.maximum(ww, k)
|
||||||
|
e1 = torch.exp(ww - p)
|
||||||
|
e2 = torch.exp(k - p)
|
||||||
|
self.aa[name] = e1 * aa + e2 * v
|
||||||
|
self.bb[name] = e1 * bb + e2
|
||||||
|
self.pp[name] = p
|
||||||
|
|
||||||
|
rwkv = r * a / b
|
||||||
|
|
||||||
|
return w.output.weight @ rwkv
|
||||||
|
|
||||||
|
def forward(self, ctx, preprocess_only = False):
|
||||||
|
with torch.no_grad():
|
||||||
|
w = self.w
|
||||||
|
x = w.emb.weight[ctx[-1]]
|
||||||
|
|
||||||
|
for i in range(self.n_layer):
|
||||||
|
if i == 0:
|
||||||
|
x = self.LN(x, w.blocks[i].ln0)
|
||||||
|
if i == 0 and self.model_type == 'RWKV-ffnPre':
|
||||||
|
x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
|
||||||
|
else:
|
||||||
|
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
|
||||||
|
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
|
||||||
|
|
||||||
|
x = self.LN(x, w.ln_out)
|
||||||
|
|
||||||
|
if RWKV_HEAD_QK_DIM > 0:
|
||||||
|
if self.hk == None:
|
||||||
|
self.hk = (w.head_k.weight @ x).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
self.hk = torch.cat(
|
||||||
|
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
|
||||||
|
if self.hk.shape[0] > self.ctx_len:
|
||||||
|
self.hk = self.hk[-self.ctx_len:, :]
|
||||||
|
|
||||||
|
if preprocess_only:
|
||||||
|
return None
|
||||||
|
|
||||||
|
q = w.head_q.weight @ x
|
||||||
|
|
||||||
|
x = w.head.weight @ x
|
||||||
|
x = x
|
||||||
|
|
||||||
|
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
|
||||||
|
for i in range(len(c)):
|
||||||
|
x[ctx[i]] += c[i]
|
||||||
|
else:
|
||||||
|
if preprocess_only:
|
||||||
|
return None
|
||||||
|
|
||||||
|
x = w.head.weight @ x
|
||||||
|
x = x
|
||||||
|
|
||||||
|
return x
|
||||||
Loading…
Reference in New Issue