|
|
|
|
@ -7,13 +7,15 @@ import torch
|
|
|
|
|
import math, os, gc
|
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from typing import List, Dict
|
|
|
|
|
|
|
|
|
|
# try:
|
|
|
|
|
# import torchdynamo
|
|
|
|
|
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
|
|
|
|
|
# except:
|
|
|
|
|
def __nop(ob):
|
|
|
|
|
return ob
|
|
|
|
|
MyModule = nn.Module
|
|
|
|
|
MyFunction = __nop
|
|
|
|
|
# MyModule = torch.jit.ScriptModule
|
|
|
|
|
# MyFunction = torch.jit.script_method
|
|
|
|
|
|
|
|
|
|
RWKV_HEAD_QK_DIM = 0
|
|
|
|
|
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
|
|
|
|
|
@ -22,24 +24,20 @@ DEBUG_TIME = False # True False - show trained time-coeffs
|
|
|
|
|
|
|
|
|
|
############################################################################################################
|
|
|
|
|
|
|
|
|
|
class RWKV_RNN(MyModule):
|
|
|
|
|
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
|
|
|
|
|
class RWKV_RNN(nn.Module):
|
|
|
|
|
def __init__(self, args):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.RUN_DEVICE = RUN_DEVICE
|
|
|
|
|
self.model_type = model_type
|
|
|
|
|
self.n_layer = n_layer
|
|
|
|
|
self.n_embd = n_embd
|
|
|
|
|
self.ctx_len = ctx_len
|
|
|
|
|
|
|
|
|
|
w = torch.load(MODEL_NAME + '.pth', map_location='cpu')
|
|
|
|
|
self.args = args
|
|
|
|
|
self.FLOAT_MODE = args.FLOAT_MODE
|
|
|
|
|
self.RUN_DEVICE = args.RUN_DEVICE
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
|
|
|
|
|
# refine weights and send to correct device
|
|
|
|
|
|
|
|
|
|
keys = list(w.keys())
|
|
|
|
|
if 'pos_emb_x' in keys:
|
|
|
|
|
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(ctx_len+1, -1)[:-1,:]
|
|
|
|
|
|
|
|
|
|
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:]
|
|
|
|
|
keys = list(w.keys())
|
|
|
|
|
print_need_newline = False
|
|
|
|
|
for x in keys:
|
|
|
|
|
@ -53,13 +51,13 @@ class RWKV_RNN(MyModule):
|
|
|
|
|
elif '.time_first' in x:
|
|
|
|
|
w[x] = w[x].float()
|
|
|
|
|
else:
|
|
|
|
|
if os.environ["RWKV_FLOAT_MODE"] == "fp32":
|
|
|
|
|
if self.FLOAT_MODE == "fp32":
|
|
|
|
|
w[x] = w[x].float()
|
|
|
|
|
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
|
|
|
elif self.FLOAT_MODE == "bf16":
|
|
|
|
|
w[x] = w[x].bfloat16()
|
|
|
|
|
|
|
|
|
|
w[x].requires_grad = False
|
|
|
|
|
if RUN_DEVICE == 'cuda' and x != 'emb.weight':
|
|
|
|
|
if args.RUN_DEVICE == 'cuda' and x != 'emb.weight':
|
|
|
|
|
w[x] = w[x].cuda()
|
|
|
|
|
|
|
|
|
|
if ('blocks.' not in x) or ('blocks.0.' in x):
|
|
|
|
|
@ -72,7 +70,6 @@ class RWKV_RNN(MyModule):
|
|
|
|
|
print('.', end = '', flush = True)
|
|
|
|
|
|
|
|
|
|
# store weights in self.w
|
|
|
|
|
|
|
|
|
|
keys = list(w.keys())
|
|
|
|
|
self.w = types.SimpleNamespace()
|
|
|
|
|
for x in keys:
|
|
|
|
|
@ -98,91 +95,78 @@ class RWKV_RNN(MyModule):
|
|
|
|
|
gc.collect()
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def LN(self, x, w):
|
|
|
|
|
return F.layer_norm(x, (self.n_embd,), weight=w.weight, bias=w.bias)
|
|
|
|
|
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
|
|
|
|
|
|
|
|
|
|
# state: ffn_xx att_xx att_aa att_bb att_pp
|
|
|
|
|
# state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def FF(self, x, w, state, i):
|
|
|
|
|
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
|
|
|
xk = x * w.time_mix_k + state[5*i+0].bfloat16() * (1 - w.time_mix_k)
|
|
|
|
|
xr = x * w.time_mix_r + state[5*i+0].bfloat16() * (1 - w.time_mix_r)
|
|
|
|
|
def FF(self, x, state, i, time_mix_k, time_mix_r, kw, vw, rw):
|
|
|
|
|
if self.FLOAT_MODE == "bf16":
|
|
|
|
|
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
|
|
|
|
|
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
|
|
|
|
|
state[5*i+0] = x.float()
|
|
|
|
|
else:
|
|
|
|
|
xk = x * w.time_mix_k + state[5*i+0] * (1 - w.time_mix_k)
|
|
|
|
|
xr = x * w.time_mix_r + state[5*i+0] * (1 - w.time_mix_r)
|
|
|
|
|
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
|
|
|
|
|
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
|
|
|
|
|
state[5*i+0] = x
|
|
|
|
|
|
|
|
|
|
r = torch.sigmoid(w.receptance.weight @ xr)
|
|
|
|
|
k = torch.square(torch.relu(w.key.weight @ xk))
|
|
|
|
|
kv = w.value.weight @ k
|
|
|
|
|
r = torch.sigmoid(rw @ xr)
|
|
|
|
|
k = torch.square(torch.relu(kw @ xk))
|
|
|
|
|
kv = vw @ k
|
|
|
|
|
|
|
|
|
|
return r * kv
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def SA(self, x, w, state, i):
|
|
|
|
|
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
|
|
|
xk = x * w.time_mix_k + state[5*i+1].bfloat16() * (1 - w.time_mix_k)
|
|
|
|
|
xv = x * w.time_mix_v + state[5*i+1].bfloat16() * (1 - w.time_mix_v)
|
|
|
|
|
xr = x * w.time_mix_r + state[5*i+1].bfloat16() * (1 - w.time_mix_r)
|
|
|
|
|
def SA(self, x, state, i, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
|
|
|
|
|
if self.FLOAT_MODE == "bf16":
|
|
|
|
|
xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
|
|
|
|
|
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
|
|
|
|
|
xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r)
|
|
|
|
|
state[5*i+1] = x.float()
|
|
|
|
|
else:
|
|
|
|
|
xk = x * w.time_mix_k + state[5*i+1] * (1 - w.time_mix_k)
|
|
|
|
|
xv = x * w.time_mix_v + state[5*i+1] * (1 - w.time_mix_v)
|
|
|
|
|
xr = x * w.time_mix_r + state[5*i+1] * (1 - w.time_mix_r)
|
|
|
|
|
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
|
|
|
|
|
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
|
|
|
|
|
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
|
|
|
|
|
state[5*i+1] = x
|
|
|
|
|
|
|
|
|
|
r = torch.sigmoid(w.receptance.weight @ xr)
|
|
|
|
|
|
|
|
|
|
k = w.key.weight @ xk
|
|
|
|
|
v = w.value.weight @ xv
|
|
|
|
|
r = torch.sigmoid(rw @ xr)
|
|
|
|
|
k = kw @ xk
|
|
|
|
|
v = vw @ xv
|
|
|
|
|
|
|
|
|
|
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
|
|
|
if self.FLOAT_MODE == "bf16":
|
|
|
|
|
kk = k.float()
|
|
|
|
|
vv = v.float()
|
|
|
|
|
else:
|
|
|
|
|
kk = k
|
|
|
|
|
vv = v
|
|
|
|
|
aa = state[5*i+2]
|
|
|
|
|
bb = state[5*i+3]
|
|
|
|
|
pp = state[5*i+4]
|
|
|
|
|
ww = w.time_first + kk
|
|
|
|
|
ww = time_first + kk
|
|
|
|
|
p = torch.maximum(pp, ww)
|
|
|
|
|
e1 = torch.exp(pp - p)
|
|
|
|
|
e2 = torch.exp(ww - p)
|
|
|
|
|
a = e1 * aa + e2 * vv
|
|
|
|
|
b = e1 * bb + e2
|
|
|
|
|
ww = pp + w.time_decay
|
|
|
|
|
ww = pp + time_decay
|
|
|
|
|
p = torch.maximum(ww, kk)
|
|
|
|
|
e1 = torch.exp(ww - p)
|
|
|
|
|
e2 = torch.exp(kk - p)
|
|
|
|
|
state[5*i+2] = e1 * aa + e2 * vv
|
|
|
|
|
state[5*i+3] = e1 * bb + e2
|
|
|
|
|
state[5*i+4] = p
|
|
|
|
|
rwkv = r * (a / b).bfloat16()
|
|
|
|
|
if self.FLOAT_MODE == "bf16":
|
|
|
|
|
wkv = (a / b).type(torch.bfloat16)
|
|
|
|
|
else:
|
|
|
|
|
aa = state[5*i+2]
|
|
|
|
|
bb = state[5*i+3]
|
|
|
|
|
pp = state[5*i+4]
|
|
|
|
|
ww = w.time_first + k
|
|
|
|
|
p = torch.maximum(pp, ww)
|
|
|
|
|
e1 = torch.exp(pp - p)
|
|
|
|
|
e2 = torch.exp(ww - p)
|
|
|
|
|
a = e1 * aa + e2 * v
|
|
|
|
|
b = e1 * bb + e2
|
|
|
|
|
ww = pp + w.time_decay
|
|
|
|
|
p = torch.maximum(ww, k)
|
|
|
|
|
e1 = torch.exp(ww - p)
|
|
|
|
|
e2 = torch.exp(k - p)
|
|
|
|
|
state[5*i+2] = e1 * aa + e2 * v
|
|
|
|
|
state[5*i+3] = e1 * bb + e2
|
|
|
|
|
state[5*i+4] = p
|
|
|
|
|
rwkv = r * a / b
|
|
|
|
|
wkv = a / b
|
|
|
|
|
|
|
|
|
|
return w.output.weight @ rwkv
|
|
|
|
|
return ow @ (r * wkv)
|
|
|
|
|
|
|
|
|
|
def forward(self, ctx, state, preprocess_only = False):
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
w = self.w
|
|
|
|
|
args = self.args
|
|
|
|
|
|
|
|
|
|
x = w.emb.weight[ctx[-1]]
|
|
|
|
|
if self.RUN_DEVICE == 'cuda':
|
|
|
|
|
@ -194,15 +178,23 @@ class RWKV_RNN(MyModule):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
if state == None:
|
|
|
|
|
state = torch.zeros(self.n_layer * 5, self.n_embd, device=self.RUN_DEVICE)
|
|
|
|
|
for i in range(self.n_layer):
|
|
|
|
|
state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE)
|
|
|
|
|
for i in range(args.n_layer):
|
|
|
|
|
state[5*i+4] -= 1e30
|
|
|
|
|
|
|
|
|
|
for i in range(self.n_layer):
|
|
|
|
|
for i in range(args.n_layer):
|
|
|
|
|
if i == 0:
|
|
|
|
|
x = self.LN(x, w.blocks[i].ln0)
|
|
|
|
|
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, state, i)
|
|
|
|
|
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, state, i)
|
|
|
|
|
|
|
|
|
|
ww = w.blocks[i].att
|
|
|
|
|
x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i,
|
|
|
|
|
ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay,
|
|
|
|
|
ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight)
|
|
|
|
|
|
|
|
|
|
ww = w.blocks[i].ffn
|
|
|
|
|
x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i,
|
|
|
|
|
ww.time_mix_k, ww.time_mix_r,
|
|
|
|
|
ww.key.weight, ww.value.weight, ww.receptance.weight)
|
|
|
|
|
|
|
|
|
|
if preprocess_only:
|
|
|
|
|
return state
|
|
|
|
|
|