You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
RWKV-LM/RWKV-v4neo/run.py

220 lines
7.6 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math, os
import time
import types
import copy
import torch
from src.utils import TOKENIZER
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
########################################################################################################
# Step 1: set model
#
# Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch.
#
# Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models.
########################################################################################################
TOKEN_MODE = "pile" # char / bpe / pile
n_layer = 6
n_embd = 512
ctx_len = 1024
if TOKEN_MODE == "char":
MODEL_NAME = "trained-500" # your trained model
WORD_NAME = "vocab" # the .json vocab (generated by train.py)
# set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it
UNKNOWN_CHAR = " " # here we just set it to ' ' for simplicity
elif TOKEN_MODE == "bpe":
MODEL_NAME = "trained-500" # your trained model
WORD_NAME = [
"model-vocab.json",
"model-merges.txt",
] # [vocab, merge] for your BPE model
UNKNOWN_CHAR = None
elif TOKEN_MODE == "pile":
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None
# ---> you can set MODEL_NAME to your fine-tuned model <---
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
# n_layer = 12
# n_embd = 768
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066'
# n_layer = 24
# n_embd = 1024
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-1B5-20220903-8040'
# n_layer = 24
# n_embd = 2048
# ctx_len = 1024
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20220925-4537'
n_layer = 32
n_embd = 2560
ctx_len = 1024
os.environ["RWKV_FLOAT_MODE"] = "fp32" # currently only supprts fp32
os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda'
model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre'
########################################################################################################
# Step 2: set prompt & sampling stuffs
########################################################################################################
# context = 'A'
# context = "\nIn the"
# context = '\nSugar:'
# context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
context = "\n深圳是" # test Chinese
context = "\n東京は" # test Japanese
# context = ''' # A good prompt for chatbot
# The following is a conversation between a highly knowledgeable and intelligent AI assistant, called RWKV, and a human user, called User. In the following interactions, User and RWKV will converse in natural language, and RWKV will do its best to answer Users questions. RWKV was built to be respectful, polite and inclusive. It knows a lot, and always tells the truth. The conversation begins.
# User: OK RWKV, Im going to start by quizzing you with a few warm-up questions. Who is currently the president of the USA?
# RWKV: Its Joe Biden; he was sworn in earlier this year.
# User: What year was the French Revolution?
# RWKV: It started in 1789, but it lasted 10 years until 1799.
# User: Can you guess who I might want to marry?
# RWKV: Only if you tell me more about yourself - what are your interests?
# User: Aha, Im going to refrain from that for now. Now for a science question. What can you tell me about the Large Hadron Collider (LHC)?
# RWKV: Its a large and very expensive piece of science equipment. If I understand correctly, its a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
# User:'''
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 333
TEMPERATURE = 1.0
top_p = 0.8
top_p_newline = 0.9 # only used in TOKEN_MODE = char
DEBUG_DEBUG = False # True False --> show softmax output
########################################################################################################
print(f"Loading {MODEL_NAME}...")
from src.model_run import RWKV_RNN
model = RWKV_RNN(
MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len
)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
########################################################################################################
if tokenizer.charMode:
context = tokenizer.refine_context(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
else:
ctx = tokenizer.tokenizer.encode(context)
src_len = len(ctx)
src_ctx = ctx.copy()
print("\nYour prompt has " + str(src_len) + " tokens.")
print(
"\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n"
)
# time_slot = {}
# time_ref = time.time_ns()
# def record_time(name):
# if name not in time_slot:
# time_slot[name] = 1e20
# tt = (time.time_ns() - time_ref) / 1e9
# if tt < time_slot[name]:
# time_slot[name] = tt
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
# time_ref = time.time_ns()
print(("-" * 50) + context, end="")
ctx = src_ctx.copy()
model.clear()
if TRIAL == 0:
init_state = types.SimpleNamespace()
for i in range(src_len):
x = ctx[: i + 1]
if i == src_len - 1:
init_state.out = model.forward(x)
else:
model.forward(x, preprocess_only=True)
model.save(init_state)
else:
model.load(init_state)
# record_time('model_pre')
out_last = src_len
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
# time_ref = time.time_ns()
x = ctx[: i + 1]
x = x[-ctx_len:]
if i == src_len:
out = copy.deepcopy(init_state.out)
else:
out = model.forward(x)
# record_time('model_run')
if DEBUG_DEBUG:
print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy()))
if TOKEN_MODE == "pile":
out[0] = -999999999 # disable <|endoftext|>
time_ref = time.time_ns()
ttt = tokenizer.sample_logits(
out,
x,
ctx_len,
temperature=TEMPERATURE,
top_p_usual=top_p,
top_p_newline=top_p_newline,
)
ctx += [ttt]
if tokenizer.charMode:
char = tokenizer.itos[ttt]
print(char, end="", flush=True)
else:
char = tokenizer.tokenizer.decode(ctx[out_last:])
if '\ufffd' not in char:
print(char, end="", flush=True)
out_last = i+1
# record_time('model_sampling')
print()
# print(f'\n\n{time_slot}\n\n')
# print(
# f"\n--- preprocess {round((t_mid - t_begin) / (10 ** 9), 2)}s, generation {round((t_end - t_mid) / (10 ** 9), 2)}s", end = ''
# )