diff --git a/RWKV-v4neo/chat.py b/RWKV-v4neo/chat.py index b2aad52..0e4bc5a 100644 --- a/RWKV-v4neo/chat.py +++ b/RWKV-v4neo/chat.py @@ -97,7 +97,7 @@ current_state = None def run_rnn(tokens, newline_adj = 0): global model_tokens, current_state for i in range(len(tokens)): - model_tokens += [tokens[i]] + model_tokens += [int(tokens[i])] if i == len(tokens) - 1: out, current_state = model.forward(model_tokens, current_state) else: @@ -198,7 +198,7 @@ def on_message(message): out = run_rnn(tokenizer.tokenizer.encode(new)) save_all_stat(srv, 'gen_0', out) - # new = f"\nThe following is an excellent Q&A session consists of detailed and factual information.\n\nQ: What is 3+5?\nA: 3+5=8.\n\nQ: {msg[9:].strip()}\nA:" + # new = f"\nThe following is an excellent Q&A session consists of detailed and factual information.\n\nQ: What is 3+5?\nA: The answer is 8.\n\nQ: {msg[9:].strip()}\nA:" # print(f'### prompt ###\n[{new}]') # current_state = None # out = run_rnn(tokenizer.tokenizer.encode(new)) @@ -254,7 +254,7 @@ def on_message(message): if i <= 0: newline_adj = -999999999 elif i <= 30: - newline_adj = -1 + newline_adj = (i - 30) / 10 elif i <= 80: newline_adj = 0 elif i <= 117: @@ -270,17 +270,22 @@ def on_message(message): top_p_newline=x_top_p, ) out = run_rnn([token], newline_adj=newline_adj) - if tokenizer.tokenizer.decode(model_tokens[-10:]).endswith(f'\n\n'): + send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]) + if '\n\n' in send_msg: + send_msg = send_msg.strip() break - # tail = tokenizer.tokenizer.decode(model_tokens[-10:]).strip() - # if tail.endswith(f'User:') or tail.endswith(f'Bot:'): + + # send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() + # if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!! + # send_msg = send_msg[:-len(f'{user}{interface}')].strip() + # break + # if send_msg.endswith(f'{bot}{interface}'): + # send_msg = send_msg[:-len(f'{bot}{interface}')].strip() # break - send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() - # if send_msg.endswith(f'User:'): - # send_msg = send_msg[:-5].strip() - # if send_msg.endswith(f'Bot:'): - # send_msg = send_msg[:-4].strip() + # print(f'{model_tokens}') + # print(f'[{tokenizer.tokenizer.decode(model_tokens)}]') + # print(f'### send ###\n[{send_msg}]') reply_msg(send_msg) save_all_stat(srv, 'chat', out)