Update example-cpu.py

main
randaller 3 years ago committed by GitHub
parent 2ca6853f10
commit 97eee68d4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,32 +9,17 @@ import fire
import time
import json
from pathlib import Path
from llama import ModelArgs, Transformer, Tokenizer, LLaMA
def setup_model() -> Tuple[int, int]:
# local_rank = int(os.environ.get("LOCAL_RANK", -1))
# world_size = int(os.environ.get("WORLD_SIZE", -1))
local_rank = int(-1)
world_size = int(1)
# torch.manual_seed(1)
return local_rank, world_size
def load(
ckpt_dir: str,
tokenizer_path: str,
local_rank: int,
world_size: int,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert world_size == len(checkpoints), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
ckpt_path = checkpoints[local_rank]
ckpt_path = checkpoints[-1]
print("Loading models...")
checkpoint = torch.load(ckpt_path, map_location="cpu")
@ -57,7 +42,6 @@ def load(
print(f"Loaded models in {time.time() - start_time:.2f} seconds")
return generator
def main(
ckpt_dir: str = './model',
tokenizer_path: str = './tokenizer/tokenizer.model',
@ -66,22 +50,18 @@ def main(
max_seq_len: int = 512,
max_batch_size: int = 32,
):
local_rank, world_size = setup_model()
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
# torch.manual_seed(1)
generator = load(
ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
)
generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)
prompts = [
######## For these prompts, the expected answer is the natural continuation of the prompt #######
##### For these prompts, the expected answer is the natural continuation of the prompt #####
"I believe the meaning of life is",
# "Simply put, the theory of relativity states that ",
# "Building a website can be done in 10 simple steps:\n",
######## Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api ######
##### Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api #####
# """Tweet: "I hate it when my phone battery dies."
# Sentiment: Negative
@ -96,13 +76,9 @@ def main(
# Sentiment:""",
# """Translate English to French:
#
# sea otter => loutre de mer
#
# peppermint => menthe poivrée
#
# plush girafe => girafe peluche
#
# cheese =>""",
]
@ -118,6 +94,5 @@ def main(
print(f"Inference took {time.time() - start_time:.2f} seconds")
if __name__ == "__main__":
fire.Fire(main)

Loading…
Cancel
Save