You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

240 lines
7.9 KiB
Python

# -*- coding:utf-8 -*-
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import types
import copy
import time
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
### Step 1: set model ##################################################################################
ctx_len = 1024
n_layer = 6
n_embd = 512
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
# your trained model
MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13'
WORD_NAME = 'enwik8-vocab' # the .json vocab (generated by train.py)
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> unknown tokens in your context will be denoted by it <--
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False # True False - show softmax output
DEBUG_TIME = False # True False - show trained time-coeffs
### Step 2: set context ################################################################################
context = "\n" # ==> this is your prompt
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 500
TEMPERATURE = 1.0
top_p = 0.7
top_p_newline = 0.9
########################################################################################################
np.set_printoptions(precision=4, suppress=True, linewidth=200)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
context = tokenizer.refine_context(context)
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n')
print(f'Loading {MODEL_NAME}...')
##############################################################################################################
RWKV_K_CLAMP = 60
RWKV_K_EPS = 1e-16
RWKV_HEAD_QK_DIM = 256
class RWKV_RNN():
def __init__(self, MODEL_NAME):
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth',
map_location=torch.device(RUN_DEVICE))
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
self.hk = None
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
target.hk = copy.deepcopy(self.hk)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
self.hk = copy.deepcopy(target.hk)
def LN(self, xx, w):
return F.layer_norm(xx, (n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if DEBUG_TIME:
print(name+'.time_mix', w.time_mix.squeeze().numpy())
if name not in self.xx:
self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.square(torch.relu(w.key.weight @ x))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if DEBUG_TIME:
print(name+'.time_mix', w.time_mix.squeeze().numpy())
print(name+'.time_decay', w.time_decay.squeeze().numpy())
print(name+'.time_first', w.time_first.squeeze().numpy())
if name not in self.xx:
self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
self.aa[name] = torch.zeros(n_embd, device=RUN_DEVICE)
self.bb[name] = torch.zeros(n_embd, device=RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.exp(torch.clamp(w.key.weight @ x, max=RWKV_K_CLAMP))
v = w.value.weight @ x
kv = k * v
a = self.aa[name] + w.time_first * kv
b = self.bb[name] + w.time_first * k
self.aa[name] = w.time_decay * self.aa[name] + kv
self.bb[name] = w.time_decay * self.bb[name] + k
rwkv = r * a / (b + RWKV_K_EPS)
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
for i in range(n_layer):
x = self.LN(x, w.blocks[i].ln1)
if i == 0 and model_type == 'RWKV-ffnPre':
x = x + self.FF(x, w.blocks[i].ffnPre, f'ffnPre.{i}')
else:
x = x + self.SA(x, w.blocks[i].att, f'att.{i}')
x = self.LN(x, w.blocks[i].ln2)
x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
if self.hk == None:
self.hk = (w.head_k.weight @ x).unsqueeze(0)
else:
self.hk = torch.cat(
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
if self.hk.shape[0] > ctx_len:
self.hk = self.hk[-ctx_len:, :]
q = w.head_q.weight @ x
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
for i in range(len(c)):
x[ctx[i]] += c[i]
return x
##############################################################################################################
model = RWKV_RNN(MODEL_NAME)
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
t_begin = time.time_ns()
src_len = len(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
print(('-' * 30) + context, end='')
model.clear()
if TRIAL == 0:
init_state = types.SimpleNamespace()
for i in range(src_len):
x = ctx[:i+1]
if i == src_len - 1:
init_state.out = model.run(x)
else:
model.run(x)
model.save(init_state)
else:
model.load(init_state)
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
x = ctx[:i+1]
x = x[-ctx_len:]
if i == src_len:
out = copy.deepcopy(init_state.out)
else:
out = model.run(x)
if DEBUG_DEBUG:
print('model', np.array(x), '==>', np.array(
out), np.max(out), np.min(out))
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
top_p_usual=top_p, top_p_newline=top_p_newline)
char = char.item()
print(tokenizer.itos[int(char)], end='', flush=True)
ctx += [char]
t_end = time.time_ns()
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')