better chatbot

main
BlinkDL 3 years ago
parent 511b7adb4f
commit 0d0cedfcd9

@ -97,7 +97,7 @@ current_state = None
def run_rnn(tokens, newline_adj = 0): def run_rnn(tokens, newline_adj = 0):
global model_tokens, current_state global model_tokens, current_state
for i in range(len(tokens)): for i in range(len(tokens)):
model_tokens += [tokens[i]] model_tokens += [int(tokens[i])]
if i == len(tokens) - 1: if i == len(tokens) - 1:
out, current_state = model.forward(model_tokens, current_state) out, current_state = model.forward(model_tokens, current_state)
else: else:
@ -198,7 +198,7 @@ def on_message(message):
out = run_rnn(tokenizer.tokenizer.encode(new)) out = run_rnn(tokenizer.tokenizer.encode(new))
save_all_stat(srv, 'gen_0', out) save_all_stat(srv, 'gen_0', out)
# new = f"\nThe following is an excellent Q&A session consists of detailed and factual information.\n\nQ: What is 3+5?\nA: 3+5=8.\n\nQ: {msg[9:].strip()}\nA:" # new = f"\nThe following is an excellent Q&A session consists of detailed and factual information.\n\nQ: What is 3+5?\nA: The answer is 8.\n\nQ: {msg[9:].strip()}\nA:"
# print(f'### prompt ###\n[{new}]') # print(f'### prompt ###\n[{new}]')
# current_state = None # current_state = None
# out = run_rnn(tokenizer.tokenizer.encode(new)) # out = run_rnn(tokenizer.tokenizer.encode(new))
@ -254,7 +254,7 @@ def on_message(message):
if i <= 0: if i <= 0:
newline_adj = -999999999 newline_adj = -999999999
elif i <= 30: elif i <= 30:
newline_adj = -1 newline_adj = (i - 30) / 10
elif i <= 80: elif i <= 80:
newline_adj = 0 newline_adj = 0
elif i <= 117: elif i <= 117:
@ -270,17 +270,22 @@ def on_message(message):
top_p_newline=x_top_p, top_p_newline=x_top_p,
) )
out = run_rnn([token], newline_adj=newline_adj) out = run_rnn([token], newline_adj=newline_adj)
if tokenizer.tokenizer.decode(model_tokens[-10:]).endswith(f'\n\n'): send_msg = tokenizer.tokenizer.decode(model_tokens[begin:])
if '\n\n' in send_msg:
send_msg = send_msg.strip()
break break
# tail = tokenizer.tokenizer.decode(model_tokens[-10:]).strip()
# if tail.endswith(f'User:') or tail.endswith(f'Bot:'): # send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
# if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!!
# send_msg = send_msg[:-len(f'{user}{interface}')].strip()
# break
# if send_msg.endswith(f'{bot}{interface}'):
# send_msg = send_msg[:-len(f'{bot}{interface}')].strip()
# break # break
send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() # print(f'{model_tokens}')
# if send_msg.endswith(f'User:'): # print(f'[{tokenizer.tokenizer.decode(model_tokens)}]')
# send_msg = send_msg[:-5].strip()
# if send_msg.endswith(f'Bot:'):
# send_msg = send_msg[:-4].strip()
# print(f'### send ###\n[{send_msg}]') # print(f'### send ###\n[{send_msg}]')
reply_msg(send_msg) reply_msg(send_msg)
save_all_stat(srv, 'chat', out) save_all_stat(srv, 'chat', out)

Loading…
Cancel
Save