You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

112 lines
3.3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
from typing import Tuple
import os
import sys
import torch
import fire
import time
import json
import pyarrow as pa
from pathlib import Path
from llama import ModelArgs, Transformer, Tokenizer, LLaMA
def load(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
start_time = time.time()
arrow_dir = Path(ckpt_dir).expanduser() / 'arrow'
if not arrow_dir.exists():
print('Converting checkpoints to arrow format')
checkpoints = sorted(Path(ckpt_dir).expanduser().glob("*.pth"))
for ckpt_file in checkpoints:
print(ckpt_file)
index = ckpt_file.parts[-1].split('.')[-2]
ckpt = torch.load(ckpt_file, map_location='cpu')
(arrow_dir / index).mkdir(parents=True, exist_ok=True)
for k, v in ckpt.items():
tens = pa.Tensor.from_numpy(v.numpy())
with pa.output_stream(arrow_dir / index / k) as f:
pa.ipc.write_tensor(tens, f)
ckpt = None
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
print("Loading checkpoint")
segments = sorted((arrow_dir / '00').glob("*"))
# print(segments)
checkpoint = {}
files = []
for seg in segments:
f = pa.memory_map(str(seg))
files.append(f)
t = pa.ipc.read_tensor(f).to_numpy()
t = torch.from_numpy(t)
checkpoint[seg.parts[-1]] = t
# torch.set_default_tensor_type(torch.cuda.HalfTensor)
torch.set_default_tensor_type(torch.BFloat16Tensor)
# torch.set_default_tensor_type(torch.FloatTensor)
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
print("Loading tokenizer")
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
print("Loading model")
model = Transformer(model_args)
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
model.load_state_dict(torch.load(checkpoints[-1]), strict=False)
for f in files:
f.close()
files = None
generator = LLaMA(model, tokenizer)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return generator
def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 2048,
max_batch_size: int = 1, # 16 for 13B, 4 for 30B and 65B (for 512 seq)
):
generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)
prompts = [
# For these prompts, the expected answer is the natural continuation of the prompt
# "I believe the meaning of life is",
"""Write the Python code with detailed comments to generate 256 random integers in the range from -128 to 512, inclusive.
\\begin{code}\n""",
]
results = generator.generate(
prompts, max_gen_len=max_seq_len, temperature=temperature, top_p=top_p
)
for result in results:
print("\n==================================\n")
print(result)
print("\n==================================\n")
if __name__ == "__main__":
fire.Fire(main)