|
|
|
|
@ -9,32 +9,17 @@ import fire
|
|
|
|
|
import time
|
|
|
|
|
import json
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
from llama import ModelArgs, Transformer, Tokenizer, LLaMA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_model() -> Tuple[int, int]:
|
|
|
|
|
# local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
|
|
|
|
# world_size = int(os.environ.get("WORLD_SIZE", -1))
|
|
|
|
|
local_rank = int(-1)
|
|
|
|
|
world_size = int(1)
|
|
|
|
|
|
|
|
|
|
# torch.manual_seed(1)
|
|
|
|
|
return local_rank, world_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(
|
|
|
|
|
ckpt_dir: str,
|
|
|
|
|
tokenizer_path: str,
|
|
|
|
|
local_rank: int,
|
|
|
|
|
world_size: int,
|
|
|
|
|
max_seq_len: int,
|
|
|
|
|
max_batch_size: int,
|
|
|
|
|
) -> LLaMA:
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
|
|
|
|
assert world_size == len(checkpoints), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
|
|
|
|
|
ckpt_path = checkpoints[local_rank]
|
|
|
|
|
ckpt_path = checkpoints[-1]
|
|
|
|
|
|
|
|
|
|
print("Loading models...")
|
|
|
|
|
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
|
|
|
|
@ -57,7 +42,6 @@ def load(
|
|
|
|
|
print(f"Loaded models in {time.time() - start_time:.2f} seconds")
|
|
|
|
|
return generator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(
|
|
|
|
|
ckpt_dir: str = './model',
|
|
|
|
|
tokenizer_path: str = './tokenizer/tokenizer.model',
|
|
|
|
|
@ -66,22 +50,18 @@ def main(
|
|
|
|
|
max_seq_len: int = 512,
|
|
|
|
|
max_batch_size: int = 32,
|
|
|
|
|
):
|
|
|
|
|
local_rank, world_size = setup_model()
|
|
|
|
|
if local_rank > 0:
|
|
|
|
|
sys.stdout = open(os.devnull, "w")
|
|
|
|
|
# torch.manual_seed(1)
|
|
|
|
|
|
|
|
|
|
generator = load(
|
|
|
|
|
ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
|
|
|
|
|
)
|
|
|
|
|
generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)
|
|
|
|
|
|
|
|
|
|
prompts = [
|
|
|
|
|
######## For these prompts, the expected answer is the natural continuation of the prompt #######
|
|
|
|
|
##### For these prompts, the expected answer is the natural continuation of the prompt #####
|
|
|
|
|
|
|
|
|
|
"I believe the meaning of life is",
|
|
|
|
|
# "Simply put, the theory of relativity states that ",
|
|
|
|
|
# "Building a website can be done in 10 simple steps:\n",
|
|
|
|
|
|
|
|
|
|
######## Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api ######
|
|
|
|
|
##### Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api #####
|
|
|
|
|
|
|
|
|
|
# """Tweet: "I hate it when my phone battery dies."
|
|
|
|
|
# Sentiment: Negative
|
|
|
|
|
@ -96,13 +76,9 @@ def main(
|
|
|
|
|
# Sentiment:""",
|
|
|
|
|
|
|
|
|
|
# """Translate English to French:
|
|
|
|
|
#
|
|
|
|
|
# sea otter => loutre de mer
|
|
|
|
|
#
|
|
|
|
|
# peppermint => menthe poivrée
|
|
|
|
|
#
|
|
|
|
|
# plush girafe => girafe peluche
|
|
|
|
|
#
|
|
|
|
|
# cheese =>""",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
@ -118,6 +94,5 @@ def main(
|
|
|
|
|
|
|
|
|
|
print(f"Inference took {time.time() - start_time:.2f} seconds")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
fire.Fire(main)
|
|
|
|
|
|