@ -6,6 +6,7 @@
import numpy as np
import numpy as np
import types
import types
import copy
import copy
import time
import torch
import torch
from torch . nn import functional as F
from torch . nn import functional as F
from src . utils import TOKENIZER
from src . utils import TOKENIZER
@ -20,8 +21,9 @@ n_layer = 6
n_embd = 512
n_embd = 512
model_type = ' RWKV ' # 'RWKV' or 'RWKV-ffnPre'
model_type = ' RWKV ' # 'RWKV' or 'RWKV-ffnPre'
MODEL_NAME = ' trained-31 ' # your trained model
# your trained model
WORD_NAME = ' vocab ' # the .json vocab (generated by train.py)
MODEL_NAME = ' enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13 '
WORD_NAME = ' enwik8-vocab ' # the .json vocab (generated by train.py)
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> unknown tokens in your context will be denoted by it <--
# --> unknown tokens in your context will be denoted by it <--
@ -44,12 +46,13 @@ top_p_newline = 0.9
########################################################################################################
########################################################################################################
np . set_printoptions ( precision = 4 , suppress = True , linewidth = 200 )
np . set_printoptions ( precision = 4 , suppress = True , linewidth = 200 )
tokenizer = TOKENIZER ( WORD_NAME , UNKNOWN_CHAR = UNKNOWN_CHAR )
tokenizer = TOKENIZER ( WORD_NAME , UNKNOWN_CHAR = UNKNOWN_CHAR )
context = tokenizer . refine_context ( context )
context = tokenizer . refine_context ( context )
print ( ' Your context has ' + str ( len ( context ) ) + ' tokens ' )
print ( ' \n Your prompt has ' + str ( len ( context ) ) + ' tokens. ' )
print ( ' \n --> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <-- \n ' )
print ( f ' Loading { MODEL_NAME } ... ' )
print ( f ' Loading { MODEL_NAME } ... ' )
##############################################################################################################
##############################################################################################################
@ -65,7 +68,7 @@ class RWKV_RNN():
self . w = types . SimpleNamespace ( )
self . w = types . SimpleNamespace ( )
w = torch . load ( MODEL_NAME + ' .pth ' ,
w = torch . load ( MODEL_NAME + ' .pth ' ,
map_location = torch . device ( RUN_DEVICE ) ) # .state_dict()
map_location = torch . device ( RUN_DEVICE ) )
for x in w . keys ( ) :
for x in w . keys ( ) :
if ' .time_ ' in x :
if ' .time_ ' in x :
w [ x ] = w [ x ] . squeeze ( )
w [ x ] = w [ x ] . squeeze ( )
@ -195,12 +198,12 @@ class RWKV_RNN():
model = RWKV_RNN ( MODEL_NAME )
model = RWKV_RNN ( MODEL_NAME )
print ( ' \n \n --> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <-- \n ' )
for TRIAL in range ( 1 if DEBUG_DEBUG else NUM_TRIALS ) :
for TRIAL in range ( 1 if DEBUG_DEBUG else NUM_TRIALS ) :
t_begin = time . time_ns ( )
src_len = len ( context )
src_len = len ( context )
ctx = [ tokenizer . stoi . get ( s , tokenizer . UNKNOWN_CHAR ) for s in context ]
ctx = [ tokenizer . stoi . get ( s , tokenizer . UNKNOWN_CHAR ) for s in context ]
print ( context . replace ( ' \n ' , ' \n ' ) , end = ' ' )
print ( ( ' - ' * 30 ) + context , end = ' ' )
model . clear ( )
model . clear ( )
if TRIAL == 0 :
if TRIAL == 0 :
@ -230,7 +233,7 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
char = tokenizer . sample_logits ( out , x , ctx_len , temperature = TEMPERATURE ,
char = tokenizer . sample_logits ( out , x , ctx_len , temperature = TEMPERATURE ,
top_p_usual = top_p , top_p_newline = top_p_newline )
top_p_usual = top_p , top_p_newline = top_p_newline )
char = char . item ( )
char = char . item ( )
print ( tokenizer . itos [ int ( char ) ] . replace (
print ( tokenizer . itos [ int ( char ) ] , end = ' ' , flush = True )
' \n ' , ' \n ' ) , end = ' ' , flush = True )
ctx + = [ char ]
ctx + = [ char ]
print ( ' \n ' + ' - ' * 40 , end = ' ' )
t_end = time . time_ns ( )
print ( " \n ---------- " , round ( ( t_end - t_begin ) / ( 10 * * 9 ) , 2 ) , end = ' s ' )