Update example-cpu.py

main
randaller 3 years ago committed by GitHub
parent 2ca6853f10
commit 97eee68d4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,32 +9,17 @@ import fire
import time import time
import json import json
from pathlib import Path from pathlib import Path
from llama import ModelArgs, Transformer, Tokenizer, LLaMA from llama import ModelArgs, Transformer, Tokenizer, LLaMA
def setup_model() -> Tuple[int, int]:
# local_rank = int(os.environ.get("LOCAL_RANK", -1))
# world_size = int(os.environ.get("WORLD_SIZE", -1))
local_rank = int(-1)
world_size = int(1)
# torch.manual_seed(1)
return local_rank, world_size
def load( def load(
ckpt_dir: str, ckpt_dir: str,
tokenizer_path: str, tokenizer_path: str,
local_rank: int,
world_size: int,
max_seq_len: int, max_seq_len: int,
max_batch_size: int, max_batch_size: int,
) -> LLaMA: ) -> LLaMA:
start_time = time.time() start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert world_size == len(checkpoints), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" ckpt_path = checkpoints[-1]
ckpt_path = checkpoints[local_rank]
print("Loading models...") print("Loading models...")
checkpoint = torch.load(ckpt_path, map_location="cpu") checkpoint = torch.load(ckpt_path, map_location="cpu")
@ -57,7 +42,6 @@ def load(
print(f"Loaded models in {time.time() - start_time:.2f} seconds") print(f"Loaded models in {time.time() - start_time:.2f} seconds")
return generator return generator
def main( def main(
ckpt_dir: str = './model', ckpt_dir: str = './model',
tokenizer_path: str = './tokenizer/tokenizer.model', tokenizer_path: str = './tokenizer/tokenizer.model',
@ -66,22 +50,18 @@ def main(
max_seq_len: int = 512, max_seq_len: int = 512,
max_batch_size: int = 32, max_batch_size: int = 32,
): ):
local_rank, world_size = setup_model() # torch.manual_seed(1)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
generator = load( generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)
ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
)
prompts = [ prompts = [
######## For these prompts, the expected answer is the natural continuation of the prompt ####### ##### For these prompts, the expected answer is the natural continuation of the prompt #####
"I believe the meaning of life is", "I believe the meaning of life is",
# "Simply put, the theory of relativity states that ", # "Simply put, the theory of relativity states that ",
# "Building a website can be done in 10 simple steps:\n", # "Building a website can be done in 10 simple steps:\n",
######## Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api ###### ##### Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api #####
# """Tweet: "I hate it when my phone battery dies." # """Tweet: "I hate it when my phone battery dies."
# Sentiment: Negative # Sentiment: Negative
@ -96,13 +76,9 @@ def main(
# Sentiment:""", # Sentiment:""",
# """Translate English to French: # """Translate English to French:
#
# sea otter => loutre de mer # sea otter => loutre de mer
#
# peppermint => menthe poivrée # peppermint => menthe poivrée
#
# plush girafe => girafe peluche # plush girafe => girafe peluche
#
# cheese =>""", # cheese =>""",
] ]
@ -118,6 +94,5 @@ def main(
print(f"Inference took {time.time() - start_time:.2f} seconds") print(f"Inference took {time.time() - start_time:.2f} seconds")
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(main) fire.Fire(main)

Loading…
Cancel
Save