From 5b65d6507ff98c5db4e9fe4d8c88accc9a30de7b Mon Sep 17 00:00:00 2001 From: randaller Date: Thu, 9 Mar 2023 10:45:57 +0300 Subject: [PATCH] Add files via upload --- example-chat.py | 115 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 example-chat.py diff --git a/example-chat.py b/example-chat.py new file mode 100644 index 0000000..eaf4f83 --- /dev/null +++ b/example-chat.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from typing import Tuple +import os +import sys +import torch +import fire +import time +import json +import pyarrow as pa + +from pathlib import Path + +from llama import ModelArgs, Transformer, Tokenizer, LLaMA + + +def load( + ckpt_dir: str, + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, +) -> LLaMA: + start_time = time.time() + arrow_dir = Path(ckpt_dir).expanduser() / 'arrow' + + if not arrow_dir.exists(): + print('Converting checkpoints to arrow format') + checkpoints = sorted(Path(ckpt_dir).expanduser().glob("*.pth")) + for ckpt_file in checkpoints: + print(ckpt_file) + index = ckpt_file.parts[-1].split('.')[-2] + + ckpt = torch.load(ckpt_file, map_location='cpu') + (arrow_dir / index).mkdir(parents=True, exist_ok=True) + for k, v in ckpt.items(): + tens = pa.Tensor.from_numpy(v.numpy()) + with pa.output_stream(arrow_dir / index / k) as f: + pa.ipc.write_tensor(tens, f) + ckpt = None + + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + print("Loading checkpoint") + segments = sorted((arrow_dir / '00').glob("*")) + + checkpoint = {} + files = [] + for seg in segments: + f = pa.memory_map(str(seg)) + files.append(f) + t = pa.ipc.read_tensor(f).to_numpy() + t = torch.from_numpy(t) + checkpoint[seg.parts[-1]] = t + + # torch.set_default_tensor_type(torch.cuda.HalfTensor) + torch.set_default_tensor_type(torch.BFloat16Tensor) + # torch.set_default_tensor_type(torch.FloatTensor) + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params + ) + print("Loading tokenizer") + tokenizer = Tokenizer(model_path=tokenizer_path) + model_args.vocab_size = tokenizer.n_words + print("Loading model") + model = Transformer(model_args) + + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + model.load_state_dict(torch.load(checkpoints[-1]), strict=False) + + for f in files: + f.close() + files = None + + generator = LLaMA(model, tokenizer) + print(f"Loaded in {time.time() - start_time:.2f} seconds") + return generator + + +def main( + ckpt_dir: str, + tokenizer_path: str, + temperature: float = 0.8, + top_p: float = 0.95, + max_seq_len: int = 2048, + max_batch_size: int = 1, # 16 for 13B, 4 for 30B and 65B, 2 for 1024 seq_len for 30B +): + generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size) + + ctx = """A dialog, where User interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits. +User: Hello, AI. +AI: Hello! How can I assist you today? +""" + + while True: + prompt = input(f'User: ') + if ctx != "": + ctx = ctx + "User: " + prompt + "\n" + else: + ctx = prompt + "\n" + + ctx = (ctx[-1920:]) if len(ctx) >= 2048 else ctx + + if len(ctx.strip()) > 0: + prompts = [ctx] + results = generator.generate( + prompts, max_gen_len=2048, temperature=temperature, top_p=top_p + ) + ctx = results[0] + + +if __name__ == "__main__": + fire.Fire(main)