You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
3.5 KiB
Python
102 lines
3.5 KiB
Python
# -*- coding:utf-8 -*-
|
|
########################################################################################################
|
|
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
########################################################################################################
|
|
|
|
import numpy as np
|
|
import time
|
|
import types
|
|
import copy
|
|
import torch
|
|
from torch.nn import functional as F
|
|
from src.utils import TOKENIZER
|
|
from src.model_run import RWKV_RNN
|
|
torch.backends.cudnn.benchmark = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
### Step 1: set model ##################################################################################
|
|
|
|
ctx_len = 1024
|
|
n_layer = 6
|
|
n_embd = 512
|
|
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
|
|
|
|
# your trained model
|
|
MODEL_NAME = 'trained-31'
|
|
WORD_NAME = 'vocab' # the .json vocab (generated by train.py
|
|
|
|
# ### uncompress enwik8-model.zip to test my enwik8 model
|
|
# MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13'
|
|
# WORD_NAME = 'enwik8-vocab'
|
|
|
|
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
|
|
# --> all unknown tokens in your context will be denoted by it <--
|
|
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
|
|
|
|
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
|
|
DEBUG_DEBUG = False # True False - show softmax output
|
|
|
|
### Step 2: set context ################################################################################
|
|
|
|
context = "\nIn the" # ==> this is your prompt
|
|
|
|
NUM_TRIALS = 999
|
|
LENGTH_PER_TRIAL = 500
|
|
|
|
TEMPERATURE = 1.0
|
|
top_p = 0.7
|
|
top_p_newline = 0.9
|
|
|
|
########################################################################################################
|
|
|
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
|
|
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
|
context = tokenizer.refine_context(context)
|
|
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
|
|
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n')
|
|
|
|
print(f'Loading {MODEL_NAME}...')
|
|
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
|
|
|
|
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
|
|
t_begin = time.time_ns()
|
|
|
|
src_len = len(context)
|
|
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
|
|
print(('-' * 30) + context, end='')
|
|
|
|
model.clear()
|
|
if TRIAL == 0:
|
|
init_state = types.SimpleNamespace()
|
|
for i in range(src_len):
|
|
x = ctx[:i+1]
|
|
if i == src_len - 1:
|
|
init_state.out = model.run(x)
|
|
else:
|
|
model.run(x)
|
|
model.save(init_state)
|
|
else:
|
|
model.load(init_state)
|
|
|
|
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
|
|
x = ctx[:i+1]
|
|
x = x[-ctx_len:]
|
|
|
|
if i == src_len:
|
|
out = copy.deepcopy(init_state.out)
|
|
else:
|
|
out = model.run(x)
|
|
if DEBUG_DEBUG:
|
|
print('model', np.array(x), '==>', np.array(
|
|
out), np.max(out), np.min(out))
|
|
|
|
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
|
|
top_p_usual=top_p, top_p_newline=top_p_newline)
|
|
char = char.item()
|
|
print(tokenizer.itos[int(char)], end='', flush=True)
|
|
ctx += [char]
|
|
t_end = time.time_ns()
|
|
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')
|