diff --git a/RWKV-v4neo/run.py b/RWKV-v4neo/run.py index 1ef9073..2f28e63 100644 --- a/RWKV-v4neo/run.py +++ b/RWKV-v4neo/run.py @@ -3,11 +3,9 @@ ######################################################################################################## import numpy as np -import math, os, sys -import time +import math, os, sys, types, time, gc import torch from src.utils import TOKENIZER - try: os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] except: @@ -16,72 +14,67 @@ torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True np.set_printoptions(precision=4, suppress=True, linewidth=200) +args = types.SimpleNamespace() ######################################################################################################## -# Step 1: set model -# -# Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch. -# -# Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models. +# Step 1: set model & config +# Do this first: pip install torchdynamo ######################################################################################################## -TOKEN_MODE = "pile" # char / bpe / pile - -n_layer = 6 -n_embd = 512 +args.RUN_DEVICE = "cpu" # 'cpu' (already very fast) // 'cuda' +args.FLOAT_MODE = "fp32" # fp32 // bf16 (saves VRAM, slightly less accurate) +# if args.RUN_DEVICE == "cuda": +# os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output + +TOKEN_MODE = "pile" +WORD_NAME = [ + "20B_tokenizer.json", + "20B_tokenizer.json", +] # [vocab, vocab] for Pile model +UNKNOWN_CHAR = None +vocab_size = 50277 + +# note; you can set MODEL_NAME to your fine-tuned model +# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023" +# n_layer = 12 +# n_embd = 768 +# ctx_len = 1024 + +# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066' +# n_layer = 24 +# n_embd = 1024 +# ctx_len = 1024 + +# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040' +# n_layer = 24 +# n_embd = 2048 +# ctx_len = 1024 + +# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220929-ctx4096' +# n_layer = 24 +# n_embd = 2048 +# ctx_len = 4096 + +MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783' +n_layer = 32 +n_embd = 2560 ctx_len = 1024 -if TOKEN_MODE == "char": - MODEL_NAME = "trained-500" # your trained model - WORD_NAME = "vocab" # the .json vocab (generated by train.py) - # set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it - UNKNOWN_CHAR = " " # here we just set it to ' ' for simplicity - -elif TOKEN_MODE == "bpe": - MODEL_NAME = "trained-500" # your trained model - WORD_NAME = [ - "model-vocab.json", - "model-merges.txt", - ] # [vocab, merge] for your BPE model - UNKNOWN_CHAR = None - -elif TOKEN_MODE == "pile": - WORD_NAME = [ - "20B_tokenizer.json", - "20B_tokenizer.json", - ] # [vocab, vocab] for Pile model - UNKNOWN_CHAR = None - - # ---> you can set MODEL_NAME to your fine-tuned model <--- - - # MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023" - # n_layer = 12 - # n_embd = 768 - # ctx_len = 1024 - - # MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066' - # n_layer = 24 - # n_embd = 1024 - # ctx_len = 1024 - - # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220929-ctx4096' - # n_layer = 24 - # n_embd = 2048 - # ctx_len = 1024 - - MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783' - n_layer = 32 - n_embd = 2560 - ctx_len = 1024 - - # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221004-3047' - # n_layer = 32 - # n_embd = 4096 - # ctx_len = 1024 - -os.environ["RWKV_FLOAT_MODE"] = "fp32" # fp32 (faster at this moment) or bf16 (slower but saves VRAM) -os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda' -model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre' +# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221004-3047' +# n_layer = 32 +# n_embd = 4096 +# ctx_len = 1024 + +args.MODEL_NAME = MODEL_NAME +args.n_layer = n_layer +args.n_embd = n_embd +args.ctx_len = ctx_len +args.vocab_size = vocab_size +args.head_qk = 0 +args.pre_ffn = 0 +args.grad_cp = 0 +args.my_pos_emb = 0 +os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE ######################################################################################################## # Step 2: set prompt & sampling stuffs @@ -128,12 +121,15 @@ DEBUG_DEBUG = False # True False --> show softmax output ######################################################################################################## -print(f'\nUsing {os.environ["RWKV_RUN_DEVICE"].upper()}. Loading {MODEL_NAME}...') +print(f'\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...') from src.model_run import RWKV_RNN -model = RWKV_RNN( - MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len -) +model = RWKV_RNN(args) + +print(f'\nOptimizing speed...') +model.forward([187], None) +gc.collect() +torch.cuda.empty_cache() # input(0) @@ -185,6 +181,8 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): init_out, init_state = model.forward(x, init_state) else: init_state = model.forward(x, init_state, preprocess_only=True) + gc.collect() + torch.cuda.empty_cache() record_time('preprocess') out_last = src_len diff --git a/RWKV-v4neo/src/model_run.py b/RWKV-v4neo/src/model_run.py index 209ae21..f3325eb 100644 --- a/RWKV-v4neo/src/model_run.py +++ b/RWKV-v4neo/src/model_run.py @@ -7,13 +7,15 @@ import torch import math, os, gc from torch.nn import functional as F import torch.nn as nn +from typing import List, Dict +# try: +# import torchdynamo +# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output +# except: def __nop(ob): return ob -MyModule = nn.Module MyFunction = __nop -# MyModule = torch.jit.ScriptModule -# MyFunction = torch.jit.script_method RWKV_HEAD_QK_DIM = 0 print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') @@ -22,57 +24,52 @@ DEBUG_TIME = False # True False - show trained time-coeffs ############################################################################################################ -class RWKV_RNN(MyModule): - def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len): +class RWKV_RNN(nn.Module): + def __init__(self, args): super().__init__() - self.RUN_DEVICE = RUN_DEVICE - self.model_type = model_type - self.n_layer = n_layer - self.n_embd = n_embd - self.ctx_len = ctx_len + self.args = args + self.FLOAT_MODE = args.FLOAT_MODE + self.RUN_DEVICE = args.RUN_DEVICE - w = torch.load(MODEL_NAME + '.pth', map_location='cpu') - - # refine weights and send to correct device - - keys = list(w.keys()) - if 'pos_emb_x' in keys: - w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(ctx_len+1, -1)[:-1,:] - - keys = list(w.keys()) - print_need_newline = False - for x in keys: - if '.time_' in x: - w[x] = w[x].squeeze() - if DEBUG_TIME: - print(x, w[x].numpy()) - if '.time_decay' in x: - w[x] = w[x].float() - w[x] = -torch.exp(w[x]) - elif '.time_first' in x: - w[x] = w[x].float() - else: - if os.environ["RWKV_FLOAT_MODE"] == "fp32": + with torch.no_grad(): + w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') + # refine weights and send to correct device + keys = list(w.keys()) + if 'pos_emb_x' in keys: + w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:] + keys = list(w.keys()) + print_need_newline = False + for x in keys: + if '.time_' in x: + w[x] = w[x].squeeze() + if DEBUG_TIME: + print(x, w[x].numpy()) + if '.time_decay' in x: w[x] = w[x].float() - elif os.environ["RWKV_FLOAT_MODE"] == "bf16": - w[x] = w[x].bfloat16() - - w[x].requires_grad = False - if RUN_DEVICE == 'cuda' and x != 'emb.weight': - w[x] = w[x].cuda() - - if ('blocks.' not in x) or ('blocks.0.' in x): - if print_need_newline: - print('\n', end = '') - print_need_newline = False - print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device) - else: - print_need_newline = True - print('.', end = '', flush = True) + w[x] = -torch.exp(w[x]) + elif '.time_first' in x: + w[x] = w[x].float() + else: + if self.FLOAT_MODE == "fp32": + w[x] = w[x].float() + elif self.FLOAT_MODE == "bf16": + w[x] = w[x].bfloat16() + + w[x].requires_grad = False + if args.RUN_DEVICE == 'cuda' and x != 'emb.weight': + w[x] = w[x].cuda() + + if ('blocks.' not in x) or ('blocks.0.' in x): + if print_need_newline: + print('\n', end = '') + print_need_newline = False + print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device) + else: + print_need_newline = True + print('.', end = '', flush = True) # store weights in self.w - keys = list(w.keys()) self.w = types.SimpleNamespace() for x in keys: @@ -98,91 +95,78 @@ class RWKV_RNN(MyModule): gc.collect() torch.cuda.empty_cache() - @MyFunction def LN(self, x, w): - return F.layer_norm(x, (self.n_embd,), weight=w.weight, bias=w.bias) + return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias) - # state: ffn_xx att_xx att_aa att_bb att_pp + # state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp @MyFunction - def FF(self, x, w, state, i): - if os.environ["RWKV_FLOAT_MODE"] == "bf16": - xk = x * w.time_mix_k + state[5*i+0].bfloat16() * (1 - w.time_mix_k) - xr = x * w.time_mix_r + state[5*i+0].bfloat16() * (1 - w.time_mix_r) + def FF(self, x, state, i, time_mix_k, time_mix_r, kw, vw, rw): + if self.FLOAT_MODE == "bf16": + xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k) + xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r) state[5*i+0] = x.float() else: - xk = x * w.time_mix_k + state[5*i+0] * (1 - w.time_mix_k) - xr = x * w.time_mix_r + state[5*i+0] * (1 - w.time_mix_r) + xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k) + xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r) state[5*i+0] = x - r = torch.sigmoid(w.receptance.weight @ xr) - k = torch.square(torch.relu(w.key.weight @ xk)) - kv = w.value.weight @ k + r = torch.sigmoid(rw @ xr) + k = torch.square(torch.relu(kw @ xk)) + kv = vw @ k return r * kv @MyFunction - def SA(self, x, w, state, i): - if os.environ["RWKV_FLOAT_MODE"] == "bf16": - xk = x * w.time_mix_k + state[5*i+1].bfloat16() * (1 - w.time_mix_k) - xv = x * w.time_mix_v + state[5*i+1].bfloat16() * (1 - w.time_mix_v) - xr = x * w.time_mix_r + state[5*i+1].bfloat16() * (1 - w.time_mix_r) + def SA(self, x, state, i, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): + if self.FLOAT_MODE == "bf16": + xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k) + xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v) + xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r) state[5*i+1] = x.float() else: - xk = x * w.time_mix_k + state[5*i+1] * (1 - w.time_mix_k) - xv = x * w.time_mix_v + state[5*i+1] * (1 - w.time_mix_v) - xr = x * w.time_mix_r + state[5*i+1] * (1 - w.time_mix_r) + xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k) + xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v) + xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r) state[5*i+1] = x - r = torch.sigmoid(w.receptance.weight @ xr) - - k = w.key.weight @ xk - v = w.value.weight @ xv + r = torch.sigmoid(rw @ xr) + k = kw @ xk + v = vw @ xv - if os.environ["RWKV_FLOAT_MODE"] == "bf16": + if self.FLOAT_MODE == "bf16": kk = k.float() vv = v.float() - aa = state[5*i+2] - bb = state[5*i+3] - pp = state[5*i+4] - ww = w.time_first + kk - p = torch.maximum(pp, ww) - e1 = torch.exp(pp - p) - e2 = torch.exp(ww - p) - a = e1 * aa + e2 * vv - b = e1 * bb + e2 - ww = pp + w.time_decay - p = torch.maximum(ww, kk) - e1 = torch.exp(ww - p) - e2 = torch.exp(kk - p) - state[5*i+2] = e1 * aa + e2 * vv - state[5*i+3] = e1 * bb + e2 - state[5*i+4] = p - rwkv = r * (a / b).bfloat16() else: - aa = state[5*i+2] - bb = state[5*i+3] - pp = state[5*i+4] - ww = w.time_first + k - p = torch.maximum(pp, ww) - e1 = torch.exp(pp - p) - e2 = torch.exp(ww - p) - a = e1 * aa + e2 * v - b = e1 * bb + e2 - ww = pp + w.time_decay - p = torch.maximum(ww, k) - e1 = torch.exp(ww - p) - e2 = torch.exp(k - p) - state[5*i+2] = e1 * aa + e2 * v - state[5*i+3] = e1 * bb + e2 - state[5*i+4] = p - rwkv = r * a / b - - return w.output.weight @ rwkv + kk = k + vv = v + aa = state[5*i+2] + bb = state[5*i+3] + pp = state[5*i+4] + ww = time_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + a = e1 * aa + e2 * vv + b = e1 * bb + e2 + ww = pp + time_decay + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + state[5*i+2] = e1 * aa + e2 * vv + state[5*i+3] = e1 * bb + e2 + state[5*i+4] = p + if self.FLOAT_MODE == "bf16": + wkv = (a / b).type(torch.bfloat16) + else: + wkv = a / b + + return ow @ (r * wkv) def forward(self, ctx, state, preprocess_only = False): with torch.no_grad(): w = self.w + args = self.args x = w.emb.weight[ctx[-1]] if self.RUN_DEVICE == 'cuda': @@ -194,15 +178,23 @@ class RWKV_RNN(MyModule): pass if state == None: - state = torch.zeros(self.n_layer * 5, self.n_embd, device=self.RUN_DEVICE) - for i in range(self.n_layer): + state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE) + for i in range(args.n_layer): state[5*i+4] -= 1e30 - for i in range(self.n_layer): + for i in range(args.n_layer): if i == 0: x = self.LN(x, w.blocks[i].ln0) - x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, state, i) - x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, state, i) + + ww = w.blocks[i].att + x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i, + ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay, + ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight) + + ww = w.blocks[i].ffn + x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i, + ww.time_mix_k, ww.time_mix_r, + ww.key.weight, ww.value.weight, ww.receptance.weight) if preprocess_only: return state