diff --git a/RWKV-v4neo/chat.py b/RWKV-v4neo/chat.py index 0e4bc5a..d214ba2 100644 --- a/RWKV-v4neo/chat.py +++ b/RWKV-v4neo/chat.py @@ -17,6 +17,8 @@ torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True np.set_printoptions(precision=4, suppress=True, linewidth=200) +CHAT_LANG = 'English' # English Chinese + WORD_NAME = [ "20B_tokenizer.json", "20B_tokenizer.json", @@ -48,17 +50,15 @@ args.ctx_len = 1024 # args.n_embd = 2560 # args.ctx_len = 1024 -os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE -MODEL_NAME = args.MODEL_NAME - -user = "User" -bot = "Bot" -interface = ":" +if CHAT_LANG == 'English': + user = "User" + bot = "Bot" + interface = ":" -# The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. -# The following is a conversation between a highly knowledgeable and intelligent AI called {bot}, and a human called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} do its best to answer {user}'s questions. {bot} is respectful, polite and inclusive. {bot} knows a lot, and always tells the truth. + # The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. + # The following is a conversation between a highly knowledgeable and intelligent AI called {bot}, and a human called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} do its best to answer {user}'s questions. {bot} is respectful, polite and inclusive. {bot} knows a lot, and always tells the truth. -init_prompt = f''' + init_prompt = f''' The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. {user}{interface} french revolution what year @@ -81,10 +81,58 @@ The following is a verbose and detailed conversation between an AI assistant cal {bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. +''' + HELP_MSG = '''Commands: +say something --> chat with bot. use \\n for new line. ++alt --> alternate chat reply ++reset --> reset chat + ++gen YOUR PROMPT --> free generation with any prompt. use \\n for new line. ++qa YOUR QUESTION --> free generation - ask any question (just ask the question). use \\n for new line. ++more --> continue last free generation (only for +gen / +qa) ++retry --> retry last free generation (only for +gen / +qa) + +Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results. +This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +gen for free generation. +''' +elif CHAT_LANG == 'Chinese': + args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run3z/rwkv-293' + args.n_layer = 32 + args.n_embd = 4096 + args.ctx_len = 1024 + + user = "Q" + bot = "A" + interface = ":" + + init_prompt = ''' +Q: 企鹅会飞吗? + +A: 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。 + +Q: 西瓜是什么 + +A: 西瓜是一种常见的水果,是一种多年生蔓生藤本植物。西瓜的果实呈圆形或卵形,通常是绿色的,里面有红色或黄色的肉和很多的籽。西瓜味甜,多吃可以增加水分,是夏季非常受欢迎的水果之一。 + +''' + HELP_MSG = '''指令: +直接输入内容 --> 和机器人聊天,用\\n代表换行 ++alt --> 让机器人换个回答 ++reset --> 重置对话 + ++gen 某某内容 --> 续写任何中英文内容,用\\n代表换行 ++qa 某某问题 --> 问独立的问题(忽略上下文),用\\n代表换行 ++more --> 继续 +gen / +qa 的回答 ++retry --> 换个 +gen / +qa 的回答 + +现在可以输入内容和机器人聊天(注意它不怎么懂中文,它可能更懂英文)。请经常使用 +reset 重置机器人记忆。 ''' # Load Model +os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE +MODEL_NAME = args.MODEL_NAME + print(f'loading... {MODEL_NAME}') model = RWKV_RNN(args) @@ -107,8 +155,8 @@ def run_rnn(tokens, newline_adj = 0): out[0] = -999999999 # disable <|endoftext|> out[187] += newline_adj - if newline_adj > 0: - out[15] += newline_adj / 2 # '.' + # if newline_adj > 0: + # out[15] += newline_adj / 2 # '.' return out all_state = {} @@ -144,14 +192,14 @@ for s in srv_list: print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n') def reply_msg(msg): - print('Bot:', msg + '\n') + print(f'{bot}{interface} {msg}\n') def on_message(message): global model_tokens, current_state srv = 'dummy_server' - msg = message.strip() + msg = message.replace('\\n','\n').strip() if len(msg) > 1000: reply_msg('your message is too long (max 1000 tokens)') return @@ -179,16 +227,16 @@ def on_message(message): reply_msg("Chat reset.") return - elif msg[:5] == '+gen ' or msg[:4] == '+qa ' or msg == '+more' or msg == '+retry': + elif msg[:5].lower() == '+gen ' or msg[:4].lower() == '+qa ' or msg.lower() == '+more' or msg.lower() == '+retry': - if msg[:5] == '+gen ': + if msg[:5].lower() == '+gen ': new = '\n' + msg[5:].strip() # print(f'### prompt ###\n[{new}]') current_state = None out = run_rnn(tokenizer.tokenizer.encode(new)) save_all_stat(srv, 'gen_0', out) - elif msg[:4] == '+qa ': + elif msg[:4].lower() == '+qa ': out = load_all_stat('', 'chat_init') real_msg = msg[4:].strip() @@ -204,21 +252,22 @@ def on_message(message): # out = run_rnn(tokenizer.tokenizer.encode(new)) # save_all_stat(srv, 'gen_0', out) - elif msg == '+more': + elif msg.lower() == '+more': try: out = load_all_stat(srv, 'gen_1') save_all_stat(srv, 'gen_0', out) except: return - elif msg == '+retry': + elif msg.lower() == '+retry': try: out = load_all_stat(srv, 'gen_0') except: return begin = len(model_tokens) - for i in range(100): + out_last = begin + for i in range(150): token = tokenizer.sample_logits( out, model_tokens, @@ -227,17 +276,23 @@ def on_message(message): top_p_usual=x_top_p, top_p_newline=x_top_p, ) - if msg[:4] == '+qa ': + if msg[:4].lower() == '+qa ': out = run_rnn([token], newline_adj=-1) else: out = run_rnn([token]) - send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() + + xxx = tokenizer.tokenizer.decode(model_tokens[out_last:]) + if '\ufffd' not in xxx: + print(xxx, end='', flush=True) + out_last = begin + i + 1 + print('\n') + # send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() # print(f'### send ###\n[{send_msg}]') - reply_msg(send_msg) + # reply_msg(send_msg) save_all_stat(srv, 'gen_1', out) else: - if msg == '+alt': + if msg.lower() == '+alt': try: out = load_all_stat(srv, 'chat_pre') except: @@ -250,17 +305,17 @@ def on_message(message): save_all_stat(srv, 'chat_pre', out) begin = len(model_tokens) - for i in range(120): + out_last = begin + print(f'{bot}{interface}', end='', flush=True) + for i in range(999): if i <= 0: newline_adj = -999999999 elif i <= 30: newline_adj = (i - 30) / 10 - elif i <= 80: + elif i <= 130: newline_adj = 0 - elif i <= 117: - newline_adj = (i - 80) * 0.5 else: - newline_adj = 999999999 + newline_adj = (i - 130) * 0.25 # MUST END THE GENERATION token = tokenizer.sample_logits( out, model_tokens, @@ -270,6 +325,12 @@ def on_message(message): top_p_newline=x_top_p, ) out = run_rnn([token], newline_adj=newline_adj) + + xxx = tokenizer.tokenizer.decode(model_tokens[out_last:]) + if '\ufffd' not in xxx: + print(xxx, end='', flush=True) + out_last = begin + i + 1 + send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]) if '\n\n' in send_msg: send_msg = send_msg.strip() @@ -287,24 +348,13 @@ def on_message(message): # print(f'[{tokenizer.tokenizer.decode(model_tokens)}]') # print(f'### send ###\n[{send_msg}]') - reply_msg(send_msg) + # reply_msg(send_msg) save_all_stat(srv, 'chat', out) -print('''Commands: -+alt --> alternate chat reply -+reset --> reset chat - -+gen YOUR PROMPT --> free generation with your prompt -+qa YOUR QUESTION --> free generation - ask any question and get answer (just ask the question) -+more --> continue last free generation [does not work for chat] -+retry --> retry last free generation - -Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results. -This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +rwkv_gen for free generation. -''') +print(HELP_MSG) while True: - msg = input('User: ') + msg = input(f'{user}{interface} ') if len(msg.strip()) > 0: on_message(msg) else: