From 5ab0a330738c34903e22d47af4e47e57280c4dad Mon Sep 17 00:00:00 2001 From: novarobot <88540431+novarobot@users.noreply.github.com> Date: Thu, 23 Mar 2023 18:16:07 +0100 Subject: [PATCH] Create example-multi.py --- example-multi.py | 98 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 example-multi.py diff --git a/example-multi.py b/example-multi.py new file mode 100644 index 0000000..f1e65b8 --- /dev/null +++ b/example-multi.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from typing import Tuple +import os +import sys +import torch +import fire +import time +import json + +from pathlib import Path + +from fairscale.nn.model_parallel.initialize import initialize_model_parallel + +from llama import ModelArgs, Transformer, Tokenizer, LLaMA + + +def setup_model_parallel() -> Tuple[int, int]: + local_rank = int(os.environ.get("LOCAL_RANK", -1)) + world_size = int(os.environ.get("WORLD_SIZE", -1)) + + torch.distributed.init_process_group("gloo") + initialize_model_parallel(world_size) + print('Setup parallel complete!') + # torch.cuda.set_device(local_rank) + + # seed must be the same in all processes + torch.manual_seed(1) + return local_rank, world_size + + +def load( + ckpt_dir: str, + tokenizer_path: str, + local_rank: int, + world_size: int, + max_seq_len: int, + max_batch_size: int, +) -> LLaMA: + start_time = time.time() + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + assert world_size == len( + checkpoints + ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" + ckpt_path = checkpoints[local_rank] + print("Loading") + checkpoint = torch.load(ckpt_path, map_location="cpu") + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params + ) + tokenizer = Tokenizer(model_path=tokenizer_path) + model_args.vocab_size = tokenizer.n_words + # torch.set_default_tensor_type(torch.cuda.HalfTensor) + torch.set_default_tensor_type(torch.BFloat16Tensor) + model = Transformer(model_args) + torch.set_default_tensor_type(torch.FloatTensor) + model.load_state_dict(checkpoint, strict=False) + + generator = LLaMA(model, tokenizer) + print(f"Loaded in {time.time() - start_time:.2f} seconds") + return generator + + +def main( + ckpt_dir: str, + tokenizer_path: str, + temperature: float = 0.8, + top_p: float = 0.95, + max_seq_len: int = 512, + max_batch_size: int = 32, +): + local_rank, world_size = setup_model_parallel() + if local_rank > 0: + sys.stdout = open(os.devnull, "w") + + generator = load( + ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size + ) + + prompts = ["I believe the meaning of life is"] + # results = generator.generate( + # prompts, max_gen_len=256, temperature=temperature, top_p=top_p + # ) + results = generator.generate( + prompts, max_gen_len=512, temperature=temperature, top_p=top_p + ) + + for result in results: + print(result) + print("\n==================================\n") + + +if __name__ == "__main__": + fire.Fire(main)