refactoring
parent
5817d265c3
commit
71538e44a9
@ -0,0 +1,143 @@
|
||||
import types
|
||||
import copy
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
RWKV_K_CLAMP = 60
|
||||
RWKV_K_EPS = 1e-16
|
||||
RWKV_HEAD_QK_DIM = 256
|
||||
|
||||
DEBUG_TIME = False # True False - show trained time-coeffs
|
||||
|
||||
|
||||
class RWKV_RNN():
|
||||
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
|
||||
self.RUN_DEVICE = RUN_DEVICE
|
||||
self.model_type = model_type
|
||||
self.n_layer = n_layer
|
||||
self.n_embd = n_embd
|
||||
self.ctx_len = ctx_len
|
||||
|
||||
self.w = types.SimpleNamespace()
|
||||
|
||||
w = torch.load(MODEL_NAME + '.pth',
|
||||
map_location=torch.device(RUN_DEVICE))
|
||||
for x in w.keys():
|
||||
if '.time_' in x:
|
||||
w[x] = w[x].squeeze()
|
||||
if '.time_decay' in x:
|
||||
w[x] = torch.exp(-torch.exp(w[x]))
|
||||
if '.time_first' in x:
|
||||
w[x] = torch.exp(w[x])
|
||||
if DEBUG_TIME and '.time_' in x:
|
||||
print(x, w[x].squeeze().cpu().numpy())
|
||||
|
||||
xx = x.split('.')
|
||||
here = self.w
|
||||
for i in range(len(xx)):
|
||||
if xx[i].isdigit():
|
||||
ii = int(xx[i])
|
||||
if ii not in here:
|
||||
here[ii] = types.SimpleNamespace()
|
||||
here = here[ii]
|
||||
else:
|
||||
if i == len(xx) - 1:
|
||||
setattr(here, xx[i], w[x])
|
||||
elif not hasattr(here, xx[i]):
|
||||
if xx[i+1].isdigit():
|
||||
setattr(here, xx[i], {})
|
||||
else:
|
||||
setattr(here, xx[i], types.SimpleNamespace())
|
||||
here = getattr(here, xx[i])
|
||||
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
self.xx = {}
|
||||
self.aa = {}
|
||||
self.bb = {}
|
||||
self.hk = None
|
||||
|
||||
def save(self, target):
|
||||
target.xx = copy.deepcopy(self.xx)
|
||||
target.aa = copy.deepcopy(self.aa)
|
||||
target.bb = copy.deepcopy(self.bb)
|
||||
target.hk = copy.deepcopy(self.hk)
|
||||
|
||||
def load(self, target):
|
||||
self.xx = copy.deepcopy(target.xx)
|
||||
self.aa = copy.deepcopy(target.aa)
|
||||
self.bb = copy.deepcopy(target.bb)
|
||||
self.hk = copy.deepcopy(target.hk)
|
||||
|
||||
def LN(self, xx, w):
|
||||
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
|
||||
|
||||
def FF(self, xx, w, name):
|
||||
if name not in self.xx:
|
||||
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
|
||||
self.xx[name] = xx
|
||||
|
||||
r = torch.sigmoid(w.receptance.weight @ x)
|
||||
k = torch.square(torch.relu(w.key.weight @ x))
|
||||
kv = w.value.weight @ k
|
||||
|
||||
return r * kv
|
||||
|
||||
def SA(self, xx, w, name):
|
||||
if name not in self.xx:
|
||||
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
|
||||
self.xx[name] = xx
|
||||
|
||||
r = torch.sigmoid(w.receptance.weight @ x)
|
||||
|
||||
k = torch.exp(torch.clamp(w.key.weight @ x, max=RWKV_K_CLAMP))
|
||||
v = w.value.weight @ x
|
||||
kv = k * v
|
||||
|
||||
a = self.aa[name] + w.time_first * kv
|
||||
b = self.bb[name] + w.time_first * k
|
||||
self.aa[name] = w.time_decay * self.aa[name] + kv
|
||||
self.bb[name] = w.time_decay * self.bb[name] + k
|
||||
|
||||
rwkv = r * a / (b + RWKV_K_EPS)
|
||||
|
||||
return w.output.weight @ rwkv
|
||||
|
||||
def run(self, ctx):
|
||||
w = self.w
|
||||
x = w.emb.weight[ctx[-1]]
|
||||
|
||||
for i in range(self.n_layer):
|
||||
x = self.LN(x, w.blocks[i].ln1)
|
||||
if i == 0 and self.model_type == 'RWKV-ffnPre':
|
||||
x = x + self.FF(x, w.blocks[i].ffnPre, f'ffnPre.{i}')
|
||||
else:
|
||||
x = x + self.SA(x, w.blocks[i].att, f'att.{i}')
|
||||
x = self.LN(x, w.blocks[i].ln2)
|
||||
x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}')
|
||||
|
||||
x = self.LN(x, w.ln_out)
|
||||
|
||||
if self.hk == None:
|
||||
self.hk = (w.head_k.weight @ x).unsqueeze(0)
|
||||
else:
|
||||
self.hk = torch.cat(
|
||||
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
|
||||
if self.hk.shape[0] > self.ctx_len:
|
||||
self.hk = self.hk[-self.ctx_len:, :]
|
||||
|
||||
q = w.head_q.weight @ x
|
||||
|
||||
x = w.head.weight @ x
|
||||
x = x.cpu().numpy().tolist()
|
||||
|
||||
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
|
||||
for i in range(len(c)):
|
||||
x[ctx[i]] += c[i]
|
||||
|
||||
return x
|
||||
Loading…
Reference in New Issue