From 968b4dcb4f3a200fb5268c064ec4108024414782 Mon Sep 17 00:00:00 2001 From: Matteo Croce Date: Thu, 9 Mar 2023 13:14:28 +0100 Subject: [PATCH] add bfloat16 chat --- example-chat-bfloat16.py | 74 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 example-chat-bfloat16.py diff --git a/example-chat-bfloat16.py b/example-chat-bfloat16.py new file mode 100644 index 0000000..ecb33da --- /dev/null +++ b/example-chat-bfloat16.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from typing import Tuple +import os +import sys +import torch +import fire +import time +import json +from pathlib import Path +from llama import ModelArgs, Transformer, Tokenizer, LLaMA + + +def load( + ckpt_dir: str, + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, +) -> LLaMA: + print("Creating model...") + start_time = time.time() + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params + ) + + tokenizer = Tokenizer(model_path=tokenizer_path) + model_args.vocab_size = tokenizer.n_words + + model = Transformer(model_args) + model.to("cpu") + + print("Loading merged checkpoint...") + checkpoint = torch.load(checkpoints[-1], map_location="cpu") + model.load_state_dict(checkpoint, strict=False) + del checkpoint + + generator = LLaMA(model, tokenizer) + print(f"Loaded model in {time.time() - start_time:.2f} seconds") + return generator + + +def main( + ckpt_dir: str = './model', + tokenizer_path: str = './tokenizer/tokenizer.model', + temperature: float = 0.8, + top_p: float = 0.95, + max_seq_len: int = 256, # up to 2048 + max_batch_size: int = 32, +): + # torch.manual_seed(1) + torch.set_default_dtype(torch.bfloat16) + + generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size) + + while True: + prompt = input(f'prompt> ') + if len(prompt.strip()) > 0: + prompts = [prompt] + results = generator.generate( + prompts, max_gen_len=256, temperature=temperature, top_p=top_p + ) + + for result in results: + print(result) + + +if __name__ == "__main__": + fire.Fire(main)