diff --git a/example-cpu.py b/example-cpu.py index f03db54..a71e718 100644 --- a/example-cpu.py +++ b/example-cpu.py @@ -9,32 +9,17 @@ import fire import time import json from pathlib import Path - from llama import ModelArgs, Transformer, Tokenizer, LLaMA - -def setup_model() -> Tuple[int, int]: - # local_rank = int(os.environ.get("LOCAL_RANK", -1)) - # world_size = int(os.environ.get("WORLD_SIZE", -1)) - local_rank = int(-1) - world_size = int(1) - - # torch.manual_seed(1) - return local_rank, world_size - - def load( ckpt_dir: str, tokenizer_path: str, - local_rank: int, - world_size: int, max_seq_len: int, max_batch_size: int, ) -> LLaMA: start_time = time.time() checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) - assert world_size == len(checkpoints), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" - ckpt_path = checkpoints[local_rank] + ckpt_path = checkpoints[-1] print("Loading models...") checkpoint = torch.load(ckpt_path, map_location="cpu") @@ -57,7 +42,6 @@ def load( print(f"Loaded models in {time.time() - start_time:.2f} seconds") return generator - def main( ckpt_dir: str = './model', tokenizer_path: str = './tokenizer/tokenizer.model', @@ -66,22 +50,18 @@ def main( max_seq_len: int = 512, max_batch_size: int = 32, ): - local_rank, world_size = setup_model() - if local_rank > 0: - sys.stdout = open(os.devnull, "w") + # torch.manual_seed(1) - generator = load( - ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size - ) + generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size) prompts = [ - ######## For these prompts, the expected answer is the natural continuation of the prompt ####### + ##### For these prompts, the expected answer is the natural continuation of the prompt ##### "I believe the meaning of life is", # "Simply put, the theory of relativity states that ", # "Building a website can be done in 10 simple steps:\n", - ######## Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api ###### + ##### Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api ##### # """Tweet: "I hate it when my phone battery dies." # Sentiment: Negative @@ -96,13 +76,9 @@ def main( # Sentiment:""", # """Translate English to French: - # # sea otter => loutre de mer - # # peppermint => menthe poivrée - # # plush girafe => girafe peluche - # # cheese =>""", ] @@ -118,6 +94,5 @@ def main( print(f"Inference took {time.time() - start_time:.2f} seconds") - if __name__ == "__main__": fire.Fire(main)