Update example-cpu.py

main
randaller 3 years ago committed by GitHub
parent 417c576a9f
commit 1a7b31831b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,8 +19,7 @@ def setup_model() -> Tuple[int, int]:
local_rank = int(-1) local_rank = int(-1)
world_size = int(1) world_size = int(1)
# seed must be the same in all processes # torch.manual_seed(1)
torch.manual_seed(1)
return local_rank, world_size return local_rank, world_size
@ -45,15 +44,13 @@ def load(
model_args: ModelArgs = ModelArgs( model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
) )
tokenizer = Tokenizer(model_path=tokenizer_path) tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.FloatTensor)
model = Transformer(model_args) model = Transformer(model_args)
model.to("cpu") model.to("cpu")
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False) model.load_state_dict(checkpoint, strict=False)
generator = LLaMA(model, tokenizer) generator = LLaMA(model, tokenizer)
@ -78,31 +75,35 @@ def main(
) )
prompts = [ prompts = [
# For these prompts, the expected answer is the natural continuation of the prompt ######## For these prompts, the expected answer is the natural continuation of the prompt #######
"I believe the meaning of life is", "I believe the meaning of life is",
"Simply put, the theory of relativity states that ", # "Simply put, the theory of relativity states that ",
"Building a website can be done in 10 simple steps:\n", # "Building a website can be done in 10 simple steps:\n",
# Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api
"""Tweet: "I hate it when my phone battery dies." ######## Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api ######
Sentiment: Negative
### # """Tweet: "I hate it when my phone battery dies."
Tweet: "My day has been 👍" # Sentiment: Negative
Sentiment: Positive # ###
### # Tweet: "My day has been 👍"
Tweet: "This is the link to the article" # Sentiment: Positive
Sentiment: Neutral # ###
### # Tweet: "This is the link to the article"
Tweet: "This new music video was incredibile" # Sentiment: Neutral
Sentiment:""", # ###
"""Translate English to French: # Tweet: "This new music video was incredibile"
# Sentiment:""",
sea otter => loutre de mer
# """Translate English to French:
peppermint => menthe poivrée #
# sea otter => loutre de mer
plush girafe => girafe peluche #
# peppermint => menthe poivrée
cheese =>""", #
# plush girafe => girafe peluche
#
# cheese =>""",
] ]
start_time = time.time() start_time = time.time()
@ -113,7 +114,7 @@ cheese =>""",
for result in results: for result in results:
print(result) print(result)
print("\n==================================\n") print("\n==================================")
print(f"Inference took {time.time() - start_time:.2f} seconds") print(f"Inference took {time.time() - start_time:.2f} seconds")

Loading…
Cancel
Save