You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
RWKV-LM/RWKV-v4neo/run.py

185 lines
6.0 KiB
Python

########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math, os
import time
import types
import copy
import torch
from src.utils import TOKENIZER
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
########################################################################################################
# Step 1: set model
#
# Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch.
#
# Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models.
########################################################################################################
TOKEN_MODE = "pile" # char / bpe / pile
n_layer = 6
n_embd = 512
ctx_len = 1024
if TOKEN_MODE == "char":
MODEL_NAME = "trained-500" # your trained model
WORD_NAME = "vocab" # the .json vocab (generated by train.py)
# set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it
UNKNOWN_CHAR = " " # here we just set it to ' ' for simplicity
elif TOKEN_MODE == "bpe":
MODEL_NAME = "trained-500" # your trained model
WORD_NAME = [
"model-vocab.json",
"model-merges.txt",
] # [vocab, merge] for your BPE model
UNKNOWN_CHAR = None
elif TOKEN_MODE == "pile":
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None
# ---> you can set MODEL_NAME to your fine-tuned model <---
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
# n_layer = 12
# n_embd = 768
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066'
# n_layer = 24
# n_embd = 1024
# ctx_len = 1024
MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-1B5-20220903-8040'
n_layer = 24
n_embd = 2048
ctx_len = 1024
os.environ["RWKV_FLOAT_MODE"] = "fp32" # currently only supprts fp32
os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda'
model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre'
########################################################################################################
# Step 2: set prompt & sampling stuffs
########################################################################################################
# context = 'A'
# context = "\nIn the"
# context = '\nSugar:'
context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 333
TEMPERATURE = 1.0
top_p = 0.8
top_p_newline = 0.9 # only used in TOKEN_MODE = char
DEBUG_DEBUG = False # True False --> show softmax output
########################################################################################################
print(f"Loading {MODEL_NAME}...")
from src.model_run import RWKV_RNN
model = RWKV_RNN(
MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len
)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
########################################################################################################
if tokenizer.charMode:
context = tokenizer.refine_context(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
else:
ctx = tokenizer.tokenizer.encode(context)
src_len = len(ctx)
src_ctx = ctx.copy()
print("\nYour prompt has " + str(src_len) + " tokens.")
print(
"\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n"
)
# time_slot = {}
# time_ref = time.time_ns()
# def record_time(name):
# if name not in time_slot:
# time_slot[name] = 1e20
# tt = (time.time_ns() - time_ref) / 1e9
# if tt < time_slot[name]:
# time_slot[name] = tt
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
# time_ref = time.time_ns()
print(("-" * 50) + context, end="")
ctx = src_ctx.copy()
model.clear()
if TRIAL == 0:
init_state = types.SimpleNamespace()
for i in range(src_len):
x = ctx[: i + 1]
if i == src_len - 1:
init_state.out = model.forward(x)
else:
model.forward(x, preprocess_only=True)
model.save(init_state)
else:
model.load(init_state)
# record_time('model_pre')
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
# time_ref = time.time_ns()
x = ctx[: i + 1]
x = x[-ctx_len:]
if i == src_len:
out = copy.deepcopy(init_state.out)
else:
out = model.forward(x)
# record_time('model_run')
if DEBUG_DEBUG:
print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy()))
if TOKEN_MODE == "pile":
out[0] = -999999999 # disable <|endoftext|>
time_ref = time.time_ns()
char = tokenizer.sample_logits(
out,
x,
ctx_len,
temperature=TEMPERATURE,
top_p_usual=top_p,
top_p_newline=top_p_newline,
)
if tokenizer.charMode:
print(tokenizer.itos[char], end="", flush=True)
else:
print(tokenizer.tokenizer.decode(char), end="", flush=True)
ctx += [char]
# record_time('model_sampling')
print()
# print(f'\n\n{time_slot}\n\n')
# print(
# f"\n--- preprocess {round((t_mid - t_begin) / (10 ** 9), 2)}s, generation {round((t_end - t_mid) / (10 ** 9), 2)}s", end = ''
# )