|
|
########################################################################################################
|
|
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
|
########################################################################################################
|
|
|
|
|
|
import numpy as np
|
|
|
import math, os, sys, types, time, gc
|
|
|
import torch
|
|
|
from src.utils import TOKENIZER
|
|
|
try:
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
|
|
|
except:
|
|
|
pass
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
|
args = types.SimpleNamespace()
|
|
|
|
|
|
########################################################################################################
|
|
|
# Step 1: set model & config
|
|
|
# Do this first: pip install torchdynamo
|
|
|
########################################################################################################
|
|
|
|
|
|
args.RUN_DEVICE = "cpu" # 'cpu' (already very fast) // 'cuda'
|
|
|
args.FLOAT_MODE = "fp32" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate)
|
|
|
# if args.RUN_DEVICE == "cuda":
|
|
|
# os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output
|
|
|
|
|
|
TOKEN_MODE = "pile"
|
|
|
WORD_NAME = [
|
|
|
"20B_tokenizer.json",
|
|
|
"20B_tokenizer.json",
|
|
|
] # [vocab, vocab] for Pile model
|
|
|
UNKNOWN_CHAR = None
|
|
|
vocab_size = 50277
|
|
|
|
|
|
# Download Pile models: https://huggingface.co/BlinkDL
|
|
|
# or, set MODEL_NAME to your fine-tuned model
|
|
|
|
|
|
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
|
|
|
# n_layer = 12
|
|
|
# n_embd = 768
|
|
|
# ctx_len = 1024
|
|
|
|
|
|
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066'
|
|
|
# n_layer = 24
|
|
|
# n_embd = 1024
|
|
|
# ctx_len = 1024
|
|
|
|
|
|
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040'
|
|
|
# n_layer = 24
|
|
|
# n_embd = 2048
|
|
|
# ctx_len = 1024
|
|
|
|
|
|
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023'
|
|
|
# n_layer = 32
|
|
|
# n_embd = 2560
|
|
|
# ctx_len = 1024
|
|
|
|
|
|
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
|
|
|
n_layer = 32
|
|
|
n_embd = 4096
|
|
|
ctx_len = 1024
|
|
|
|
|
|
args.MODEL_NAME = MODEL_NAME
|
|
|
args.n_layer = n_layer
|
|
|
args.n_embd = n_embd
|
|
|
args.ctx_len = ctx_len
|
|
|
args.vocab_size = vocab_size
|
|
|
args.head_qk = 0
|
|
|
args.pre_ffn = 0
|
|
|
args.grad_cp = 0
|
|
|
args.my_pos_emb = 0
|
|
|
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
|
|
|
|
|
|
########################################################################################################
|
|
|
# Step 2: set prompt & sampling stuffs
|
|
|
########################################################################################################
|
|
|
|
|
|
# context = 'A'
|
|
|
# context = "\nIn the"
|
|
|
# context = '\nSugar:'
|
|
|
context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
|
|
|
|
|
|
# context = "\n深圳是" # test Chinese
|
|
|
# context = "\n東京は" # test Japanese
|
|
|
|
|
|
###### A good prompt for chatbot ######
|
|
|
# context = '''
|
|
|
# The following is a conversation between a highly knowledgeable and intelligent AI assistant called Bot, and a human user called User. In the following interactions, User and Bot converse in natural language, and Bot always answer User's questions. Bot is very smart, polite and humorous. Bot knows a lot, and always tells the truth. The conversation begins.
|
|
|
|
|
|
# User: who is president of usa?
|
|
|
|
|
|
# Bot: It’s Joe Biden; he was sworn in earlier this year.
|
|
|
|
|
|
# User: french revolution what year
|
|
|
|
|
|
# Bot: It started in 1789, but it lasted 10 years until 1799.
|
|
|
|
|
|
# User: guess i marry who ?
|
|
|
|
|
|
# Bot: Only if you tell me more about yourself - what are your interests?
|
|
|
|
|
|
# User: wat is lhc
|
|
|
|
|
|
# Bot: It’s a large and very expensive piece of science equipment. If I understand correctly, it’s a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
|
|
|
|
|
|
# User:''' # type your question here
|
|
|
|
|
|
NUM_TRIALS = 999
|
|
|
LENGTH_PER_TRIAL = 333
|
|
|
|
|
|
TEMPERATURE = 1.0
|
|
|
top_p = 0.8
|
|
|
top_p_newline = 0.9 # only used in TOKEN_MODE = char
|
|
|
|
|
|
DEBUG_DEBUG = False # True False --> show softmax output
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
print(f'\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...')
|
|
|
from src.model_run import RWKV_RNN
|
|
|
|
|
|
model = RWKV_RNN(args)
|
|
|
|
|
|
print(f'\nOptimizing speed...')
|
|
|
out, _ = model.forward([187], None)
|
|
|
# print(out)
|
|
|
gc.collect()
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
# input(0)
|
|
|
|
|
|
print(f'\nLoading tokenizer {WORD_NAME}...')
|
|
|
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
|
|
if TOKEN_MODE == "pile":
|
|
|
assert tokenizer.tokenizer.decode([187]) == '\n'
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
if tokenizer.charMode:
|
|
|
context = tokenizer.refine_context(context)
|
|
|
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
|
|
|
else:
|
|
|
ctx = tokenizer.tokenizer.encode(context)
|
|
|
src_len = len(ctx)
|
|
|
src_ctx = ctx.copy()
|
|
|
|
|
|
print("\nYour prompt has " + str(src_len) + " tokens.")
|
|
|
print(
|
|
|
"Note: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n"
|
|
|
)
|
|
|
|
|
|
time_slot = {}
|
|
|
time_ref = time.time_ns()
|
|
|
|
|
|
def record_time(name):
|
|
|
if name not in time_slot:
|
|
|
time_slot[name] = 1e20
|
|
|
tt = (time.time_ns() - time_ref) / 1e9
|
|
|
if tt < time_slot[name]:
|
|
|
time_slot[name] = tt
|
|
|
|
|
|
init_state = None
|
|
|
init_out = None
|
|
|
state = None
|
|
|
out = None
|
|
|
|
|
|
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
|
|
|
print(("-" * 50) + '\n' + context, end="")
|
|
|
|
|
|
time_ref = time.time_ns()
|
|
|
ctx = src_ctx.copy()
|
|
|
|
|
|
if TRIAL == 0:
|
|
|
for i in range(src_len):
|
|
|
x = ctx[: i + 1]
|
|
|
if i == src_len - 1:
|
|
|
init_out, init_state = model.forward(x, init_state)
|
|
|
else:
|
|
|
init_state = model.forward(x, init_state, preprocess_only=True)
|
|
|
gc.collect()
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
record_time('preprocess')
|
|
|
out_last = src_len
|
|
|
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
|
|
|
x = ctx[: i + 1]
|
|
|
x = x[-ctx_len:]
|
|
|
|
|
|
if i == src_len:
|
|
|
out = init_out.clone()
|
|
|
state = init_state.clone()
|
|
|
else:
|
|
|
out, state = model.forward(x, state)
|
|
|
if DEBUG_DEBUG:
|
|
|
print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy()))
|
|
|
if TOKEN_MODE == "pile":
|
|
|
out[0] = -999999999 # disable <|endoftext|>
|
|
|
|
|
|
ttt = tokenizer.sample_logits(
|
|
|
out,
|
|
|
x,
|
|
|
ctx_len,
|
|
|
temperature=TEMPERATURE,
|
|
|
top_p_usual=top_p,
|
|
|
top_p_newline=top_p_newline,
|
|
|
)
|
|
|
ctx += [ttt]
|
|
|
|
|
|
if tokenizer.charMode:
|
|
|
char = tokenizer.itos[ttt]
|
|
|
print(char, end="", flush=True)
|
|
|
else:
|
|
|
char = tokenizer.tokenizer.decode(ctx[out_last:])
|
|
|
if '\ufffd' not in char:
|
|
|
print(char, end="", flush=True)
|
|
|
out_last = i+1
|
|
|
|
|
|
record_time('total')
|
|
|
# print(f'\n\n{time_slot}\n\n')
|
|
|
print(
|
|
|
f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = ''
|
|
|
)
|
|
|
|
|
|
print(("-" * 50) + '\n')
|