main
BlinkDL 3 years ago
parent fc3bc1eb0e
commit c7155525bb

@ -62,17 +62,17 @@ elif TOKEN_MODE == "pile":
# n_embd = 1024 # n_embd = 1024
# ctx_len = 1024 # ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-1B5-20220903-8040' MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-1B5-20220903-8040'
# n_layer = 24 n_layer = 24
# n_embd = 2048 n_embd = 2048
# ctx_len = 1024
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20220925-4537'
n_layer = 32
n_embd = 2560
ctx_len = 1024 ctx_len = 1024
os.environ["RWKV_FLOAT_MODE"] = "fp32" # currently only supprts fp32 # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20220925-4537'
# n_layer = 32
# n_embd = 2560
# ctx_len = 1024
os.environ["RWKV_FLOAT_MODE"] = "fp32" # currently only supprts fp32 (it can do bf16 and fp16. just wait a bit... busy these days)
os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda' os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda'
model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre' model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre'
@ -83,12 +83,13 @@ model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre'
# context = 'A' # context = 'A'
# context = "\nIn the" # context = "\nIn the"
# context = '\nSugar:' # context = '\nSugar:'
# context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
context = "\n深圳是" # test Chinese # context = "\n深圳是" # test Chinese
context = "\n東京は" # test Japanese # context = "\n東京は" # test Japanese
# context = ''' # A good prompt for chatbot ###### A good prompt for chatbot ######
# context = '''
# The following is a conversation between a highly knowledgeable and intelligent AI assistant, called RWKV, and a human user, called User. In the following interactions, User and RWKV will converse in natural language, and RWKV will do its best to answer Users questions. RWKV was built to be respectful, polite and inclusive. It knows a lot, and always tells the truth. The conversation begins. # The following is a conversation between a highly knowledgeable and intelligent AI assistant, called RWKV, and a human user, called User. In the following interactions, User and RWKV will converse in natural language, and RWKV will do its best to answer Users questions. RWKV was built to be respectful, polite and inclusive. It knows a lot, and always tells the truth. The conversation begins.
# User: OK RWKV, Im going to start by quizzing you with a few warm-up questions. Who is currently the president of the USA? # User: OK RWKV, Im going to start by quizzing you with a few warm-up questions. Who is currently the president of the USA?
@ -120,7 +121,7 @@ DEBUG_DEBUG = False # True False --> show softmax output
######################################################################################################## ########################################################################################################
print(f"Loading {MODEL_NAME}...") print(f'\nUsing {os.environ["RWKV_RUN_DEVICE"].upper()}. Loading {MODEL_NAME}...')
from src.model_run import RWKV_RNN from src.model_run import RWKV_RNN
model = RWKV_RNN( model = RWKV_RNN(
@ -138,25 +139,25 @@ else:
src_len = len(ctx) src_len = len(ctx)
src_ctx = ctx.copy() src_ctx = ctx.copy()
print("\nYour prompt has " + str(src_len) + " tokens.") print("Your prompt has " + str(src_len) + " tokens.")
print( print(
"\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n" "\nNote: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n"
) )
# time_slot = {} time_slot = {}
# time_ref = time.time_ns() time_ref = time.time_ns()
# def record_time(name): def record_time(name):
# if name not in time_slot: if name not in time_slot:
# time_slot[name] = 1e20 time_slot[name] = 1e20
# tt = (time.time_ns() - time_ref) / 1e9 tt = (time.time_ns() - time_ref) / 1e9
# if tt < time_slot[name]: if tt < time_slot[name]:
# time_slot[name] = tt time_slot[name] = tt
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
# time_ref = time.time_ns() print(("-" * 50) + '\n' + context, end="")
print(("-" * 50) + context, end="") time_ref = time.time_ns()
ctx = src_ctx.copy() ctx = src_ctx.copy()
model.clear() model.clear()
@ -172,11 +173,9 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
else: else:
model.load(init_state) model.load(init_state)
# record_time('model_pre') record_time('preprocess')
out_last = src_len out_last = src_len
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
# time_ref = time.time_ns()
x = ctx[: i + 1] x = ctx[: i + 1]
x = x[-ctx_len:] x = x[-ctx_len:]
@ -184,14 +183,12 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
out = copy.deepcopy(init_state.out) out = copy.deepcopy(init_state.out)
else: else:
out = model.forward(x) out = model.forward(x)
# record_time('model_run')
if DEBUG_DEBUG: if DEBUG_DEBUG:
print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy())) print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy()))
if TOKEN_MODE == "pile": if TOKEN_MODE == "pile":
out[0] = -999999999 # disable <|endoftext|> out[0] = -999999999 # disable <|endoftext|>
time_ref = time.time_ns()
ttt = tokenizer.sample_logits( ttt = tokenizer.sample_logits(
out, out,
x, x,
@ -211,9 +208,8 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
print(char, end="", flush=True) print(char, end="", flush=True)
out_last = i+1 out_last = i+1
# record_time('model_sampling') record_time('total')
print()
# print(f'\n\n{time_slot}\n\n') # print(f'\n\n{time_slot}\n\n')
# print( print(
# f"\n--- preprocess {round((t_mid - t_begin) / (10 ** 9), 2)}s, generation {round((t_end - t_mid) / (10 ** 9), 2)}s", end = '' f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = ''
# ) )

Loading…
Cancel
Save