bf16 inference - 15G VRAM for 7b model

main
BlinkDL 3 years ago
parent 1479315677
commit daed379db2

@ -3,13 +3,15 @@
########################################################################################################
import numpy as np
import math, os
import math, os, sys
import time
import types
import copy
import torch
from src.utils import TOKENIZER
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
@ -62,17 +64,22 @@ elif TOKEN_MODE == "pile":
# n_embd = 1024
# ctx_len = 1024
MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-1B5-20220903-8040'
n_layer = 24
n_embd = 2048
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220929-ctx4096'
# n_layer = 24
# n_embd = 2048
# ctx_len = 1024
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
n_layer = 32
n_embd = 2560
ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20220925-4537'
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221004-3047'
# n_layer = 32
# n_embd = 2560
# n_embd = 4096
# ctx_len = 1024
os.environ["RWKV_FLOAT_MODE"] = "fp32" # currently only supprts fp32 (it can do bf16 and fp16. just wait a bit... busy these days)
os.environ["RWKV_FLOAT_MODE"] = "fp32" # fp32 (faster at this moment) or bf16 (slower but saves VRAM)
os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda'
model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre'
@ -127,7 +134,13 @@ from src.model_run import RWKV_RNN
model = RWKV_RNN(
MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len
)
# input(0)
print(f'\nLoading tokenizer {WORD_NAME}...')
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
if TOKEN_MODE == "pile":
assert tokenizer.tokenizer.decode([187]) == '\n'
########################################################################################################
@ -139,9 +152,9 @@ else:
src_len = len(ctx)
src_ctx = ctx.copy()
print("Your prompt has " + str(src_len) + " tokens.")
print("\nYour prompt has " + str(src_len) + " tokens.")
print(
"\nNote: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n"
"Note: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n"
)
time_slot = {}
@ -154,24 +167,24 @@ def record_time(name):
if tt < time_slot[name]:
time_slot[name] = tt
init_state = None
init_out = None
state = None
out = None
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
print(("-" * 50) + '\n' + context, end="")
time_ref = time.time_ns()
ctx = src_ctx.copy()
model.clear()
if TRIAL == 0:
init_state = types.SimpleNamespace()
for i in range(src_len):
x = ctx[: i + 1]
if i == src_len - 1:
init_state.out = model.forward(x)
init_out, init_state = model.forward(x, init_state)
else:
model.forward(x, preprocess_only=True)
model.save(init_state)
else:
model.load(init_state)
init_state = model.forward(x, init_state, preprocess_only=True)
record_time('preprocess')
out_last = src_len
@ -180,12 +193,12 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
x = x[-ctx_len:]
if i == src_len:
out = copy.deepcopy(init_state.out)
out = init_out.clone()
state = init_state.clone()
else:
out = model.forward(x)
out, state = model.forward(x, state)
if DEBUG_DEBUG:
print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy()))
if TOKEN_MODE == "pile":
out[0] = -999999999 # disable <|endoftext|>
@ -213,3 +226,5 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
print(
f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = ''
)
print(("-" * 50) + '\n')

@ -3,16 +3,13 @@
########################################################################################################
import types
import copy
import torch
import math, os
import math, os, gc
from torch.nn import functional as F
import torch.nn as nn
def __nop(ob):
return ob
MyModule = nn.Module
MyFunction = __nop
# MyModule = torch.jit.ScriptModule
@ -25,7 +22,7 @@ DEBUG_TIME = False # True False - show trained time-coeffs
############################################################################################################
class RWKV_RNN(MyModule): # this is running in FP32 at this moment
class RWKV_RNN(MyModule):
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
super().__init__()
@ -35,20 +32,50 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment
self.n_embd = n_embd
self.ctx_len = ctx_len
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth', map_location='cpu')
w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE))
for x in w.keys():
w[x] = w[x].float()
# refine weights and send to correct device
keys = list(w.keys())
if 'pos_emb_x' in keys:
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(ctx_len+1, -1)[:-1,:]
keys = list(w.keys())
print_need_newline = False
for x in keys:
if '.time_' in x:
w[x] = w[x].squeeze()
if DEBUG_TIME:
print(x, w[x].numpy())
if '.time_decay' in x:
w[x] = w[x].float()
w[x] = -torch.exp(w[x])
if 'pos_emb_x' in x:
self.w.pos_emb = (w['pos_emb_x'] + w['pos_emb_y']).reshape(ctx_len+1, -1)[:-1,:]
if DEBUG_TIME and '.time_' in x:
print(x, w[x].squeeze().cpu().numpy())
elif '.time_first' in x:
w[x] = w[x].float()
else:
if os.environ["RWKV_FLOAT_MODE"] == "fp32":
w[x] = w[x].float()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
w[x] = w[x].bfloat16()
w[x].requires_grad = False
if RUN_DEVICE == 'cuda' and x != 'emb.weight':
w[x] = w[x].cuda()
if ('blocks.' not in x) or ('blocks.0.' in x):
if print_need_newline:
print('\n', end = '')
print_need_newline = False
print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device)
else:
print_need_newline = True
print('.', end = '', flush = True)
# store weights in self.w
keys = list(w.keys())
self.w = types.SimpleNamespace()
for x in keys:
xx = x.split('.')
here = self.w
for i in range(len(xx)):
@ -67,41 +94,26 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
self.eval()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
self.pp = {}
self.hk = None
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
target.pp = copy.deepcopy(self.pp)
target.hk = copy.deepcopy(self.hk)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
self.pp = copy.deepcopy(target.pp)
self.hk = copy.deepcopy(target.hk)
gc.collect()
torch.cuda.empty_cache()
@MyFunction
def LN(self, xx, w):
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
def LN(self, x, w):
return F.layer_norm(x, (self.n_embd,), weight=w.weight, bias=w.bias)
# state: ffn_xx att_xx att_aa att_bb att_pp
@MyFunction
def FF(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
def FF(self, x, w, state, i):
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
xk = x * w.time_mix_k + state[5*i+0].bfloat16() * (1 - w.time_mix_k)
xr = x * w.time_mix_r + state[5*i+0].bfloat16() * (1 - w.time_mix_r)
state[5*i+0] = x.float()
else:
xk = x * w.time_mix_k + state[5*i+0] * (1 - w.time_mix_k)
xr = x * w.time_mix_r + state[5*i+0] * (1 - w.time_mix_r)
state[5*i+0] = x
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.square(torch.relu(w.key.weight @ xk))
@ -110,90 +122,92 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment
return r * kv
@MyFunction
def SA(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
def SA(self, x, w, state, i):
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
xk = x * w.time_mix_k + state[5*i+1].bfloat16() * (1 - w.time_mix_k)
xv = x * w.time_mix_v + state[5*i+1].bfloat16() * (1 - w.time_mix_v)
xr = x * w.time_mix_r + state[5*i+1].bfloat16() * (1 - w.time_mix_r)
state[5*i+1] = x.float()
else:
xk = x * w.time_mix_k + state[5*i+1] * (1 - w.time_mix_k)
xv = x * w.time_mix_v + state[5*i+1] * (1 - w.time_mix_v)
xr = x * w.time_mix_r + state[5*i+1] * (1 - w.time_mix_r)
state[5*i+1] = x
r = torch.sigmoid(w.receptance.weight @ xr)
k = w.key.weight @ xk
v = w.value.weight @ xv
pp = self.pp[name]
aa = self.aa[name]
bb = self.bb[name]
ww = w.time_first + k
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * v
b = e1 * bb + e2
ww = pp + w.time_decay
p = torch.maximum(ww, k)
e1 = torch.exp(ww - p)
e2 = torch.exp(k - p)
self.aa[name] = e1 * aa + e2 * v
self.bb[name] = e1 * bb + e2
self.pp[name] = p
rwkv = r * a / b
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
kk = k.float()
vv = v.float()
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = w.time_first + kk
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * vv
b = e1 * bb + e2
ww = pp + w.time_decay
p = torch.maximum(ww, kk)
e1 = torch.exp(ww - p)
e2 = torch.exp(kk - p)
state[5*i+2] = e1 * aa + e2 * vv
state[5*i+3] = e1 * bb + e2
state[5*i+4] = p
rwkv = r * (a / b).bfloat16()
else:
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = w.time_first + k
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * v
b = e1 * bb + e2
ww = pp + w.time_decay
p = torch.maximum(ww, k)
e1 = torch.exp(ww - p)
e2 = torch.exp(k - p)
state[5*i+2] = e1 * aa + e2 * v
state[5*i+3] = e1 * bb + e2
state[5*i+4] = p
rwkv = r * a / b
return w.output.weight @ rwkv
def forward(self, ctx, preprocess_only = False):
def forward(self, ctx, state, preprocess_only = False):
with torch.no_grad():
w = self.w
x = w.emb.weight[ctx[-1]]
if self.RUN_DEVICE == 'cuda':
x = x.cuda()
try:
pos_emb = w.pos_emb[len(ctx)-1]
x = x + pos_emb
except:
pass
if state == None:
state = torch.zeros(self.n_layer * 5, self.n_embd, device=self.RUN_DEVICE)
for i in range(self.n_layer):
state[5*i+4] -= 1e30
for i in range(self.n_layer):
if i == 0:
x = self.LN(x, w.blocks[i].ln0)
if i == 0 and self.model_type == 'RWKV-ffnPre':
x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
else:
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
if RWKV_HEAD_QK_DIM > 0:
if self.hk == None:
self.hk = (w.head_k.weight @ x).unsqueeze(0)
else:
self.hk = torch.cat(
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
if self.hk.shape[0] > self.ctx_len:
self.hk = self.hk[-self.ctx_len:, :]
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, state, i)
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, state, i)
if preprocess_only:
return None
if preprocess_only:
return state
q = w.head_q.weight @ x
x = w.head.weight @ x
x = x
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
for i in range(len(c)):
x[ctx[i]] += c[i]
else:
if preprocess_only:
return None
x = w.head.weight @ x
x = x
x = self.LN(x, w.ln_out)
x = w.head.weight @ x
return x
return x.float(), state

Loading…
Cancel
Save