main
BlinkDL 3 years ago
parent daed379db2
commit 8a4a41a3aa

@ -3,11 +3,9 @@
########################################################################################################
import numpy as np
import math, os, sys
import time
import math, os, sys, types, time, gc
import torch
from src.utils import TOKENIZER
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
@ -16,44 +14,27 @@ torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
args = types.SimpleNamespace()
########################################################################################################
# Step 1: set model
#
# Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch.
#
# Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models.
# Step 1: set model & config
# Do this first: pip install torchdynamo
########################################################################################################
TOKEN_MODE = "pile" # char / bpe / pile
n_layer = 6
n_embd = 512
ctx_len = 1024
if TOKEN_MODE == "char":
MODEL_NAME = "trained-500" # your trained model
WORD_NAME = "vocab" # the .json vocab (generated by train.py)
# set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it
UNKNOWN_CHAR = " " # here we just set it to ' ' for simplicity
args.RUN_DEVICE = "cpu" # 'cpu' (already very fast) // 'cuda'
args.FLOAT_MODE = "fp32" # fp32 // bf16 (saves VRAM, slightly less accurate)
# if args.RUN_DEVICE == "cuda":
# os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output
elif TOKEN_MODE == "bpe":
MODEL_NAME = "trained-500" # your trained model
WORD_NAME = [
"model-vocab.json",
"model-merges.txt",
] # [vocab, merge] for your BPE model
UNKNOWN_CHAR = None
elif TOKEN_MODE == "pile":
TOKEN_MODE = "pile"
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None
vocab_size = 50277
# ---> you can set MODEL_NAME to your fine-tuned model <---
# note; you can set MODEL_NAME to your fine-tuned model
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
# n_layer = 12
# n_embd = 768
@ -64,11 +45,16 @@ elif TOKEN_MODE == "pile":
# n_embd = 1024
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220929-ctx4096'
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040'
# n_layer = 24
# n_embd = 2048
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220929-ctx4096'
# n_layer = 24
# n_embd = 2048
# ctx_len = 4096
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
n_layer = 32
n_embd = 2560
@ -79,9 +65,16 @@ elif TOKEN_MODE == "pile":
# n_embd = 4096
# ctx_len = 1024
os.environ["RWKV_FLOAT_MODE"] = "fp32" # fp32 (faster at this moment) or bf16 (slower but saves VRAM)
os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda'
model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre'
args.MODEL_NAME = MODEL_NAME
args.n_layer = n_layer
args.n_embd = n_embd
args.ctx_len = ctx_len
args.vocab_size = vocab_size
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
########################################################################################################
# Step 2: set prompt & sampling stuffs
@ -128,12 +121,15 @@ DEBUG_DEBUG = False # True False --> show softmax output
########################################################################################################
print(f'\nUsing {os.environ["RWKV_RUN_DEVICE"].upper()}. Loading {MODEL_NAME}...')
print(f'\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...')
from src.model_run import RWKV_RNN
model = RWKV_RNN(
MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len
)
model = RWKV_RNN(args)
print(f'\nOptimizing speed...')
model.forward([187], None)
gc.collect()
torch.cuda.empty_cache()
# input(0)
@ -185,6 +181,8 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
init_out, init_state = model.forward(x, init_state)
else:
init_state = model.forward(x, init_state, preprocess_only=True)
gc.collect()
torch.cuda.empty_cache()
record_time('preprocess')
out_last = src_len

@ -7,13 +7,15 @@ import torch
import math, os, gc
from torch.nn import functional as F
import torch.nn as nn
from typing import List, Dict
# try:
# import torchdynamo
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
# except:
def __nop(ob):
return ob
MyModule = nn.Module
MyFunction = __nop
# MyModule = torch.jit.ScriptModule
# MyFunction = torch.jit.script_method
RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
@ -22,24 +24,20 @@ DEBUG_TIME = False # True False - show trained time-coeffs
############################################################################################################
class RWKV_RNN(MyModule):
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
class RWKV_RNN(nn.Module):
def __init__(self, args):
super().__init__()
self.RUN_DEVICE = RUN_DEVICE
self.model_type = model_type
self.n_layer = n_layer
self.n_embd = n_embd
self.ctx_len = ctx_len
w = torch.load(MODEL_NAME + '.pth', map_location='cpu')
self.args = args
self.FLOAT_MODE = args.FLOAT_MODE
self.RUN_DEVICE = args.RUN_DEVICE
with torch.no_grad():
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
# refine weights and send to correct device
keys = list(w.keys())
if 'pos_emb_x' in keys:
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(ctx_len+1, -1)[:-1,:]
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:]
keys = list(w.keys())
print_need_newline = False
for x in keys:
@ -53,13 +51,13 @@ class RWKV_RNN(MyModule):
elif '.time_first' in x:
w[x] = w[x].float()
else:
if os.environ["RWKV_FLOAT_MODE"] == "fp32":
if self.FLOAT_MODE == "fp32":
w[x] = w[x].float()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
elif self.FLOAT_MODE == "bf16":
w[x] = w[x].bfloat16()
w[x].requires_grad = False
if RUN_DEVICE == 'cuda' and x != 'emb.weight':
if args.RUN_DEVICE == 'cuda' and x != 'emb.weight':
w[x] = w[x].cuda()
if ('blocks.' not in x) or ('blocks.0.' in x):
@ -72,7 +70,6 @@ class RWKV_RNN(MyModule):
print('.', end = '', flush = True)
# store weights in self.w
keys = list(w.keys())
self.w = types.SimpleNamespace()
for x in keys:
@ -98,91 +95,78 @@ class RWKV_RNN(MyModule):
gc.collect()
torch.cuda.empty_cache()
@MyFunction
def LN(self, x, w):
return F.layer_norm(x, (self.n_embd,), weight=w.weight, bias=w.bias)
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
# state: ffn_xx att_xx att_aa att_bb att_pp
# state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
@MyFunction
def FF(self, x, w, state, i):
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
xk = x * w.time_mix_k + state[5*i+0].bfloat16() * (1 - w.time_mix_k)
xr = x * w.time_mix_r + state[5*i+0].bfloat16() * (1 - w.time_mix_r)
def FF(self, x, state, i, time_mix_k, time_mix_r, kw, vw, rw):
if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+0] = x.float()
else:
xk = x * w.time_mix_k + state[5*i+0] * (1 - w.time_mix_k)
xr = x * w.time_mix_r + state[5*i+0] * (1 - w.time_mix_r)
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
state[5*i+0] = x
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.square(torch.relu(w.key.weight @ xk))
kv = w.value.weight @ k
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk))
kv = vw @ k
return r * kv
@MyFunction
def SA(self, x, w, state, i):
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
xk = x * w.time_mix_k + state[5*i+1].bfloat16() * (1 - w.time_mix_k)
xv = x * w.time_mix_v + state[5*i+1].bfloat16() * (1 - w.time_mix_v)
xr = x * w.time_mix_r + state[5*i+1].bfloat16() * (1 - w.time_mix_r)
def SA(self, x, state, i, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+1] = x.float()
else:
xk = x * w.time_mix_k + state[5*i+1] * (1 - w.time_mix_k)
xv = x * w.time_mix_v + state[5*i+1] * (1 - w.time_mix_v)
xr = x * w.time_mix_r + state[5*i+1] * (1 - w.time_mix_r)
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
state[5*i+1] = x
r = torch.sigmoid(w.receptance.weight @ xr)
k = w.key.weight @ xk
v = w.value.weight @ xv
r = torch.sigmoid(rw @ xr)
k = kw @ xk
v = vw @ xv
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
if self.FLOAT_MODE == "bf16":
kk = k.float()
vv = v.float()
else:
kk = k
vv = v
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = w.time_first + kk
ww = time_first + kk
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * vv
b = e1 * bb + e2
ww = pp + w.time_decay
ww = pp + time_decay
p = torch.maximum(ww, kk)
e1 = torch.exp(ww - p)
e2 = torch.exp(kk - p)
state[5*i+2] = e1 * aa + e2 * vv
state[5*i+3] = e1 * bb + e2
state[5*i+4] = p
rwkv = r * (a / b).bfloat16()
if self.FLOAT_MODE == "bf16":
wkv = (a / b).type(torch.bfloat16)
else:
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = w.time_first + k
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * v
b = e1 * bb + e2
ww = pp + w.time_decay
p = torch.maximum(ww, k)
e1 = torch.exp(ww - p)
e2 = torch.exp(k - p)
state[5*i+2] = e1 * aa + e2 * v
state[5*i+3] = e1 * bb + e2
state[5*i+4] = p
rwkv = r * a / b
wkv = a / b
return w.output.weight @ rwkv
return ow @ (r * wkv)
def forward(self, ctx, state, preprocess_only = False):
with torch.no_grad():
w = self.w
args = self.args
x = w.emb.weight[ctx[-1]]
if self.RUN_DEVICE == 'cuda':
@ -194,15 +178,23 @@ class RWKV_RNN(MyModule):
pass
if state == None:
state = torch.zeros(self.n_layer * 5, self.n_embd, device=self.RUN_DEVICE)
for i in range(self.n_layer):
state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE)
for i in range(args.n_layer):
state[5*i+4] -= 1e30
for i in range(self.n_layer):
for i in range(args.n_layer):
if i == 0:
x = self.LN(x, w.blocks[i].ln0)
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, state, i)
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, state, i)
ww = w.blocks[i].att
x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i,
ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay,
ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight)
ww = w.blocks[i].ffn
x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i,
ww.time_mix_k, ww.time_mix_r,
ww.key.weight, ww.value.weight, ww.receptance.weight)
if preprocess_only:
return state

Loading…
Cancel
Save