@ -17,6 +17,8 @@ torch.backends.cudnn.allow_tf32 = True
torch . backends . cuda . matmul . allow_tf32 = True
np . set_printoptions ( precision = 4 , suppress = True , linewidth = 200 )
CHAT_LANG = ' English ' # English Chinese
WORD_NAME = [
" 20B_tokenizer.json " ,
" 20B_tokenizer.json " ,
@ -48,17 +50,15 @@ args.ctx_len = 1024
# args.n_embd = 2560
# args.ctx_len = 1024
os . environ [ " RWKV_RUN_DEVICE " ] = args . RUN_DEVICE
MODEL_NAME = args . MODEL_NAME
user = " User "
bot = " Bot "
interface = " : "
if CHAT_LANG == ' English ' :
user = " User "
bot = " Bot "
interface = " : "
# The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
# The following is a conversation between a highly knowledgeable and intelligent AI called {bot}, and a human called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} do its best to answer {user}'s questions. {bot} is respectful, polite and inclusive. {bot} knows a lot, and always tells the truth.
# The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
# The following is a conversation between a highly knowledgeable and intelligent AI called {bot}, and a human called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} do its best to answer {user}'s questions. {bot} is respectful, polite and inclusive. {bot} knows a lot, and always tells the truth.
init_prompt = f '''
init_prompt = f '''
The following is a verbose and detailed conversation between an AI assistant called { bot } , and a human user called { user } . { bot } is intelligent , knowledgeable , wise and polite .
{ user } { interface } french revolution what year
@ -81,10 +81,58 @@ The following is a verbose and detailed conversation between an AI assistant cal
{ bot } { interface } LHC is a high - energy particle collider , built by CERN , and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
'''
HELP_MSG = ''' Commands:
say something - - > chat with bot . use \\n for new line .
+ alt - - > alternate chat reply
+ reset - - > reset chat
+ gen YOUR PROMPT - - > free generation with any prompt . use \\n for new line .
+ qa YOUR QUESTION - - > free generation - ask any question ( just ask the question ) . use \\n for new line .
+ more - - > continue last free generation ( only for + gen / + qa )
+ retry - - > retry last free generation ( only for + gen / + qa )
Now talk with the bot and enjoy . Remember to + reset periodically to clean up the bot ' s memory. Use RWKV-4 14B for best results.
This is not instruct - tuned for conversation yet , so don ' t expect good quality. Better use +gen for free generation.
'''
elif CHAT_LANG == ' Chinese ' :
args . MODEL_NAME = ' /fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run3z/rwkv-293 '
args . n_layer = 32
args . n_embd = 4096
args . ctx_len = 1024
user = " Q "
bot = " A "
interface = " : "
init_prompt = '''
Q : 企鹅会飞吗 ?
A : 企鹅是不会飞的 。 它们的翅膀主要用于游泳和平衡 , 而不是飞行 。
Q : 西瓜是什么
A : 西瓜是一种常见的水果 , 是一种多年生蔓生藤本植物 。 西瓜的果实呈圆形或卵形 , 通常是绿色的 , 里面有红色或黄色的肉和很多的籽 。 西瓜味甜 , 多吃可以增加水分 , 是夏季非常受欢迎的水果之一 。
'''
HELP_MSG = ''' 指令:
直接输入内容 - - > 和机器人聊天 , 用 \\n代表换行
+ alt - - > 让机器人换个回答
+ reset - - > 重置对话
+ gen 某某内容 - - > 续写任何中英文内容 , 用 \\n代表换行
+ qa 某某问题 - - > 问独立的问题 ( 忽略上下文 ) , 用 \\n代表换行
+ more - - > 继续 + gen / + qa 的回答
+ retry - - > 换个 + gen / + qa 的回答
现在可以输入内容和机器人聊天 ( 注意它不怎么懂中文 , 它可能更懂英文 ) 。 请经常使用 + reset 重置机器人记忆 。
'''
# Load Model
os . environ [ " RWKV_RUN_DEVICE " ] = args . RUN_DEVICE
MODEL_NAME = args . MODEL_NAME
print ( f ' loading... { MODEL_NAME } ' )
model = RWKV_RNN ( args )
@ -107,8 +155,8 @@ def run_rnn(tokens, newline_adj = 0):
out [ 0 ] = - 999999999 # disable <|endoftext|>
out [ 187 ] + = newline_adj
if newline_adj > 0 :
out [ 15 ] + = newline_adj / 2 # '.'
# if newline_adj > 0 :
# out[15] += newline_adj / 2 # '.'
return out
all_state = { }
@ -144,14 +192,14 @@ for s in srv_list:
print ( f ' ### prompt ### \n [ { tokenizer . tokenizer . decode ( model_tokens ) } ] \n ' )
def reply_msg ( msg ) :
print ( ' Bot: ' , msg + ' \n ' )
print ( f' { bot } { interface } { msg } \n ' )
def on_message ( message ) :
global model_tokens , current_state
srv = ' dummy_server '
msg = message . strip( )
msg = message . replace( ' \\ n ' , ' \n ' ) . strip( )
if len ( msg ) > 1000 :
reply_msg ( ' your message is too long (max 1000 tokens) ' )
return
@ -179,16 +227,16 @@ def on_message(message):
reply_msg ( " Chat reset. " )
return
elif msg [ : 5 ] == ' +gen ' or msg [ : 4 ] == ' +qa ' or msg == ' +more ' or msg == ' +retry ' :
elif msg [ : 5 ] . lower ( ) == ' +gen ' or msg [ : 4 ] . lower ( ) == ' +qa ' or msg . lower ( ) == ' +more ' or msg . lower ( ) == ' +retry ' :
if msg [ : 5 ] == ' +gen ' :
if msg [ : 5 ] . lower ( ) == ' +gen ' :
new = ' \n ' + msg [ 5 : ] . strip ( )
# print(f'### prompt ###\n[{new}]')
current_state = None
out = run_rnn ( tokenizer . tokenizer . encode ( new ) )
save_all_stat ( srv , ' gen_0 ' , out )
elif msg [ : 4 ] == ' +qa ' :
elif msg [ : 4 ] . lower ( ) == ' +qa ' :
out = load_all_stat ( ' ' , ' chat_init ' )
real_msg = msg [ 4 : ] . strip ( )
@ -204,21 +252,22 @@ def on_message(message):
# out = run_rnn(tokenizer.tokenizer.encode(new))
# save_all_stat(srv, 'gen_0', out)
elif msg == ' +more ' :
elif msg . lower ( ) == ' +more ' :
try :
out = load_all_stat ( srv , ' gen_1 ' )
save_all_stat ( srv , ' gen_0 ' , out )
except :
return
elif msg == ' +retry ' :
elif msg . lower ( ) == ' +retry ' :
try :
out = load_all_stat ( srv , ' gen_0 ' )
except :
return
begin = len ( model_tokens )
for i in range ( 100 ) :
out_last = begin
for i in range ( 150 ) :
token = tokenizer . sample_logits (
out ,
model_tokens ,
@ -227,17 +276,23 @@ def on_message(message):
top_p_usual = x_top_p ,
top_p_newline = x_top_p ,
)
if msg [ : 4 ] == ' +qa ' :
if msg [ : 4 ] . lower ( ) == ' +qa ' :
out = run_rnn ( [ token ] , newline_adj = - 1 )
else :
out = run_rnn ( [ token ] )
send_msg = tokenizer . tokenizer . decode ( model_tokens [ begin : ] ) . strip ( )
xxx = tokenizer . tokenizer . decode ( model_tokens [ out_last : ] )
if ' \ufffd ' not in xxx :
print ( xxx , end = ' ' , flush = True )
out_last = begin + i + 1
print ( ' \n ' )
# send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
# print(f'### send ###\n[{send_msg}]')
reply_msg ( send_msg )
# reply_msg(send_msg )
save_all_stat ( srv , ' gen_1 ' , out )
else :
if msg == ' +alt ' :
if msg . lower ( ) == ' +alt ' :
try :
out = load_all_stat ( srv , ' chat_pre ' )
except :
@ -250,17 +305,17 @@ def on_message(message):
save_all_stat ( srv , ' chat_pre ' , out )
begin = len ( model_tokens )
for i in range ( 120 ) :
out_last = begin
print ( f ' { bot } { interface } ' , end = ' ' , flush = True )
for i in range ( 999 ) :
if i < = 0 :
newline_adj = - 999999999
elif i < = 30 :
newline_adj = ( i - 30 ) / 10
elif i < = 8 0:
elif i < = 13 0:
newline_adj = 0
elif i < = 117 :
newline_adj = ( i - 80 ) * 0.5
else :
newline_adj = 999999999
newline_adj = ( i - 130 ) * 0.25 # MUST END THE GENERATION
token = tokenizer . sample_logits (
out ,
model_tokens ,
@ -270,6 +325,12 @@ def on_message(message):
top_p_newline = x_top_p ,
)
out = run_rnn ( [ token ] , newline_adj = newline_adj )
xxx = tokenizer . tokenizer . decode ( model_tokens [ out_last : ] )
if ' \ufffd ' not in xxx :
print ( xxx , end = ' ' , flush = True )
out_last = begin + i + 1
send_msg = tokenizer . tokenizer . decode ( model_tokens [ begin : ] )
if ' \n \n ' in send_msg :
send_msg = send_msg . strip ( )
@ -287,24 +348,13 @@ def on_message(message):
# print(f'[{tokenizer.tokenizer.decode(model_tokens)}]')
# print(f'### send ###\n[{send_msg}]')
reply_msg ( send_msg )
# reply_msg(send_msg )
save_all_stat ( srv , ' chat ' , out )
print ( ''' Commands:
+ alt - - > alternate chat reply
+ reset - - > reset chat
+ gen YOUR PROMPT - - > free generation with your prompt
+ qa YOUR QUESTION - - > free generation - ask any question and get answer ( just ask the question )
+ more - - > continue last free generation [ does not work for chat ]
+ retry - - > retry last free generation
Now talk with the bot and enjoy . Remember to + reset periodically to clean up the bot ' s memory. Use RWKV-4 14B for best results.
This is not instruct - tuned for conversation yet , so don ' t expect good quality. Better use +rwkv_gen for free generation.
''' )
print ( HELP_MSG )
while True :
msg = input ( ' User: ' )
msg = input ( f ' { user } { interface } ' )
if len ( msg . strip ( ) ) > 0 :
on_message ( msg )
else :