Update hf-chat-example.py

main
randaller 3 years ago committed by GitHub
parent 85864cf9e4
commit 21313098fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,8 +1,9 @@
import llamahf import llamahf
import os import os
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
# # to save memory use bfloat16 # # to save memory use bfloat16
# import torch
# torch.set_default_dtype(torch.bfloat16) # torch.set_default_dtype(torch.bfloat16)
MODEL = 'decapoda-research/llama-7b-hf' MODEL = 'decapoda-research/llama-7b-hf'
@ -17,7 +18,17 @@ tokenizer = llamahf.LLaMATokenizer.from_pretrained(MODEL)
model = llamahf.LLaMAForCausalLM.from_pretrained(MODEL, low_cpu_mem_usage=True) model = llamahf.LLaMAForCausalLM.from_pretrained(MODEL, low_cpu_mem_usage=True)
model.to('cpu') model.to('cpu')
n = tokenizer.encode('\n', return_tensors='pt')[0]
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self):
super().__init__()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]):
if input_ids[0][-1] == 13:
return True
return False
ctx = """A dialog, where User interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits. ctx = """A dialog, where User interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits.
User: Hello, AI. User: Hello, AI.
@ -43,7 +54,8 @@ while True:
max_length=2048, max_length=2048,
top_p=0.95, top_p=0.95,
temperature=1.0, temperature=1.0,
eos_token_id=n stopping_criteria=StoppingCriteriaList([StoppingCriteriaSub()]),
# repetition_penalty=1.17
) )
decoded = tokenizer.decode(result[0]) decoded = tokenizer.decode(result[0])
ctx = decoded + "\n" ctx = decoded + "\n"

Loading…
Cancel
Save