|
|
|
@ -19,8 +19,7 @@ def setup_model() -> Tuple[int, int]:
|
|
|
|
local_rank = int(-1)
|
|
|
|
local_rank = int(-1)
|
|
|
|
world_size = int(1)
|
|
|
|
world_size = int(1)
|
|
|
|
|
|
|
|
|
|
|
|
# seed must be the same in all processes
|
|
|
|
# torch.manual_seed(1)
|
|
|
|
torch.manual_seed(1)
|
|
|
|
|
|
|
|
return local_rank, world_size
|
|
|
|
return local_rank, world_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -45,15 +44,13 @@ def load(
|
|
|
|
model_args: ModelArgs = ModelArgs(
|
|
|
|
model_args: ModelArgs = ModelArgs(
|
|
|
|
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
|
|
|
|
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = Tokenizer(model_path=tokenizer_path)
|
|
|
|
tokenizer = Tokenizer(model_path=tokenizer_path)
|
|
|
|
|
|
|
|
|
|
|
|
model_args.vocab_size = tokenizer.n_words
|
|
|
|
model_args.vocab_size = tokenizer.n_words
|
|
|
|
|
|
|
|
|
|
|
|
torch.set_default_tensor_type(torch.FloatTensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = Transformer(model_args)
|
|
|
|
model = Transformer(model_args)
|
|
|
|
model.to("cpu")
|
|
|
|
model.to("cpu")
|
|
|
|
torch.set_default_tensor_type(torch.FloatTensor)
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint, strict=False)
|
|
|
|
model.load_state_dict(checkpoint, strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
generator = LLaMA(model, tokenizer)
|
|
|
|
generator = LLaMA(model, tokenizer)
|
|
|
|
@ -78,31 +75,35 @@ def main(
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
prompts = [
|
|
|
|
prompts = [
|
|
|
|
# For these prompts, the expected answer is the natural continuation of the prompt
|
|
|
|
######## For these prompts, the expected answer is the natural continuation of the prompt #######
|
|
|
|
|
|
|
|
|
|
|
|
"I believe the meaning of life is",
|
|
|
|
"I believe the meaning of life is",
|
|
|
|
"Simply put, the theory of relativity states that ",
|
|
|
|
# "Simply put, the theory of relativity states that ",
|
|
|
|
"Building a website can be done in 10 simple steps:\n",
|
|
|
|
# "Building a website can be done in 10 simple steps:\n",
|
|
|
|
# Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api
|
|
|
|
|
|
|
|
"""Tweet: "I hate it when my phone battery dies."
|
|
|
|
######## Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api ######
|
|
|
|
Sentiment: Negative
|
|
|
|
|
|
|
|
###
|
|
|
|
# """Tweet: "I hate it when my phone battery dies."
|
|
|
|
Tweet: "My day has been 👍"
|
|
|
|
# Sentiment: Negative
|
|
|
|
Sentiment: Positive
|
|
|
|
# ###
|
|
|
|
###
|
|
|
|
# Tweet: "My day has been 👍"
|
|
|
|
Tweet: "This is the link to the article"
|
|
|
|
# Sentiment: Positive
|
|
|
|
Sentiment: Neutral
|
|
|
|
# ###
|
|
|
|
###
|
|
|
|
# Tweet: "This is the link to the article"
|
|
|
|
Tweet: "This new music video was incredibile"
|
|
|
|
# Sentiment: Neutral
|
|
|
|
Sentiment:""",
|
|
|
|
# ###
|
|
|
|
"""Translate English to French:
|
|
|
|
# Tweet: "This new music video was incredibile"
|
|
|
|
|
|
|
|
# Sentiment:""",
|
|
|
|
sea otter => loutre de mer
|
|
|
|
|
|
|
|
|
|
|
|
# """Translate English to French:
|
|
|
|
peppermint => menthe poivrée
|
|
|
|
#
|
|
|
|
|
|
|
|
# sea otter => loutre de mer
|
|
|
|
plush girafe => girafe peluche
|
|
|
|
#
|
|
|
|
|
|
|
|
# peppermint => menthe poivrée
|
|
|
|
cheese =>""",
|
|
|
|
#
|
|
|
|
|
|
|
|
# plush girafe => girafe peluche
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
# cheese =>""",
|
|
|
|
]
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
start_time = time.time()
|
|
|
|
@ -113,7 +114,7 @@ cheese =>""",
|
|
|
|
|
|
|
|
|
|
|
|
for result in results:
|
|
|
|
for result in results:
|
|
|
|
print(result)
|
|
|
|
print(result)
|
|
|
|
print("\n==================================\n")
|
|
|
|
print("\n==================================")
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Inference took {time.time() - start_time:.2f} seconds")
|
|
|
|
print(f"Inference took {time.time() - start_time:.2f} seconds")
|
|
|
|
|
|
|
|
|
|
|
|
|