|
|
|
|
@ -17,6 +17,8 @@ torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
|
|
|
|
|
|
|
|
CHAT_LANG = 'English' # English Chinese
|
|
|
|
|
|
|
|
|
|
WORD_NAME = [
|
|
|
|
|
"20B_tokenizer.json",
|
|
|
|
|
"20B_tokenizer.json",
|
|
|
|
|
@ -48,9 +50,7 @@ args.ctx_len = 1024
|
|
|
|
|
# args.n_embd = 2560
|
|
|
|
|
# args.ctx_len = 1024
|
|
|
|
|
|
|
|
|
|
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
|
|
|
|
|
MODEL_NAME = args.MODEL_NAME
|
|
|
|
|
|
|
|
|
|
if CHAT_LANG == 'English':
|
|
|
|
|
user = "User"
|
|
|
|
|
bot = "Bot"
|
|
|
|
|
interface = ":"
|
|
|
|
|
@ -81,10 +81,58 @@ The following is a verbose and detailed conversation between an AI assistant cal
|
|
|
|
|
|
|
|
|
|
{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
HELP_MSG = '''Commands:
|
|
|
|
|
say something --> chat with bot. use \\n for new line.
|
|
|
|
|
+alt --> alternate chat reply
|
|
|
|
|
+reset --> reset chat
|
|
|
|
|
|
|
|
|
|
+gen YOUR PROMPT --> free generation with any prompt. use \\n for new line.
|
|
|
|
|
+qa YOUR QUESTION --> free generation - ask any question (just ask the question). use \\n for new line.
|
|
|
|
|
+more --> continue last free generation (only for +gen / +qa)
|
|
|
|
|
+retry --> retry last free generation (only for +gen / +qa)
|
|
|
|
|
|
|
|
|
|
Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results.
|
|
|
|
|
This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +gen for free generation.
|
|
|
|
|
'''
|
|
|
|
|
elif CHAT_LANG == 'Chinese':
|
|
|
|
|
args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run3z/rwkv-293'
|
|
|
|
|
args.n_layer = 32
|
|
|
|
|
args.n_embd = 4096
|
|
|
|
|
args.ctx_len = 1024
|
|
|
|
|
|
|
|
|
|
user = "Q"
|
|
|
|
|
bot = "A"
|
|
|
|
|
interface = ":"
|
|
|
|
|
|
|
|
|
|
init_prompt = '''
|
|
|
|
|
Q: 企鹅会飞吗?
|
|
|
|
|
|
|
|
|
|
A: 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。
|
|
|
|
|
|
|
|
|
|
Q: 西瓜是什么
|
|
|
|
|
|
|
|
|
|
A: 西瓜是一种常见的水果,是一种多年生蔓生藤本植物。西瓜的果实呈圆形或卵形,通常是绿色的,里面有红色或黄色的肉和很多的籽。西瓜味甜,多吃可以增加水分,是夏季非常受欢迎的水果之一。
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
HELP_MSG = '''指令:
|
|
|
|
|
直接输入内容 --> 和机器人聊天,用\\n代表换行
|
|
|
|
|
+alt --> 让机器人换个回答
|
|
|
|
|
+reset --> 重置对话
|
|
|
|
|
|
|
|
|
|
+gen 某某内容 --> 续写任何中英文内容,用\\n代表换行
|
|
|
|
|
+qa 某某问题 --> 问独立的问题(忽略上下文),用\\n代表换行
|
|
|
|
|
+more --> 继续 +gen / +qa 的回答
|
|
|
|
|
+retry --> 换个 +gen / +qa 的回答
|
|
|
|
|
|
|
|
|
|
现在可以输入内容和机器人聊天(注意它不怎么懂中文,它可能更懂英文)。请经常使用 +reset 重置机器人记忆。
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
# Load Model
|
|
|
|
|
|
|
|
|
|
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
|
|
|
|
|
MODEL_NAME = args.MODEL_NAME
|
|
|
|
|
|
|
|
|
|
print(f'loading... {MODEL_NAME}')
|
|
|
|
|
model = RWKV_RNN(args)
|
|
|
|
|
|
|
|
|
|
@ -107,8 +155,8 @@ def run_rnn(tokens, newline_adj = 0):
|
|
|
|
|
|
|
|
|
|
out[0] = -999999999 # disable <|endoftext|>
|
|
|
|
|
out[187] += newline_adj
|
|
|
|
|
if newline_adj > 0:
|
|
|
|
|
out[15] += newline_adj / 2 # '.'
|
|
|
|
|
# if newline_adj > 0:
|
|
|
|
|
# out[15] += newline_adj / 2 # '.'
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
all_state = {}
|
|
|
|
|
@ -144,14 +192,14 @@ for s in srv_list:
|
|
|
|
|
print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n')
|
|
|
|
|
|
|
|
|
|
def reply_msg(msg):
|
|
|
|
|
print('Bot:', msg + '\n')
|
|
|
|
|
print(f'{bot}{interface} {msg}\n')
|
|
|
|
|
|
|
|
|
|
def on_message(message):
|
|
|
|
|
global model_tokens, current_state
|
|
|
|
|
|
|
|
|
|
srv = 'dummy_server'
|
|
|
|
|
|
|
|
|
|
msg = message.strip()
|
|
|
|
|
msg = message.replace('\\n','\n').strip()
|
|
|
|
|
if len(msg) > 1000:
|
|
|
|
|
reply_msg('your message is too long (max 1000 tokens)')
|
|
|
|
|
return
|
|
|
|
|
@ -179,16 +227,16 @@ def on_message(message):
|
|
|
|
|
reply_msg("Chat reset.")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
elif msg[:5] == '+gen ' or msg[:4] == '+qa ' or msg == '+more' or msg == '+retry':
|
|
|
|
|
elif msg[:5].lower() == '+gen ' or msg[:4].lower() == '+qa ' or msg.lower() == '+more' or msg.lower() == '+retry':
|
|
|
|
|
|
|
|
|
|
if msg[:5] == '+gen ':
|
|
|
|
|
if msg[:5].lower() == '+gen ':
|
|
|
|
|
new = '\n' + msg[5:].strip()
|
|
|
|
|
# print(f'### prompt ###\n[{new}]')
|
|
|
|
|
current_state = None
|
|
|
|
|
out = run_rnn(tokenizer.tokenizer.encode(new))
|
|
|
|
|
save_all_stat(srv, 'gen_0', out)
|
|
|
|
|
|
|
|
|
|
elif msg[:4] == '+qa ':
|
|
|
|
|
elif msg[:4].lower() == '+qa ':
|
|
|
|
|
out = load_all_stat('', 'chat_init')
|
|
|
|
|
|
|
|
|
|
real_msg = msg[4:].strip()
|
|
|
|
|
@ -204,21 +252,22 @@ def on_message(message):
|
|
|
|
|
# out = run_rnn(tokenizer.tokenizer.encode(new))
|
|
|
|
|
# save_all_stat(srv, 'gen_0', out)
|
|
|
|
|
|
|
|
|
|
elif msg == '+more':
|
|
|
|
|
elif msg.lower() == '+more':
|
|
|
|
|
try:
|
|
|
|
|
out = load_all_stat(srv, 'gen_1')
|
|
|
|
|
save_all_stat(srv, 'gen_0', out)
|
|
|
|
|
except:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
elif msg == '+retry':
|
|
|
|
|
elif msg.lower() == '+retry':
|
|
|
|
|
try:
|
|
|
|
|
out = load_all_stat(srv, 'gen_0')
|
|
|
|
|
except:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
begin = len(model_tokens)
|
|
|
|
|
for i in range(100):
|
|
|
|
|
out_last = begin
|
|
|
|
|
for i in range(150):
|
|
|
|
|
token = tokenizer.sample_logits(
|
|
|
|
|
out,
|
|
|
|
|
model_tokens,
|
|
|
|
|
@ -227,17 +276,23 @@ def on_message(message):
|
|
|
|
|
top_p_usual=x_top_p,
|
|
|
|
|
top_p_newline=x_top_p,
|
|
|
|
|
)
|
|
|
|
|
if msg[:4] == '+qa ':
|
|
|
|
|
if msg[:4].lower() == '+qa ':
|
|
|
|
|
out = run_rnn([token], newline_adj=-1)
|
|
|
|
|
else:
|
|
|
|
|
out = run_rnn([token])
|
|
|
|
|
send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
|
|
|
|
|
|
|
|
|
|
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
|
|
|
|
|
if '\ufffd' not in xxx:
|
|
|
|
|
print(xxx, end='', flush=True)
|
|
|
|
|
out_last = begin + i + 1
|
|
|
|
|
print('\n')
|
|
|
|
|
# send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
|
|
|
|
|
# print(f'### send ###\n[{send_msg}]')
|
|
|
|
|
reply_msg(send_msg)
|
|
|
|
|
# reply_msg(send_msg)
|
|
|
|
|
save_all_stat(srv, 'gen_1', out)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
if msg == '+alt':
|
|
|
|
|
if msg.lower() == '+alt':
|
|
|
|
|
try:
|
|
|
|
|
out = load_all_stat(srv, 'chat_pre')
|
|
|
|
|
except:
|
|
|
|
|
@ -250,17 +305,17 @@ def on_message(message):
|
|
|
|
|
save_all_stat(srv, 'chat_pre', out)
|
|
|
|
|
|
|
|
|
|
begin = len(model_tokens)
|
|
|
|
|
for i in range(120):
|
|
|
|
|
out_last = begin
|
|
|
|
|
print(f'{bot}{interface}', end='', flush=True)
|
|
|
|
|
for i in range(999):
|
|
|
|
|
if i <= 0:
|
|
|
|
|
newline_adj = -999999999
|
|
|
|
|
elif i <= 30:
|
|
|
|
|
newline_adj = (i - 30) / 10
|
|
|
|
|
elif i <= 80:
|
|
|
|
|
elif i <= 130:
|
|
|
|
|
newline_adj = 0
|
|
|
|
|
elif i <= 117:
|
|
|
|
|
newline_adj = (i - 80) * 0.5
|
|
|
|
|
else:
|
|
|
|
|
newline_adj = 999999999
|
|
|
|
|
newline_adj = (i - 130) * 0.25 # MUST END THE GENERATION
|
|
|
|
|
token = tokenizer.sample_logits(
|
|
|
|
|
out,
|
|
|
|
|
model_tokens,
|
|
|
|
|
@ -270,6 +325,12 @@ def on_message(message):
|
|
|
|
|
top_p_newline=x_top_p,
|
|
|
|
|
)
|
|
|
|
|
out = run_rnn([token], newline_adj=newline_adj)
|
|
|
|
|
|
|
|
|
|
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
|
|
|
|
|
if '\ufffd' not in xxx:
|
|
|
|
|
print(xxx, end='', flush=True)
|
|
|
|
|
out_last = begin + i + 1
|
|
|
|
|
|
|
|
|
|
send_msg = tokenizer.tokenizer.decode(model_tokens[begin:])
|
|
|
|
|
if '\n\n' in send_msg:
|
|
|
|
|
send_msg = send_msg.strip()
|
|
|
|
|
@ -287,24 +348,13 @@ def on_message(message):
|
|
|
|
|
# print(f'[{tokenizer.tokenizer.decode(model_tokens)}]')
|
|
|
|
|
|
|
|
|
|
# print(f'### send ###\n[{send_msg}]')
|
|
|
|
|
reply_msg(send_msg)
|
|
|
|
|
# reply_msg(send_msg)
|
|
|
|
|
save_all_stat(srv, 'chat', out)
|
|
|
|
|
|
|
|
|
|
print('''Commands:
|
|
|
|
|
+alt --> alternate chat reply
|
|
|
|
|
+reset --> reset chat
|
|
|
|
|
|
|
|
|
|
+gen YOUR PROMPT --> free generation with your prompt
|
|
|
|
|
+qa YOUR QUESTION --> free generation - ask any question and get answer (just ask the question)
|
|
|
|
|
+more --> continue last free generation [does not work for chat]
|
|
|
|
|
+retry --> retry last free generation
|
|
|
|
|
|
|
|
|
|
Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results.
|
|
|
|
|
This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +rwkv_gen for free generation.
|
|
|
|
|
''')
|
|
|
|
|
print(HELP_MSG)
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
msg = input('User: ')
|
|
|
|
|
msg = input(f'{user}{interface} ')
|
|
|
|
|
if len(msg.strip()) > 0:
|
|
|
|
|
on_message(msg)
|
|
|
|
|
else:
|
|
|
|
|
|