bf16 inference - 15G VRAM for 7b model

main
BlinkDL 3 years ago
parent 1479315677
commit daed379db2

@ -3,13 +3,15 @@
######################################################################################################## ########################################################################################################
import numpy as np import numpy as np
import math, os import math, os, sys
import time import time
import types
import copy
import torch import torch
from src.utils import TOKENIZER from src.utils import TOKENIZER
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -62,17 +64,22 @@ elif TOKEN_MODE == "pile":
# n_embd = 1024 # n_embd = 1024
# ctx_len = 1024 # ctx_len = 1024
MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-1B5-20220903-8040' # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220929-ctx4096'
n_layer = 24 # n_layer = 24
n_embd = 2048 # n_embd = 2048
# ctx_len = 1024
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
n_layer = 32
n_embd = 2560
ctx_len = 1024 ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20220925-4537' # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221004-3047'
# n_layer = 32 # n_layer = 32
# n_embd = 2560 # n_embd = 4096
# ctx_len = 1024 # ctx_len = 1024
os.environ["RWKV_FLOAT_MODE"] = "fp32" # currently only supprts fp32 (it can do bf16 and fp16. just wait a bit... busy these days) os.environ["RWKV_FLOAT_MODE"] = "fp32" # fp32 (faster at this moment) or bf16 (slower but saves VRAM)
os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda' os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda'
model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre' model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre'
@ -127,7 +134,13 @@ from src.model_run import RWKV_RNN
model = RWKV_RNN( model = RWKV_RNN(
MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len
) )
# input(0)
print(f'\nLoading tokenizer {WORD_NAME}...')
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
if TOKEN_MODE == "pile":
assert tokenizer.tokenizer.decode([187]) == '\n'
######################################################################################################## ########################################################################################################
@ -139,9 +152,9 @@ else:
src_len = len(ctx) src_len = len(ctx)
src_ctx = ctx.copy() src_ctx = ctx.copy()
print("Your prompt has " + str(src_len) + " tokens.") print("\nYour prompt has " + str(src_len) + " tokens.")
print( print(
"\nNote: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n" "Note: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n"
) )
time_slot = {} time_slot = {}
@ -154,24 +167,24 @@ def record_time(name):
if tt < time_slot[name]: if tt < time_slot[name]:
time_slot[name] = tt time_slot[name] = tt
init_state = None
init_out = None
state = None
out = None
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
print(("-" * 50) + '\n' + context, end="") print(("-" * 50) + '\n' + context, end="")
time_ref = time.time_ns() time_ref = time.time_ns()
ctx = src_ctx.copy() ctx = src_ctx.copy()
model.clear()
if TRIAL == 0: if TRIAL == 0:
init_state = types.SimpleNamespace()
for i in range(src_len): for i in range(src_len):
x = ctx[: i + 1] x = ctx[: i + 1]
if i == src_len - 1: if i == src_len - 1:
init_state.out = model.forward(x) init_out, init_state = model.forward(x, init_state)
else: else:
model.forward(x, preprocess_only=True) init_state = model.forward(x, init_state, preprocess_only=True)
model.save(init_state)
else:
model.load(init_state)
record_time('preprocess') record_time('preprocess')
out_last = src_len out_last = src_len
@ -180,12 +193,12 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
x = x[-ctx_len:] x = x[-ctx_len:]
if i == src_len: if i == src_len:
out = copy.deepcopy(init_state.out) out = init_out.clone()
state = init_state.clone()
else: else:
out = model.forward(x) out, state = model.forward(x, state)
if DEBUG_DEBUG: if DEBUG_DEBUG:
print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy())) print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy()))
if TOKEN_MODE == "pile": if TOKEN_MODE == "pile":
out[0] = -999999999 # disable <|endoftext|> out[0] = -999999999 # disable <|endoftext|>
@ -213,3 +226,5 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
print( print(
f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = '' f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = ''
) )
print(("-" * 50) + '\n')

@ -3,16 +3,13 @@
######################################################################################################## ########################################################################################################
import types import types
import copy
import torch import torch
import math, os import math, os, gc
from torch.nn import functional as F from torch.nn import functional as F
import torch.nn as nn import torch.nn as nn
def __nop(ob): def __nop(ob):
return ob return ob
MyModule = nn.Module MyModule = nn.Module
MyFunction = __nop MyFunction = __nop
# MyModule = torch.jit.ScriptModule # MyModule = torch.jit.ScriptModule
@ -25,7 +22,7 @@ DEBUG_TIME = False # True False - show trained time-coeffs
############################################################################################################ ############################################################################################################
class RWKV_RNN(MyModule): # this is running in FP32 at this moment class RWKV_RNN(MyModule):
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len): def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
super().__init__() super().__init__()
@ -35,20 +32,50 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment
self.n_embd = n_embd self.n_embd = n_embd
self.ctx_len = ctx_len self.ctx_len = ctx_len
self.w = types.SimpleNamespace() w = torch.load(MODEL_NAME + '.pth', map_location='cpu')
w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE)) # refine weights and send to correct device
for x in w.keys():
w[x] = w[x].float() keys = list(w.keys())
if 'pos_emb_x' in keys:
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(ctx_len+1, -1)[:-1,:]
keys = list(w.keys())
print_need_newline = False
for x in keys:
if '.time_' in x: if '.time_' in x:
w[x] = w[x].squeeze() w[x] = w[x].squeeze()
if DEBUG_TIME:
print(x, w[x].numpy())
if '.time_decay' in x: if '.time_decay' in x:
w[x] = w[x].float()
w[x] = -torch.exp(w[x]) w[x] = -torch.exp(w[x])
if 'pos_emb_x' in x: elif '.time_first' in x:
self.w.pos_emb = (w['pos_emb_x'] + w['pos_emb_y']).reshape(ctx_len+1, -1)[:-1,:] w[x] = w[x].float()
if DEBUG_TIME and '.time_' in x: else:
print(x, w[x].squeeze().cpu().numpy()) if os.environ["RWKV_FLOAT_MODE"] == "fp32":
w[x] = w[x].float()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
w[x] = w[x].bfloat16()
w[x].requires_grad = False
if RUN_DEVICE == 'cuda' and x != 'emb.weight':
w[x] = w[x].cuda()
if ('blocks.' not in x) or ('blocks.0.' in x):
if print_need_newline:
print('\n', end = '')
print_need_newline = False
print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device)
else:
print_need_newline = True
print('.', end = '', flush = True)
# store weights in self.w
keys = list(w.keys())
self.w = types.SimpleNamespace()
for x in keys:
xx = x.split('.') xx = x.split('.')
here = self.w here = self.w
for i in range(len(xx)): for i in range(len(xx)):
@ -67,41 +94,26 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment
setattr(here, xx[i], types.SimpleNamespace()) setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i]) here = getattr(here, xx[i])
self.clear()
self.eval() self.eval()
gc.collect()
def clear(self): torch.cuda.empty_cache()
self.xx = {}
self.aa = {}
self.bb = {}
self.pp = {}
self.hk = None
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
target.pp = copy.deepcopy(self.pp)
target.hk = copy.deepcopy(self.hk)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
self.pp = copy.deepcopy(target.pp)
self.hk = copy.deepcopy(target.hk)
@MyFunction @MyFunction
def LN(self, xx, w): def LN(self, x, w):
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias) return F.layer_norm(x, (self.n_embd,), weight=w.weight, bias=w.bias)
# state: ffn_xx att_xx att_aa att_bb att_pp
@MyFunction @MyFunction
def FF(self, xx, w, name): def FF(self, x, w, state, i):
if name not in self.xx: if os.environ["RWKV_FLOAT_MODE"] == "bf16":
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) xk = x * w.time_mix_k + state[5*i+0].bfloat16() * (1 - w.time_mix_k)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k) xr = x * w.time_mix_r + state[5*i+0].bfloat16() * (1 - w.time_mix_r)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r) state[5*i+0] = x.float()
self.xx[name] = xx else:
xk = x * w.time_mix_k + state[5*i+0] * (1 - w.time_mix_k)
xr = x * w.time_mix_r + state[5*i+0] * (1 - w.time_mix_r)
state[5*i+0] = x
r = torch.sigmoid(w.receptance.weight @ xr) r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.square(torch.relu(w.key.weight @ xk)) k = torch.square(torch.relu(w.key.weight @ xk))
@ -110,90 +122,92 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment
return r * kv return r * kv
@MyFunction @MyFunction
def SA(self, xx, w, name): def SA(self, x, w, state, i):
if name not in self.xx: if os.environ["RWKV_FLOAT_MODE"] == "bf16":
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) xk = x * w.time_mix_k + state[5*i+1].bfloat16() * (1 - w.time_mix_k)
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) xv = x * w.time_mix_v + state[5*i+1].bfloat16() * (1 - w.time_mix_v)
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) xr = x * w.time_mix_r + state[5*i+1].bfloat16() * (1 - w.time_mix_r)
self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30 state[5*i+1] = x.float()
else:
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k) xk = x * w.time_mix_k + state[5*i+1] * (1 - w.time_mix_k)
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v) xv = x * w.time_mix_v + state[5*i+1] * (1 - w.time_mix_v)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r) xr = x * w.time_mix_r + state[5*i+1] * (1 - w.time_mix_r)
self.xx[name] = xx state[5*i+1] = x
r = torch.sigmoid(w.receptance.weight @ xr) r = torch.sigmoid(w.receptance.weight @ xr)
k = w.key.weight @ xk k = w.key.weight @ xk
v = w.value.weight @ xv v = w.value.weight @ xv
pp = self.pp[name] if os.environ["RWKV_FLOAT_MODE"] == "bf16":
aa = self.aa[name] kk = k.float()
bb = self.bb[name] vv = v.float()
ww = w.time_first + k aa = state[5*i+2]
p = torch.maximum(pp, ww) bb = state[5*i+3]
e1 = torch.exp(pp - p) pp = state[5*i+4]
e2 = torch.exp(ww - p) ww = w.time_first + kk
a = e1 * aa + e2 * v p = torch.maximum(pp, ww)
b = e1 * bb + e2 e1 = torch.exp(pp - p)
ww = pp + w.time_decay e2 = torch.exp(ww - p)
p = torch.maximum(ww, k) a = e1 * aa + e2 * vv
e1 = torch.exp(ww - p) b = e1 * bb + e2
e2 = torch.exp(k - p) ww = pp + w.time_decay
self.aa[name] = e1 * aa + e2 * v p = torch.maximum(ww, kk)
self.bb[name] = e1 * bb + e2 e1 = torch.exp(ww - p)
self.pp[name] = p e2 = torch.exp(kk - p)
state[5*i+2] = e1 * aa + e2 * vv
rwkv = r * a / b state[5*i+3] = e1 * bb + e2
state[5*i+4] = p
rwkv = r * (a / b).bfloat16()
else:
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = w.time_first + k
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * v
b = e1 * bb + e2
ww = pp + w.time_decay
p = torch.maximum(ww, k)
e1 = torch.exp(ww - p)
e2 = torch.exp(k - p)
state[5*i+2] = e1 * aa + e2 * v
state[5*i+3] = e1 * bb + e2
state[5*i+4] = p
rwkv = r * a / b
return w.output.weight @ rwkv return w.output.weight @ rwkv
def forward(self, ctx, preprocess_only = False): def forward(self, ctx, state, preprocess_only = False):
with torch.no_grad(): with torch.no_grad():
w = self.w w = self.w
x = w.emb.weight[ctx[-1]] x = w.emb.weight[ctx[-1]]
if self.RUN_DEVICE == 'cuda':
x = x.cuda()
try: try:
pos_emb = w.pos_emb[len(ctx)-1] pos_emb = w.pos_emb[len(ctx)-1]
x = x + pos_emb x = x + pos_emb
except: except:
pass pass
if state == None:
state = torch.zeros(self.n_layer * 5, self.n_embd, device=self.RUN_DEVICE)
for i in range(self.n_layer):
state[5*i+4] -= 1e30
for i in range(self.n_layer): for i in range(self.n_layer):
if i == 0: if i == 0:
x = self.LN(x, w.blocks[i].ln0) x = self.LN(x, w.blocks[i].ln0)
if i == 0 and self.model_type == 'RWKV-ffnPre': x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, state, i)
x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}') x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, state, i)
else:
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
if RWKV_HEAD_QK_DIM > 0:
if self.hk == None:
self.hk = (w.head_k.weight @ x).unsqueeze(0)
else:
self.hk = torch.cat(
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
if self.hk.shape[0] > self.ctx_len:
self.hk = self.hk[-self.ctx_len:, :]
if preprocess_only: if preprocess_only:
return None return state
q = w.head_q.weight @ x x = self.LN(x, w.ln_out)
x = w.head.weight @ x
x = w.head.weight @ x
x = x
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
for i in range(len(c)):
x[ctx[i]] += c[i]
else:
if preprocess_only:
return None
x = w.head.weight @ x
x = x
return x return x.float(), state

Loading…
Cancel
Save