refactoring

main
BlinkDL 4 years ago
parent 5817d265c3
commit 71538e44a9

@ -4,12 +4,13 @@
######################################################################################################## ########################################################################################################
import numpy as np import numpy as np
import time
import types import types
import copy import copy
import time
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from src.utils import TOKENIZER from src.utils import TOKENIZER
from src.model_run import RWKV_RNN
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -22,20 +23,23 @@ n_embd = 512
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
# your trained model # your trained model
MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13' MODEL_NAME = 'trained-31'
WORD_NAME = 'enwik8-vocab' # the .json vocab (generated by train.py) WORD_NAME = 'vocab' # the .json vocab (generated by train.py
# ### uncompress enwik8-model.zip to test my enwik8 model
# MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13'
# WORD_NAME = 'enwik8-vocab'
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <-- # --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> unknown tokens in your context will be denoted by it <-- # --> all unknown tokens in your context will be denoted by it <--
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda' RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False # True False - show softmax output DEBUG_DEBUG = False # True False - show softmax output
DEBUG_TIME = False # True False - show trained time-coeffs
### Step 2: set context ################################################################################ ### Step 2: set context ################################################################################
context = "\n" # ==> this is your prompt context = "\nIn the" # ==> this is your prompt
NUM_TRIALS = 999 NUM_TRIALS = 999
LENGTH_PER_TRIAL = 500 LENGTH_PER_TRIAL = 500
@ -54,149 +58,7 @@ print('\nYour prompt has ' + str(len(context)) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n') print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n')
print(f'Loading {MODEL_NAME}...') print(f'Loading {MODEL_NAME}...')
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
##############################################################################################################
RWKV_K_CLAMP = 60
RWKV_K_EPS = 1e-16
RWKV_HEAD_QK_DIM = 256
class RWKV_RNN():
def __init__(self, MODEL_NAME):
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth',
map_location=torch.device(RUN_DEVICE))
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
self.hk = None
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
target.hk = copy.deepcopy(self.hk)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
self.hk = copy.deepcopy(target.hk)
def LN(self, xx, w):
return F.layer_norm(xx, (n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if DEBUG_TIME:
print(name+'.time_mix', w.time_mix.squeeze().numpy())
if name not in self.xx:
self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.square(torch.relu(w.key.weight @ x))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if DEBUG_TIME:
print(name+'.time_mix', w.time_mix.squeeze().numpy())
print(name+'.time_decay', w.time_decay.squeeze().numpy())
print(name+'.time_first', w.time_first.squeeze().numpy())
if name not in self.xx:
self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
self.aa[name] = torch.zeros(n_embd, device=RUN_DEVICE)
self.bb[name] = torch.zeros(n_embd, device=RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.exp(torch.clamp(w.key.weight @ x, max=RWKV_K_CLAMP))
v = w.value.weight @ x
kv = k * v
a = self.aa[name] + w.time_first * kv
b = self.bb[name] + w.time_first * k
self.aa[name] = w.time_decay * self.aa[name] + kv
self.bb[name] = w.time_decay * self.bb[name] + k
rwkv = r * a / (b + RWKV_K_EPS)
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
for i in range(n_layer):
x = self.LN(x, w.blocks[i].ln1)
if i == 0 and model_type == 'RWKV-ffnPre':
x = x + self.FF(x, w.blocks[i].ffnPre, f'ffnPre.{i}')
else:
x = x + self.SA(x, w.blocks[i].att, f'att.{i}')
x = self.LN(x, w.blocks[i].ln2)
x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
if self.hk == None:
self.hk = (w.head_k.weight @ x).unsqueeze(0)
else:
self.hk = torch.cat(
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
if self.hk.shape[0] > ctx_len:
self.hk = self.hk[-ctx_len:, :]
q = w.head_q.weight @ x
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
for i in range(len(c)):
x[ctx[i]] += c[i]
return x
##############################################################################################################
model = RWKV_RNN(MODEL_NAME)
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
t_begin = time.time_ns() t_begin = time.time_ns()

@ -0,0 +1,143 @@
import types
import copy
import torch
from torch.nn import functional as F
RWKV_K_CLAMP = 60
RWKV_K_EPS = 1e-16
RWKV_HEAD_QK_DIM = 256
DEBUG_TIME = False # True False - show trained time-coeffs
class RWKV_RNN():
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
self.RUN_DEVICE = RUN_DEVICE
self.model_type = model_type
self.n_layer = n_layer
self.n_embd = n_embd
self.ctx_len = ctx_len
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth',
map_location=torch.device(RUN_DEVICE))
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
if DEBUG_TIME and '.time_' in x:
print(x, w[x].squeeze().cpu().numpy())
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
self.hk = None
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
target.hk = copy.deepcopy(self.hk)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
self.hk = copy.deepcopy(target.hk)
def LN(self, xx, w):
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.square(torch.relu(w.key.weight @ x))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.exp(torch.clamp(w.key.weight @ x, max=RWKV_K_CLAMP))
v = w.value.weight @ x
kv = k * v
a = self.aa[name] + w.time_first * kv
b = self.bb[name] + w.time_first * k
self.aa[name] = w.time_decay * self.aa[name] + kv
self.bb[name] = w.time_decay * self.bb[name] + k
rwkv = r * a / (b + RWKV_K_EPS)
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
for i in range(self.n_layer):
x = self.LN(x, w.blocks[i].ln1)
if i == 0 and self.model_type == 'RWKV-ffnPre':
x = x + self.FF(x, w.blocks[i].ffnPre, f'ffnPre.{i}')
else:
x = x + self.SA(x, w.blocks[i].att, f'att.{i}')
x = self.LN(x, w.blocks[i].ln2)
x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
if self.hk == None:
self.hk = (w.head_k.weight @ x).unsqueeze(0)
else:
self.hk = torch.cat(
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
if self.hk.shape[0] > self.ctx_len:
self.hk = self.hk[-self.ctx_len:, :]
q = w.head_q.weight @ x
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
for i in range(len(c)):
x[ctx[i]] += c[i]
return x
Loading…
Cancel
Save