From eaac3f7e6625dd9d64dd0c03604bf0e0eb8a6c67 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Mon, 9 Jan 2023 19:25:27 +0000 Subject: [PATCH] chatbot --- RWKV-v4neo/chat.py | 289 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 289 insertions(+) create mode 100644 RWKV-v4neo/chat.py diff --git a/RWKV-v4neo/chat.py b/RWKV-v4neo/chat.py new file mode 100644 index 0000000..4e95bd7 --- /dev/null +++ b/RWKV-v4neo/chat.py @@ -0,0 +1,289 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +print('Loading...') +from src.model_run import RWKV_RNN +import numpy as np +import os, copy, types, gc, sys +import torch +from src.utils import TOKENIZER +try: + os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] +except: + pass +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True +np.set_printoptions(precision=4, suppress=True, linewidth=200) + +WORD_NAME = [ + "20B_tokenizer.json", + "20B_tokenizer.json", +] # [vocab, vocab] for Pile model +UNKNOWN_CHAR = None +tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) + +args = types.SimpleNamespace() +args.RUN_DEVICE = "cuda" # 'cpu' (already very fast) // 'cuda' +args.FLOAT_MODE = "fp16" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate) +args.vocab_size = 50277 +args.head_qk = 0 +args.pre_ffn = 0 +args.grad_cp = 0 +args.my_pos_emb = 0 + +args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230108-5170' +args.n_layer = 40 +args.n_embd = 5120 +args.ctx_len = 1024 + +# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047' +# args.n_layer = 32 +# args.n_embd = 4096 +# args.ctx_len = 1024 + +# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023' +# args.n_layer = 32 +# args.n_embd = 2560 +# args.ctx_len = 1024 + +os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE +MODEL_NAME = args.MODEL_NAME + +user = "User" +bot = "Bot" +interface = ":" + +init_prompt = f''' +The following is a conversation between a highly knowledgeable and intelligent AI assistant called {bot}, and a human user called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} always answer {user}'s questions. {bot} is very smart, polite and humorous. {bot} knows a lot, and always tells the truth. The conversation begins. + +{user}{interface} who is president of usa? + +{bot}{interface} It’s Joe Biden; he was sworn in earlier this year. + +{user}{interface} french revolution what year + +{bot}{interface} It started in 1789, but it lasted 10 years until 1799. + +{user}{interface} guess i marry who ? + +{bot}{interface} Only if you tell me more about yourself - what are your interests? + +{user}{interface} wat is lhc + +{bot}{interface} It’s a large and very expensive piece of science equipment. If I understand correctly, it’s a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. + +''' + +# Load Model + +print(f'loading... {MODEL_NAME}') +model = RWKV_RNN(args) + +model_tokens = [] + +current_state = None + +######################################################################################################## + +def run_rnn(tokens, newline_adj = 0): + global model_tokens, current_state + for i in range(len(tokens)): + model_tokens += [tokens[i]] + if i == len(tokens) - 1: + out, current_state = model.forward(model_tokens, current_state) + else: + current_state = model.forward(model_tokens, current_state, preprocess_only = True) + + # print(f'### model ###\n[{tokenizer.tokenizer.decode(model_tokens)}]') + + out[0] = -999999999 # disable <|endoftext|> + out[187] += newline_adj + if newline_adj > 0: + out[15] += newline_adj / 2 # '.' + return out + +all_state = {} +def save_all_stat(srv, name, last_out): + n = f'{name}_{srv}' + all_state[n] = {} + all_state[n]['out'] = last_out + all_state[n]['rnn'] = copy.deepcopy(current_state) + all_state[n]['token'] = copy.deepcopy(model_tokens) + +def load_all_stat(srv, name): + global model_tokens, current_state + n = f'{name}_{srv}' + current_state = copy.deepcopy(all_state[n]['rnn']) + model_tokens = copy.deepcopy(all_state[n]['token']) + return all_state[n]['out'] + +######################################################################################################## + +# Run inference +print(f'\nRun prompt...') + +out = run_rnn(tokenizer.tokenizer.encode(init_prompt)) +gc.collect() +torch.cuda.empty_cache() + +save_all_stat('', 'chat_init', out) + +srv_list = ['dummy_server'] +for s in srv_list: + save_all_stat(s, 'chat', out) + +print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n') + +def reply_msg(msg): + print('Bot:', msg + '\n') + +def on_message(message): + srv = 'dummy_server' + + msg = message.strip() + if len(msg) > 1000: + reply_msg('your message is too long (max 1000 tokens)') + return + + x_temp = 1.0 + x_top_p = 0.8 + if ("-temp=" in msg): + x_temp = float(msg.split("-temp=")[1].split(" ")[0]) + msg = msg.replace("-temp="+f'{x_temp:g}', "") + # print(f"temp: {x_temp}") + if ("-top_p=" in msg): + x_top_p = float(msg.split("-top_p=")[1].split(" ")[0]) + msg = msg.replace("-top_p="+f'{x_top_p:g}', "") + # print(f"top_p: {x_top_p}") + if x_temp <= 0.2: + x_temp = 0.2 + if x_temp >= 5: + x_temp = 5 + if x_top_p <= 0: + x_top_p = 0 + + if msg == '+reset_rwkv' or msg == '+rwkv_reset': + out = load_all_stat('', 'chat_init') + save_all_stat(srv, 'chat', out) + reply_msg("Chat reset.") + return + + elif msg[:10] == '+rwkv_gen ' or msg[:9] == '+rwkv_qa ' or msg == '+rwkv_more' or msg == '+rwkv_retry' or msg == '+rwkv_again': + + if msg[:10] == '+rwkv_gen ': + new = '\n' + msg[10:].strip() + # print(f'### prompt ###\n[{new}]') + current_state = None + out = run_rnn(tokenizer.tokenizer.encode(new)) + save_all_stat(srv, 'gen_0', out) + + elif msg[:9] == '+rwkv_qa ': + out = load_all_stat('', 'chat_init') + + real_msg = msg[9:].strip() + new = f"{user}{interface} {real_msg}\n\n{bot}{interface}" + # print(f'### qa ###\n[{new}]') + + out = run_rnn(tokenizer.tokenizer.encode(new)) + save_all_stat(srv, 'gen_0', out) + + # new = f"\nThe following is an excellent Q&A session consists of detailed and factual information.\n\nQ: What is 3+5?\nA: 3+5=8.\n\nQ: {msg[9:].strip()}\nA:" + # print(f'### prompt ###\n[{new}]') + # current_state = None + # out = run_rnn(tokenizer.tokenizer.encode(new)) + # save_all_stat(srv, 'gen_0', out) + + elif msg == '+rwkv_more': + try: + out = load_all_stat(srv, 'gen_1') + save_all_stat(srv, 'gen_0', out) + except: + return + + elif msg == '+rwkv_retry' or msg == '+rwkv_again': + try: + out = load_all_stat(srv, 'gen_0') + except: + return + + begin = len(model_tokens) + for i in range(100): + token = tokenizer.sample_logits( + out, + model_tokens, + args.ctx_len, + temperature=x_temp, + top_p_usual=x_top_p, + top_p_newline=x_top_p, + ) + if msg[:9] == '+rwkv_qa ': + out = run_rnn([token], newline_adj=-2) + else: + out = run_rnn([token]) + send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() + # print(f'### send ###\n[{send_msg}]') + reply_msg(send_msg) + save_all_stat(srv, 'gen_1', out) + + else: + if msg == '+rwkv_alt': + try: + out = load_all_stat(srv, 'chat_pre') + except: + return + else: + out = load_all_stat(srv, 'chat') + new = f"{user}{interface} {msg}\n\n{bot}{interface}" + # print(f'### add ###\n[{new}]') + out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999) + save_all_stat(srv, 'chat_pre', out) + + begin = len(model_tokens) + for i in range(120): + if i <= 0: + newline_adj = -999999999 + elif i <= 30: + newline_adj = -2 + elif i <= 80: + newline_adj = 0 + elif i <= 117: + newline_adj = (i - 80) * 0.5 + else: + newline_adj = 999999999 + token = tokenizer.sample_logits( + out, + model_tokens, + args.ctx_len, + temperature=x_temp, + top_p_usual=x_top_p, + top_p_newline=x_top_p, + ) + out = run_rnn([token], newline_adj=newline_adj) + if tokenizer.tokenizer.decode(model_tokens[-10:]).endswith(f'\n\n'): + break + + send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() + # print(f'### send ###\n[{send_msg}]') + reply_msg(send_msg) + save_all_stat(srv, 'chat', out) + +print('''Commands: ++rwkv_alt --> alternate chat reply ++rwkv_reset --> reset chat + ++rwkv_gen YOUR PROMPT --> free generation with your prompt ++rwkv_qa YOUR QUESTION --> free generation - ask any question and get answer (just ask the question) ++rwkv_more --> continue last free generation [does not work for chat] ++rwkv_retry --> retry last free generation + +Now talk with the bot and enjoy. Remember to +rwkv_reset periodically to clean up the bot's memory !!! Use RWKV-4 14B for best results. +''') + +while True: + msg = input('User: ') + if len(msg.strip()) > 0: + on_message(msg) + else: + print('Erorr: please say something')