main
BlinkDL 3 years ago
parent daed379db2
commit 8a4a41a3aa

@ -3,11 +3,9 @@
######################################################################################################## ########################################################################################################
import numpy as np import numpy as np
import math, os, sys import math, os, sys, types, time, gc
import time
import torch import torch
from src.utils import TOKENIZER from src.utils import TOKENIZER
try: try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except: except:
@ -16,72 +14,67 @@ torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200) np.set_printoptions(precision=4, suppress=True, linewidth=200)
args = types.SimpleNamespace()
######################################################################################################## ########################################################################################################
# Step 1: set model # Step 1: set model & config
# # Do this first: pip install torchdynamo
# Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch.
#
# Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models.
######################################################################################################## ########################################################################################################
TOKEN_MODE = "pile" # char / bpe / pile args.RUN_DEVICE = "cpu" # 'cpu' (already very fast) // 'cuda'
args.FLOAT_MODE = "fp32" # fp32 // bf16 (saves VRAM, slightly less accurate)
n_layer = 6 # if args.RUN_DEVICE == "cuda":
n_embd = 512 # os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output
ctx_len = 1024
if TOKEN_MODE == "char": TOKEN_MODE = "pile"
MODEL_NAME = "trained-500" # your trained model WORD_NAME = [
WORD_NAME = "vocab" # the .json vocab (generated by train.py)
# set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it
UNKNOWN_CHAR = " " # here we just set it to ' ' for simplicity
elif TOKEN_MODE == "bpe":
MODEL_NAME = "trained-500" # your trained model
WORD_NAME = [
"model-vocab.json",
"model-merges.txt",
] # [vocab, merge] for your BPE model
UNKNOWN_CHAR = None
elif TOKEN_MODE == "pile":
WORD_NAME = [
"20B_tokenizer.json", "20B_tokenizer.json",
"20B_tokenizer.json", "20B_tokenizer.json",
] # [vocab, vocab] for Pile model ] # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None UNKNOWN_CHAR = None
vocab_size = 50277
# ---> you can set MODEL_NAME to your fine-tuned model <---
# note; you can set MODEL_NAME to your fine-tuned model
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023" # MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
# n_layer = 12 # n_layer = 12
# n_embd = 768 # n_embd = 768
# ctx_len = 1024 # ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066' # MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066'
# n_layer = 24 # n_layer = 24
# n_embd = 1024 # n_embd = 1024
# ctx_len = 1024 # ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220929-ctx4096' # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040'
# n_layer = 24 # n_layer = 24
# n_embd = 2048 # n_embd = 2048
# ctx_len = 1024 # ctx_len = 1024
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783' # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220929-ctx4096'
n_layer = 32 # n_layer = 24
n_embd = 2560 # n_embd = 2048
ctx_len = 1024 # ctx_len = 4096
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221004-3047' MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
# n_layer = 32 n_layer = 32
# n_embd = 4096 n_embd = 2560
# ctx_len = 1024 ctx_len = 1024
os.environ["RWKV_FLOAT_MODE"] = "fp32" # fp32 (faster at this moment) or bf16 (slower but saves VRAM) # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221004-3047'
os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda' # n_layer = 32
model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre' # n_embd = 4096
# ctx_len = 1024
args.MODEL_NAME = MODEL_NAME
args.n_layer = n_layer
args.n_embd = n_embd
args.ctx_len = ctx_len
args.vocab_size = vocab_size
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
######################################################################################################## ########################################################################################################
# Step 2: set prompt & sampling stuffs # Step 2: set prompt & sampling stuffs
@ -128,12 +121,15 @@ DEBUG_DEBUG = False # True False --> show softmax output
######################################################################################################## ########################################################################################################
print(f'\nUsing {os.environ["RWKV_RUN_DEVICE"].upper()}. Loading {MODEL_NAME}...') print(f'\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...')
from src.model_run import RWKV_RNN from src.model_run import RWKV_RNN
model = RWKV_RNN( model = RWKV_RNN(args)
MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len
) print(f'\nOptimizing speed...')
model.forward([187], None)
gc.collect()
torch.cuda.empty_cache()
# input(0) # input(0)
@ -185,6 +181,8 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
init_out, init_state = model.forward(x, init_state) init_out, init_state = model.forward(x, init_state)
else: else:
init_state = model.forward(x, init_state, preprocess_only=True) init_state = model.forward(x, init_state, preprocess_only=True)
gc.collect()
torch.cuda.empty_cache()
record_time('preprocess') record_time('preprocess')
out_last = src_len out_last = src_len

@ -7,13 +7,15 @@ import torch
import math, os, gc import math, os, gc
from torch.nn import functional as F from torch.nn import functional as F
import torch.nn as nn import torch.nn as nn
from typing import List, Dict
# try:
# import torchdynamo
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
# except:
def __nop(ob): def __nop(ob):
return ob return ob
MyModule = nn.Module
MyFunction = __nop MyFunction = __nop
# MyModule = torch.jit.ScriptModule
# MyFunction = torch.jit.script_method
RWKV_HEAD_QK_DIM = 0 RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
@ -22,24 +24,20 @@ DEBUG_TIME = False # True False - show trained time-coeffs
############################################################################################################ ############################################################################################################
class RWKV_RNN(MyModule): class RWKV_RNN(nn.Module):
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len): def __init__(self, args):
super().__init__() super().__init__()
self.RUN_DEVICE = RUN_DEVICE self.args = args
self.model_type = model_type self.FLOAT_MODE = args.FLOAT_MODE
self.n_layer = n_layer self.RUN_DEVICE = args.RUN_DEVICE
self.n_embd = n_embd
self.ctx_len = ctx_len
w = torch.load(MODEL_NAME + '.pth', map_location='cpu')
with torch.no_grad():
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
# refine weights and send to correct device # refine weights and send to correct device
keys = list(w.keys()) keys = list(w.keys())
if 'pos_emb_x' in keys: if 'pos_emb_x' in keys:
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(ctx_len+1, -1)[:-1,:] w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:]
keys = list(w.keys()) keys = list(w.keys())
print_need_newline = False print_need_newline = False
for x in keys: for x in keys:
@ -53,13 +51,13 @@ class RWKV_RNN(MyModule):
elif '.time_first' in x: elif '.time_first' in x:
w[x] = w[x].float() w[x] = w[x].float()
else: else:
if os.environ["RWKV_FLOAT_MODE"] == "fp32": if self.FLOAT_MODE == "fp32":
w[x] = w[x].float() w[x] = w[x].float()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16": elif self.FLOAT_MODE == "bf16":
w[x] = w[x].bfloat16() w[x] = w[x].bfloat16()
w[x].requires_grad = False w[x].requires_grad = False
if RUN_DEVICE == 'cuda' and x != 'emb.weight': if args.RUN_DEVICE == 'cuda' and x != 'emb.weight':
w[x] = w[x].cuda() w[x] = w[x].cuda()
if ('blocks.' not in x) or ('blocks.0.' in x): if ('blocks.' not in x) or ('blocks.0.' in x):
@ -72,7 +70,6 @@ class RWKV_RNN(MyModule):
print('.', end = '', flush = True) print('.', end = '', flush = True)
# store weights in self.w # store weights in self.w
keys = list(w.keys()) keys = list(w.keys())
self.w = types.SimpleNamespace() self.w = types.SimpleNamespace()
for x in keys: for x in keys:
@ -98,91 +95,78 @@ class RWKV_RNN(MyModule):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@MyFunction
def LN(self, x, w): def LN(self, x, w):
return F.layer_norm(x, (self.n_embd,), weight=w.weight, bias=w.bias) return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
# state: ffn_xx att_xx att_aa att_bb att_pp # state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
@MyFunction @MyFunction
def FF(self, x, w, state, i): def FF(self, x, state, i, time_mix_k, time_mix_r, kw, vw, rw):
if os.environ["RWKV_FLOAT_MODE"] == "bf16": if self.FLOAT_MODE == "bf16":
xk = x * w.time_mix_k + state[5*i+0].bfloat16() * (1 - w.time_mix_k) xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
xr = x * w.time_mix_r + state[5*i+0].bfloat16() * (1 - w.time_mix_r) xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+0] = x.float() state[5*i+0] = x.float()
else: else:
xk = x * w.time_mix_k + state[5*i+0] * (1 - w.time_mix_k) xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * w.time_mix_r + state[5*i+0] * (1 - w.time_mix_r) xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
state[5*i+0] = x state[5*i+0] = x
r = torch.sigmoid(w.receptance.weight @ xr) r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(w.key.weight @ xk)) k = torch.square(torch.relu(kw @ xk))
kv = w.value.weight @ k kv = vw @ k
return r * kv return r * kv
@MyFunction @MyFunction
def SA(self, x, w, state, i): def SA(self, x, state, i, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
if os.environ["RWKV_FLOAT_MODE"] == "bf16": if self.FLOAT_MODE == "bf16":
xk = x * w.time_mix_k + state[5*i+1].bfloat16() * (1 - w.time_mix_k) xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
xv = x * w.time_mix_v + state[5*i+1].bfloat16() * (1 - w.time_mix_v) xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
xr = x * w.time_mix_r + state[5*i+1].bfloat16() * (1 - w.time_mix_r) xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+1] = x.float() state[5*i+1] = x.float()
else: else:
xk = x * w.time_mix_k + state[5*i+1] * (1 - w.time_mix_k) xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
xv = x * w.time_mix_v + state[5*i+1] * (1 - w.time_mix_v) xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
xr = x * w.time_mix_r + state[5*i+1] * (1 - w.time_mix_r) xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
state[5*i+1] = x state[5*i+1] = x
r = torch.sigmoid(w.receptance.weight @ xr) r = torch.sigmoid(rw @ xr)
k = kw @ xk
k = w.key.weight @ xk v = vw @ xv
v = w.value.weight @ xv
if os.environ["RWKV_FLOAT_MODE"] == "bf16": if self.FLOAT_MODE == "bf16":
kk = k.float() kk = k.float()
vv = v.float() vv = v.float()
else:
kk = k
vv = v
aa = state[5*i+2] aa = state[5*i+2]
bb = state[5*i+3] bb = state[5*i+3]
pp = state[5*i+4] pp = state[5*i+4]
ww = w.time_first + kk ww = time_first + kk
p = torch.maximum(pp, ww) p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p) e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p) e2 = torch.exp(ww - p)
a = e1 * aa + e2 * vv a = e1 * aa + e2 * vv
b = e1 * bb + e2 b = e1 * bb + e2
ww = pp + w.time_decay ww = pp + time_decay
p = torch.maximum(ww, kk) p = torch.maximum(ww, kk)
e1 = torch.exp(ww - p) e1 = torch.exp(ww - p)
e2 = torch.exp(kk - p) e2 = torch.exp(kk - p)
state[5*i+2] = e1 * aa + e2 * vv state[5*i+2] = e1 * aa + e2 * vv
state[5*i+3] = e1 * bb + e2 state[5*i+3] = e1 * bb + e2
state[5*i+4] = p state[5*i+4] = p
rwkv = r * (a / b).bfloat16() if self.FLOAT_MODE == "bf16":
wkv = (a / b).type(torch.bfloat16)
else: else:
aa = state[5*i+2] wkv = a / b
bb = state[5*i+3]
pp = state[5*i+4]
ww = w.time_first + k
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * v
b = e1 * bb + e2
ww = pp + w.time_decay
p = torch.maximum(ww, k)
e1 = torch.exp(ww - p)
e2 = torch.exp(k - p)
state[5*i+2] = e1 * aa + e2 * v
state[5*i+3] = e1 * bb + e2
state[5*i+4] = p
rwkv = r * a / b
return w.output.weight @ rwkv return ow @ (r * wkv)
def forward(self, ctx, state, preprocess_only = False): def forward(self, ctx, state, preprocess_only = False):
with torch.no_grad(): with torch.no_grad():
w = self.w w = self.w
args = self.args
x = w.emb.weight[ctx[-1]] x = w.emb.weight[ctx[-1]]
if self.RUN_DEVICE == 'cuda': if self.RUN_DEVICE == 'cuda':
@ -194,15 +178,23 @@ class RWKV_RNN(MyModule):
pass pass
if state == None: if state == None:
state = torch.zeros(self.n_layer * 5, self.n_embd, device=self.RUN_DEVICE) state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE)
for i in range(self.n_layer): for i in range(args.n_layer):
state[5*i+4] -= 1e30 state[5*i+4] -= 1e30
for i in range(self.n_layer): for i in range(args.n_layer):
if i == 0: if i == 0:
x = self.LN(x, w.blocks[i].ln0) x = self.LN(x, w.blocks[i].ln0)
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, state, i)
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, state, i) ww = w.blocks[i].att
x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i,
ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay,
ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight)
ww = w.blocks[i].ffn
x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i,
ww.time_mix_k, ww.time_mix_r,
ww.key.weight, ww.value.weight, ww.receptance.weight)
if preprocess_only: if preprocess_only:
return state return state

Loading…
Cancel
Save