main
BlinkDL 3 years ago
parent daed379db2
commit 8a4a41a3aa

@ -3,11 +3,9 @@
######################################################################################################## ########################################################################################################
import numpy as np import numpy as np
import math, os, sys import math, os, sys, types, time, gc
import time
import torch import torch
from src.utils import TOKENIZER from src.utils import TOKENIZER
try: try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except: except:
@ -16,72 +14,67 @@ torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200) np.set_printoptions(precision=4, suppress=True, linewidth=200)
args = types.SimpleNamespace()
######################################################################################################## ########################################################################################################
# Step 1: set model # Step 1: set model & config
# # Do this first: pip install torchdynamo
# Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch.
#
# Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models.
######################################################################################################## ########################################################################################################
TOKEN_MODE = "pile" # char / bpe / pile args.RUN_DEVICE = "cpu" # 'cpu' (already very fast) // 'cuda'
args.FLOAT_MODE = "fp32" # fp32 // bf16 (saves VRAM, slightly less accurate)
n_layer = 6 # if args.RUN_DEVICE == "cuda":
n_embd = 512 # os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output
TOKEN_MODE = "pile"
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None
vocab_size = 50277
# note; you can set MODEL_NAME to your fine-tuned model
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
# n_layer = 12
# n_embd = 768
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066'
# n_layer = 24
# n_embd = 1024
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040'
# n_layer = 24
# n_embd = 2048
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220929-ctx4096'
# n_layer = 24
# n_embd = 2048
# ctx_len = 4096
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
n_layer = 32
n_embd = 2560
ctx_len = 1024 ctx_len = 1024
if TOKEN_MODE == "char": # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221004-3047'
MODEL_NAME = "trained-500" # your trained model # n_layer = 32
WORD_NAME = "vocab" # the .json vocab (generated by train.py) # n_embd = 4096
# set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it # ctx_len = 1024
UNKNOWN_CHAR = " " # here we just set it to ' ' for simplicity
args.MODEL_NAME = MODEL_NAME
elif TOKEN_MODE == "bpe": args.n_layer = n_layer
MODEL_NAME = "trained-500" # your trained model args.n_embd = n_embd
WORD_NAME = [ args.ctx_len = ctx_len
"model-vocab.json", args.vocab_size = vocab_size
"model-merges.txt", args.head_qk = 0
] # [vocab, merge] for your BPE model args.pre_ffn = 0
UNKNOWN_CHAR = None args.grad_cp = 0
args.my_pos_emb = 0
elif TOKEN_MODE == "pile": os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None
# ---> you can set MODEL_NAME to your fine-tuned model <---
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
# n_layer = 12
# n_embd = 768
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066'
# n_layer = 24
# n_embd = 1024
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220929-ctx4096'
# n_layer = 24
# n_embd = 2048
# ctx_len = 1024
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
n_layer = 32
n_embd = 2560
ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221004-3047'
# n_layer = 32
# n_embd = 4096
# ctx_len = 1024
os.environ["RWKV_FLOAT_MODE"] = "fp32" # fp32 (faster at this moment) or bf16 (slower but saves VRAM)
os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda'
model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre'
######################################################################################################## ########################################################################################################
# Step 2: set prompt & sampling stuffs # Step 2: set prompt & sampling stuffs
@ -128,12 +121,15 @@ DEBUG_DEBUG = False # True False --> show softmax output
######################################################################################################## ########################################################################################################
print(f'\nUsing {os.environ["RWKV_RUN_DEVICE"].upper()}. Loading {MODEL_NAME}...') print(f'\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...')
from src.model_run import RWKV_RNN from src.model_run import RWKV_RNN
model = RWKV_RNN( model = RWKV_RNN(args)
MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len
) print(f'\nOptimizing speed...')
model.forward([187], None)
gc.collect()
torch.cuda.empty_cache()
# input(0) # input(0)
@ -185,6 +181,8 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
init_out, init_state = model.forward(x, init_state) init_out, init_state = model.forward(x, init_state)
else: else:
init_state = model.forward(x, init_state, preprocess_only=True) init_state = model.forward(x, init_state, preprocess_only=True)
gc.collect()
torch.cuda.empty_cache()
record_time('preprocess') record_time('preprocess')
out_last = src_len out_last = src_len

@ -7,13 +7,15 @@ import torch
import math, os, gc import math, os, gc
from torch.nn import functional as F from torch.nn import functional as F
import torch.nn as nn import torch.nn as nn
from typing import List, Dict
# try:
# import torchdynamo
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
# except:
def __nop(ob): def __nop(ob):
return ob return ob
MyModule = nn.Module
MyFunction = __nop MyFunction = __nop
# MyModule = torch.jit.ScriptModule
# MyFunction = torch.jit.script_method
RWKV_HEAD_QK_DIM = 0 RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
@ -22,57 +24,52 @@ DEBUG_TIME = False # True False - show trained time-coeffs
############################################################################################################ ############################################################################################################
class RWKV_RNN(MyModule): class RWKV_RNN(nn.Module):
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len): def __init__(self, args):
super().__init__() super().__init__()
self.RUN_DEVICE = RUN_DEVICE self.args = args
self.model_type = model_type self.FLOAT_MODE = args.FLOAT_MODE
self.n_layer = n_layer self.RUN_DEVICE = args.RUN_DEVICE
self.n_embd = n_embd
self.ctx_len = ctx_len
w = torch.load(MODEL_NAME + '.pth', map_location='cpu') with torch.no_grad():
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
# refine weights and send to correct device # refine weights and send to correct device
keys = list(w.keys())
keys = list(w.keys()) if 'pos_emb_x' in keys:
if 'pos_emb_x' in keys: w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:]
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(ctx_len+1, -1)[:-1,:] keys = list(w.keys())
print_need_newline = False
keys = list(w.keys()) for x in keys:
print_need_newline = False if '.time_' in x:
for x in keys: w[x] = w[x].squeeze()
if '.time_' in x: if DEBUG_TIME:
w[x] = w[x].squeeze() print(x, w[x].numpy())
if DEBUG_TIME: if '.time_decay' in x:
print(x, w[x].numpy())
if '.time_decay' in x:
w[x] = w[x].float()
w[x] = -torch.exp(w[x])
elif '.time_first' in x:
w[x] = w[x].float()
else:
if os.environ["RWKV_FLOAT_MODE"] == "fp32":
w[x] = w[x].float() w[x] = w[x].float()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16": w[x] = -torch.exp(w[x])
w[x] = w[x].bfloat16() elif '.time_first' in x:
w[x] = w[x].float()
w[x].requires_grad = False else:
if RUN_DEVICE == 'cuda' and x != 'emb.weight': if self.FLOAT_MODE == "fp32":
w[x] = w[x].cuda() w[x] = w[x].float()
elif self.FLOAT_MODE == "bf16":
if ('blocks.' not in x) or ('blocks.0.' in x): w[x] = w[x].bfloat16()
if print_need_newline:
print('\n', end = '') w[x].requires_grad = False
print_need_newline = False if args.RUN_DEVICE == 'cuda' and x != 'emb.weight':
print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device) w[x] = w[x].cuda()
else:
print_need_newline = True if ('blocks.' not in x) or ('blocks.0.' in x):
print('.', end = '', flush = True) if print_need_newline:
print('\n', end = '')
print_need_newline = False
print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device)
else:
print_need_newline = True
print('.', end = '', flush = True)
# store weights in self.w # store weights in self.w
keys = list(w.keys()) keys = list(w.keys())
self.w = types.SimpleNamespace() self.w = types.SimpleNamespace()
for x in keys: for x in keys:
@ -98,91 +95,78 @@ class RWKV_RNN(MyModule):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@MyFunction
def LN(self, x, w): def LN(self, x, w):
return F.layer_norm(x, (self.n_embd,), weight=w.weight, bias=w.bias) return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
# state: ffn_xx att_xx att_aa att_bb att_pp # state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
@MyFunction @MyFunction
def FF(self, x, w, state, i): def FF(self, x, state, i, time_mix_k, time_mix_r, kw, vw, rw):
if os.environ["RWKV_FLOAT_MODE"] == "bf16": if self.FLOAT_MODE == "bf16":
xk = x * w.time_mix_k + state[5*i+0].bfloat16() * (1 - w.time_mix_k) xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
xr = x * w.time_mix_r + state[5*i+0].bfloat16() * (1 - w.time_mix_r) xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+0] = x.float() state[5*i+0] = x.float()
else: else:
xk = x * w.time_mix_k + state[5*i+0] * (1 - w.time_mix_k) xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * w.time_mix_r + state[5*i+0] * (1 - w.time_mix_r) xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
state[5*i+0] = x state[5*i+0] = x
r = torch.sigmoid(w.receptance.weight @ xr) r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(w.key.weight @ xk)) k = torch.square(torch.relu(kw @ xk))
kv = w.value.weight @ k kv = vw @ k
return r * kv return r * kv
@MyFunction @MyFunction
def SA(self, x, w, state, i): def SA(self, x, state, i, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
if os.environ["RWKV_FLOAT_MODE"] == "bf16": if self.FLOAT_MODE == "bf16":
xk = x * w.time_mix_k + state[5*i+1].bfloat16() * (1 - w.time_mix_k) xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
xv = x * w.time_mix_v + state[5*i+1].bfloat16() * (1 - w.time_mix_v) xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
xr = x * w.time_mix_r + state[5*i+1].bfloat16() * (1 - w.time_mix_r) xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+1] = x.float() state[5*i+1] = x.float()
else: else:
xk = x * w.time_mix_k + state[5*i+1] * (1 - w.time_mix_k) xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
xv = x * w.time_mix_v + state[5*i+1] * (1 - w.time_mix_v) xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
xr = x * w.time_mix_r + state[5*i+1] * (1 - w.time_mix_r) xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
state[5*i+1] = x state[5*i+1] = x
r = torch.sigmoid(w.receptance.weight @ xr) r = torch.sigmoid(rw @ xr)
k = kw @ xk
k = w.key.weight @ xk v = vw @ xv
v = w.value.weight @ xv
if os.environ["RWKV_FLOAT_MODE"] == "bf16": if self.FLOAT_MODE == "bf16":
kk = k.float() kk = k.float()
vv = v.float() vv = v.float()
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = w.time_first + kk
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * vv
b = e1 * bb + e2
ww = pp + w.time_decay
p = torch.maximum(ww, kk)
e1 = torch.exp(ww - p)
e2 = torch.exp(kk - p)
state[5*i+2] = e1 * aa + e2 * vv
state[5*i+3] = e1 * bb + e2
state[5*i+4] = p
rwkv = r * (a / b).bfloat16()
else: else:
aa = state[5*i+2] kk = k
bb = state[5*i+3] vv = v
pp = state[5*i+4] aa = state[5*i+2]
ww = w.time_first + k bb = state[5*i+3]
p = torch.maximum(pp, ww) pp = state[5*i+4]
e1 = torch.exp(pp - p) ww = time_first + kk
e2 = torch.exp(ww - p) p = torch.maximum(pp, ww)
a = e1 * aa + e2 * v e1 = torch.exp(pp - p)
b = e1 * bb + e2 e2 = torch.exp(ww - p)
ww = pp + w.time_decay a = e1 * aa + e2 * vv
p = torch.maximum(ww, k) b = e1 * bb + e2
e1 = torch.exp(ww - p) ww = pp + time_decay
e2 = torch.exp(k - p) p = torch.maximum(ww, kk)
state[5*i+2] = e1 * aa + e2 * v e1 = torch.exp(ww - p)
state[5*i+3] = e1 * bb + e2 e2 = torch.exp(kk - p)
state[5*i+4] = p state[5*i+2] = e1 * aa + e2 * vv
rwkv = r * a / b state[5*i+3] = e1 * bb + e2
state[5*i+4] = p
return w.output.weight @ rwkv if self.FLOAT_MODE == "bf16":
wkv = (a / b).type(torch.bfloat16)
else:
wkv = a / b
return ow @ (r * wkv)
def forward(self, ctx, state, preprocess_only = False): def forward(self, ctx, state, preprocess_only = False):
with torch.no_grad(): with torch.no_grad():
w = self.w w = self.w
args = self.args
x = w.emb.weight[ctx[-1]] x = w.emb.weight[ctx[-1]]
if self.RUN_DEVICE == 'cuda': if self.RUN_DEVICE == 'cuda':
@ -194,15 +178,23 @@ class RWKV_RNN(MyModule):
pass pass
if state == None: if state == None:
state = torch.zeros(self.n_layer * 5, self.n_embd, device=self.RUN_DEVICE) state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE)
for i in range(self.n_layer): for i in range(args.n_layer):
state[5*i+4] -= 1e30 state[5*i+4] -= 1e30
for i in range(self.n_layer): for i in range(args.n_layer):
if i == 0: if i == 0:
x = self.LN(x, w.blocks[i].ln0) x = self.LN(x, w.blocks[i].ln0)
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, state, i)
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, state, i) ww = w.blocks[i].att
x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i,
ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay,
ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight)
ww = w.blocks[i].ffn
x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i,
ww.time_mix_k, ww.time_mix_r,
ww.key.weight, ww.value.weight, ww.receptance.weight)
if preprocess_only: if preprocess_only:
return state return state

Loading…
Cancel
Save