|
|
########################################################################################################
|
|
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
|
########################################################################################################
|
|
|
|
|
|
import numpy as np
|
|
|
import math, os
|
|
|
import time
|
|
|
import types
|
|
|
import copy
|
|
|
import torch
|
|
|
from src.utils import TOKENIZER
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
|
|
|
|
########################################################################################################
|
|
|
# Step 1: set model
|
|
|
#
|
|
|
# Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch.
|
|
|
#
|
|
|
# Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models.
|
|
|
########################################################################################################
|
|
|
|
|
|
TOKEN_MODE = "pile" # char / bpe / pile
|
|
|
|
|
|
n_layer = 6
|
|
|
n_embd = 512
|
|
|
ctx_len = 1024
|
|
|
|
|
|
if TOKEN_MODE == "char":
|
|
|
MODEL_NAME = "trained-500" # your trained model
|
|
|
WORD_NAME = "vocab" # the .json vocab (generated by train.py)
|
|
|
# set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it
|
|
|
UNKNOWN_CHAR = " " # here we just set it to ' ' for simplicity
|
|
|
|
|
|
elif TOKEN_MODE == "bpe":
|
|
|
MODEL_NAME = "trained-500" # your trained model
|
|
|
WORD_NAME = [
|
|
|
"model-vocab.json",
|
|
|
"model-merges.txt",
|
|
|
] # [vocab, merge] for your BPE model
|
|
|
UNKNOWN_CHAR = None
|
|
|
|
|
|
elif TOKEN_MODE == "pile":
|
|
|
WORD_NAME = [
|
|
|
"20B_tokenizer.json",
|
|
|
"20B_tokenizer.json",
|
|
|
] # [vocab, vocab] for Pile model
|
|
|
UNKNOWN_CHAR = None
|
|
|
|
|
|
# ---> you can set MODEL_NAME to your fine-tuned model <---
|
|
|
|
|
|
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
|
|
|
# n_layer = 12
|
|
|
# n_embd = 768
|
|
|
# ctx_len = 1024
|
|
|
|
|
|
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066'
|
|
|
# n_layer = 24
|
|
|
# n_embd = 1024
|
|
|
# ctx_len = 1024
|
|
|
|
|
|
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-1B5-20220903-8040'
|
|
|
# n_layer = 24
|
|
|
# n_embd = 2048
|
|
|
# ctx_len = 1024
|
|
|
|
|
|
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20220925-4537'
|
|
|
n_layer = 32
|
|
|
n_embd = 2560
|
|
|
ctx_len = 1024
|
|
|
|
|
|
os.environ["RWKV_FLOAT_MODE"] = "fp32" # currently only supprts fp32
|
|
|
os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda'
|
|
|
model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre'
|
|
|
|
|
|
########################################################################################################
|
|
|
# Step 2: set prompt & sampling stuffs
|
|
|
########################################################################################################
|
|
|
|
|
|
# context = 'A'
|
|
|
# context = "\nIn the"
|
|
|
# context = '\nSugar:'
|
|
|
# context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
|
|
|
|
|
|
context = "\n深圳是" # test Chinese
|
|
|
context = "\n東京は" # test Japanese
|
|
|
|
|
|
# context = ''' # A good prompt for chatbot
|
|
|
# The following is a conversation between a highly knowledgeable and intelligent AI assistant, called RWKV, and a human user, called User. In the following interactions, User and RWKV will converse in natural language, and RWKV will do its best to answer User’s questions. RWKV was built to be respectful, polite and inclusive. It knows a lot, and always tells the truth. The conversation begins.
|
|
|
|
|
|
# User: OK RWKV, I’m going to start by quizzing you with a few warm-up questions. Who is currently the president of the USA?
|
|
|
|
|
|
# RWKV: It’s Joe Biden; he was sworn in earlier this year.
|
|
|
|
|
|
# User: What year was the French Revolution?
|
|
|
|
|
|
# RWKV: It started in 1789, but it lasted 10 years until 1799.
|
|
|
|
|
|
# User: Can you guess who I might want to marry?
|
|
|
|
|
|
# RWKV: Only if you tell me more about yourself - what are your interests?
|
|
|
|
|
|
# User: Aha, I’m going to refrain from that for now. Now for a science question. What can you tell me about the Large Hadron Collider (LHC)?
|
|
|
|
|
|
# RWKV: It’s a large and very expensive piece of science equipment. If I understand correctly, it’s a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
|
|
|
|
|
|
# User:'''
|
|
|
|
|
|
NUM_TRIALS = 999
|
|
|
LENGTH_PER_TRIAL = 333
|
|
|
|
|
|
TEMPERATURE = 1.0
|
|
|
top_p = 0.8
|
|
|
top_p_newline = 0.9 # only used in TOKEN_MODE = char
|
|
|
|
|
|
DEBUG_DEBUG = False # True False --> show softmax output
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
print(f"Loading {MODEL_NAME}...")
|
|
|
from src.model_run import RWKV_RNN
|
|
|
|
|
|
model = RWKV_RNN(
|
|
|
MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len
|
|
|
)
|
|
|
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
if tokenizer.charMode:
|
|
|
context = tokenizer.refine_context(context)
|
|
|
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
|
|
|
else:
|
|
|
ctx = tokenizer.tokenizer.encode(context)
|
|
|
src_len = len(ctx)
|
|
|
src_ctx = ctx.copy()
|
|
|
|
|
|
print("\nYour prompt has " + str(src_len) + " tokens.")
|
|
|
print(
|
|
|
"\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n"
|
|
|
)
|
|
|
|
|
|
# time_slot = {}
|
|
|
# time_ref = time.time_ns()
|
|
|
|
|
|
# def record_time(name):
|
|
|
# if name not in time_slot:
|
|
|
# time_slot[name] = 1e20
|
|
|
# tt = (time.time_ns() - time_ref) / 1e9
|
|
|
# if tt < time_slot[name]:
|
|
|
# time_slot[name] = tt
|
|
|
|
|
|
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
|
|
|
# time_ref = time.time_ns()
|
|
|
|
|
|
print(("-" * 50) + context, end="")
|
|
|
ctx = src_ctx.copy()
|
|
|
model.clear()
|
|
|
|
|
|
if TRIAL == 0:
|
|
|
init_state = types.SimpleNamespace()
|
|
|
for i in range(src_len):
|
|
|
x = ctx[: i + 1]
|
|
|
if i == src_len - 1:
|
|
|
init_state.out = model.forward(x)
|
|
|
else:
|
|
|
model.forward(x, preprocess_only=True)
|
|
|
model.save(init_state)
|
|
|
else:
|
|
|
model.load(init_state)
|
|
|
|
|
|
# record_time('model_pre')
|
|
|
out_last = src_len
|
|
|
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
|
|
|
# time_ref = time.time_ns()
|
|
|
|
|
|
x = ctx[: i + 1]
|
|
|
x = x[-ctx_len:]
|
|
|
|
|
|
if i == src_len:
|
|
|
out = copy.deepcopy(init_state.out)
|
|
|
else:
|
|
|
out = model.forward(x)
|
|
|
# record_time('model_run')
|
|
|
if DEBUG_DEBUG:
|
|
|
print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy()))
|
|
|
|
|
|
if TOKEN_MODE == "pile":
|
|
|
out[0] = -999999999 # disable <|endoftext|>
|
|
|
|
|
|
time_ref = time.time_ns()
|
|
|
ttt = tokenizer.sample_logits(
|
|
|
out,
|
|
|
x,
|
|
|
ctx_len,
|
|
|
temperature=TEMPERATURE,
|
|
|
top_p_usual=top_p,
|
|
|
top_p_newline=top_p_newline,
|
|
|
)
|
|
|
ctx += [ttt]
|
|
|
|
|
|
if tokenizer.charMode:
|
|
|
char = tokenizer.itos[ttt]
|
|
|
print(char, end="", flush=True)
|
|
|
else:
|
|
|
char = tokenizer.tokenizer.decode(ctx[out_last:])
|
|
|
if '\ufffd' not in char:
|
|
|
print(char, end="", flush=True)
|
|
|
out_last = i+1
|
|
|
|
|
|
# record_time('model_sampling')
|
|
|
print()
|
|
|
# print(f'\n\n{time_slot}\n\n')
|
|
|
# print(
|
|
|
# f"\n--- preprocess {round((t_mid - t_begin) / (10 ** 9), 2)}s, generation {round((t_end - t_mid) / (10 ** 9), 2)}s", end = ''
|
|
|
# )
|