From daed379db29be534c6a42ddba8e6b6be34980566 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Wed, 5 Oct 2022 21:03:35 +0000 Subject: [PATCH] bf16 inference - 15G VRAM for 7b model --- RWKV-v4neo/run.py | 57 +++++---- RWKV-v4neo/src/model_run.py | 228 +++++++++++++++++++----------------- 2 files changed, 157 insertions(+), 128 deletions(-) diff --git a/RWKV-v4neo/run.py b/RWKV-v4neo/run.py index dc43838..1ef9073 100644 --- a/RWKV-v4neo/run.py +++ b/RWKV-v4neo/run.py @@ -3,13 +3,15 @@ ######################################################################################################## import numpy as np -import math, os +import math, os, sys import time -import types -import copy import torch from src.utils import TOKENIZER +try: + os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] +except: + pass torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True @@ -62,17 +64,22 @@ elif TOKEN_MODE == "pile": # n_embd = 1024 # ctx_len = 1024 - MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-1B5-20220903-8040' - n_layer = 24 - n_embd = 2048 + # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220929-ctx4096' + # n_layer = 24 + # n_embd = 2048 + # ctx_len = 1024 + + MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783' + n_layer = 32 + n_embd = 2560 ctx_len = 1024 - # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20220925-4537' + # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221004-3047' # n_layer = 32 - # n_embd = 2560 + # n_embd = 4096 # ctx_len = 1024 -os.environ["RWKV_FLOAT_MODE"] = "fp32" # currently only supprts fp32 (it can do bf16 and fp16. just wait a bit... busy these days) +os.environ["RWKV_FLOAT_MODE"] = "fp32" # fp32 (faster at this moment) or bf16 (slower but saves VRAM) os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda' model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre' @@ -127,7 +134,13 @@ from src.model_run import RWKV_RNN model = RWKV_RNN( MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len ) + +# input(0) + +print(f'\nLoading tokenizer {WORD_NAME}...') tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) +if TOKEN_MODE == "pile": + assert tokenizer.tokenizer.decode([187]) == '\n' ######################################################################################################## @@ -139,9 +152,9 @@ else: src_len = len(ctx) src_ctx = ctx.copy() -print("Your prompt has " + str(src_len) + " tokens.") +print("\nYour prompt has " + str(src_len) + " tokens.") print( - "\nNote: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n" + "Note: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n" ) time_slot = {} @@ -154,24 +167,24 @@ def record_time(name): if tt < time_slot[name]: time_slot[name] = tt +init_state = None +init_out = None +state = None +out = None + for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): print(("-" * 50) + '\n' + context, end="") time_ref = time.time_ns() ctx = src_ctx.copy() - model.clear() if TRIAL == 0: - init_state = types.SimpleNamespace() for i in range(src_len): x = ctx[: i + 1] if i == src_len - 1: - init_state.out = model.forward(x) + init_out, init_state = model.forward(x, init_state) else: - model.forward(x, preprocess_only=True) - model.save(init_state) - else: - model.load(init_state) + init_state = model.forward(x, init_state, preprocess_only=True) record_time('preprocess') out_last = src_len @@ -180,12 +193,12 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): x = x[-ctx_len:] if i == src_len: - out = copy.deepcopy(init_state.out) + out = init_out.clone() + state = init_state.clone() else: - out = model.forward(x) + out, state = model.forward(x, state) if DEBUG_DEBUG: print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy())) - if TOKEN_MODE == "pile": out[0] = -999999999 # disable <|endoftext|> @@ -213,3 +226,5 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): print( f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = '' ) + +print(("-" * 50) + '\n') diff --git a/RWKV-v4neo/src/model_run.py b/RWKV-v4neo/src/model_run.py index 18be189..209ae21 100644 --- a/RWKV-v4neo/src/model_run.py +++ b/RWKV-v4neo/src/model_run.py @@ -3,16 +3,13 @@ ######################################################################################################## import types -import copy import torch -import math, os +import math, os, gc from torch.nn import functional as F import torch.nn as nn def __nop(ob): return ob - - MyModule = nn.Module MyFunction = __nop # MyModule = torch.jit.ScriptModule @@ -25,7 +22,7 @@ DEBUG_TIME = False # True False - show trained time-coeffs ############################################################################################################ -class RWKV_RNN(MyModule): # this is running in FP32 at this moment +class RWKV_RNN(MyModule): def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len): super().__init__() @@ -35,20 +32,50 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment self.n_embd = n_embd self.ctx_len = ctx_len - self.w = types.SimpleNamespace() + w = torch.load(MODEL_NAME + '.pth', map_location='cpu') - w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE)) - for x in w.keys(): - w[x] = w[x].float() + # refine weights and send to correct device + + keys = list(w.keys()) + if 'pos_emb_x' in keys: + w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(ctx_len+1, -1)[:-1,:] + + keys = list(w.keys()) + print_need_newline = False + for x in keys: if '.time_' in x: w[x] = w[x].squeeze() + if DEBUG_TIME: + print(x, w[x].numpy()) if '.time_decay' in x: + w[x] = w[x].float() w[x] = -torch.exp(w[x]) - if 'pos_emb_x' in x: - self.w.pos_emb = (w['pos_emb_x'] + w['pos_emb_y']).reshape(ctx_len+1, -1)[:-1,:] - if DEBUG_TIME and '.time_' in x: - print(x, w[x].squeeze().cpu().numpy()) + elif '.time_first' in x: + w[x] = w[x].float() + else: + if os.environ["RWKV_FLOAT_MODE"] == "fp32": + w[x] = w[x].float() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + w[x] = w[x].bfloat16() + + w[x].requires_grad = False + if RUN_DEVICE == 'cuda' and x != 'emb.weight': + w[x] = w[x].cuda() + + if ('blocks.' not in x) or ('blocks.0.' in x): + if print_need_newline: + print('\n', end = '') + print_need_newline = False + print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device) + else: + print_need_newline = True + print('.', end = '', flush = True) + + # store weights in self.w + keys = list(w.keys()) + self.w = types.SimpleNamespace() + for x in keys: xx = x.split('.') here = self.w for i in range(len(xx)): @@ -67,41 +94,26 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment setattr(here, xx[i], types.SimpleNamespace()) here = getattr(here, xx[i]) - self.clear() self.eval() - - def clear(self): - self.xx = {} - self.aa = {} - self.bb = {} - self.pp = {} - self.hk = None - - def save(self, target): - target.xx = copy.deepcopy(self.xx) - target.aa = copy.deepcopy(self.aa) - target.bb = copy.deepcopy(self.bb) - target.pp = copy.deepcopy(self.pp) - target.hk = copy.deepcopy(self.hk) - - def load(self, target): - self.xx = copy.deepcopy(target.xx) - self.aa = copy.deepcopy(target.aa) - self.bb = copy.deepcopy(target.bb) - self.pp = copy.deepcopy(target.pp) - self.hk = copy.deepcopy(target.hk) + gc.collect() + torch.cuda.empty_cache() @MyFunction - def LN(self, xx, w): - return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias) + def LN(self, x, w): + return F.layer_norm(x, (self.n_embd,), weight=w.weight, bias=w.bias) + + # state: ffn_xx att_xx att_aa att_bb att_pp @MyFunction - def FF(self, xx, w, name): - if name not in self.xx: - self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k) - xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r) - self.xx[name] = xx + def FF(self, x, w, state, i): + if os.environ["RWKV_FLOAT_MODE"] == "bf16": + xk = x * w.time_mix_k + state[5*i+0].bfloat16() * (1 - w.time_mix_k) + xr = x * w.time_mix_r + state[5*i+0].bfloat16() * (1 - w.time_mix_r) + state[5*i+0] = x.float() + else: + xk = x * w.time_mix_k + state[5*i+0] * (1 - w.time_mix_k) + xr = x * w.time_mix_r + state[5*i+0] * (1 - w.time_mix_r) + state[5*i+0] = x r = torch.sigmoid(w.receptance.weight @ xr) k = torch.square(torch.relu(w.key.weight @ xk)) @@ -110,90 +122,92 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment return r * kv @MyFunction - def SA(self, xx, w, name): - if name not in self.xx: - self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30 - - xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k) - xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v) - xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r) - self.xx[name] = xx + def SA(self, x, w, state, i): + if os.environ["RWKV_FLOAT_MODE"] == "bf16": + xk = x * w.time_mix_k + state[5*i+1].bfloat16() * (1 - w.time_mix_k) + xv = x * w.time_mix_v + state[5*i+1].bfloat16() * (1 - w.time_mix_v) + xr = x * w.time_mix_r + state[5*i+1].bfloat16() * (1 - w.time_mix_r) + state[5*i+1] = x.float() + else: + xk = x * w.time_mix_k + state[5*i+1] * (1 - w.time_mix_k) + xv = x * w.time_mix_v + state[5*i+1] * (1 - w.time_mix_v) + xr = x * w.time_mix_r + state[5*i+1] * (1 - w.time_mix_r) + state[5*i+1] = x r = torch.sigmoid(w.receptance.weight @ xr) k = w.key.weight @ xk v = w.value.weight @ xv - pp = self.pp[name] - aa = self.aa[name] - bb = self.bb[name] - ww = w.time_first + k - p = torch.maximum(pp, ww) - e1 = torch.exp(pp - p) - e2 = torch.exp(ww - p) - a = e1 * aa + e2 * v - b = e1 * bb + e2 - ww = pp + w.time_decay - p = torch.maximum(ww, k) - e1 = torch.exp(ww - p) - e2 = torch.exp(k - p) - self.aa[name] = e1 * aa + e2 * v - self.bb[name] = e1 * bb + e2 - self.pp[name] = p - - rwkv = r * a / b + if os.environ["RWKV_FLOAT_MODE"] == "bf16": + kk = k.float() + vv = v.float() + aa = state[5*i+2] + bb = state[5*i+3] + pp = state[5*i+4] + ww = w.time_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + a = e1 * aa + e2 * vv + b = e1 * bb + e2 + ww = pp + w.time_decay + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + state[5*i+2] = e1 * aa + e2 * vv + state[5*i+3] = e1 * bb + e2 + state[5*i+4] = p + rwkv = r * (a / b).bfloat16() + else: + aa = state[5*i+2] + bb = state[5*i+3] + pp = state[5*i+4] + ww = w.time_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + a = e1 * aa + e2 * v + b = e1 * bb + e2 + ww = pp + w.time_decay + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + state[5*i+2] = e1 * aa + e2 * v + state[5*i+3] = e1 * bb + e2 + state[5*i+4] = p + rwkv = r * a / b return w.output.weight @ rwkv - def forward(self, ctx, preprocess_only = False): + def forward(self, ctx, state, preprocess_only = False): with torch.no_grad(): w = self.w + x = w.emb.weight[ctx[-1]] + if self.RUN_DEVICE == 'cuda': + x = x.cuda() try: pos_emb = w.pos_emb[len(ctx)-1] x = x + pos_emb except: - pass + pass + + if state == None: + state = torch.zeros(self.n_layer * 5, self.n_embd, device=self.RUN_DEVICE) + for i in range(self.n_layer): + state[5*i+4] -= 1e30 for i in range(self.n_layer): if i == 0: x = self.LN(x, w.blocks[i].ln0) - if i == 0 and self.model_type == 'RWKV-ffnPre': - x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}') - else: - x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}') - x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}') - - x = self.LN(x, w.ln_out) - - if RWKV_HEAD_QK_DIM > 0: - if self.hk == None: - self.hk = (w.head_k.weight @ x).unsqueeze(0) - else: - self.hk = torch.cat( - [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0) - if self.hk.shape[0] > self.ctx_len: - self.hk = self.hk[-self.ctx_len:, :] + x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, state, i) + x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, state, i) - if preprocess_only: - return None + if preprocess_only: + return state - q = w.head_q.weight @ x - - x = w.head.weight @ x - x = x - - c = (self.hk @ q) / RWKV_HEAD_QK_DIM - for i in range(len(c)): - x[ctx[i]] += c[i] - else: - if preprocess_only: - return None - - x = w.head.weight @ x - x = x + x = self.LN(x, w.ln_out) + x = w.head.weight @ x - return x + return x.float(), state