Create example.py

main
randaller 3 years ago committed by GitHub
parent cb395a25ae
commit a7d7801dbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
from typing import Tuple
import os
import sys
import torch
import fire
import time
import json
import pyarrow as pa
from pathlib import Path
from llama import ModelArgs, Transformer, Tokenizer, LLaMA
def load(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
start_time = time.time()
arrow_dir = Path(ckpt_dir).expanduser() / 'arrow'
if not arrow_dir.exists():
print('Converting checkpoints to arrow format')
checkpoints = sorted(Path(ckpt_dir).expanduser().glob("*.pth"))
for ckpt_file in checkpoints:
print(ckpt_file)
index = ckpt_file.parts[-1].split('.')[-2]
ckpt = torch.load(ckpt_file, map_location='cpu')
(arrow_dir / index).mkdir(parents=True, exist_ok=True)
for k, v in ckpt.items():
tens = pa.Tensor.from_numpy(v.numpy())
with pa.output_stream(arrow_dir / index / k) as f:
pa.ipc.write_tensor(tens, f)
ckpt = None
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
print("Loading checkpoint")
segments = sorted((arrow_dir / '00').glob("*"))
# print(segments)
checkpoint = {}
files = []
for seg in segments:
f = pa.memory_map(str(seg))
files.append(f)
t = pa.ipc.read_tensor(f).to_numpy()
t = torch.from_numpy(t)
checkpoint[seg.parts[-1]] = t
# torch.set_default_tensor_type(torch.cuda.HalfTensor)
torch.set_default_tensor_type(torch.BFloat16Tensor)
# torch.set_default_tensor_type(torch.FloatTensor)
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
print("Loading tokenizer")
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
print("Loading model")
model = Transformer(model_args)
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
model.load_state_dict(torch.load(checkpoints[-1]), strict=False)
for f in files:
f.close()
files = None
generator = LLaMA(model, tokenizer)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return generator
def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 2048,
max_batch_size: int = 1, # 16 for 13B, 4 for 30B and 65B (for 512 seq)
):
generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)
prompts = [
# For these prompts, the expected answer is the natural continuation of the prompt
# "I believe the meaning of life is",
"""Write the Python code with detailed comments to generate 256 random integers in the range from -128 to 512, inclusive.
\\begin{code}\n""",
]
results = generator.generate(
prompts, max_gen_len=max_seq_len, temperature=temperature, top_p=top_p
)
for result in results:
print("\n==================================\n")
print(result)
print("\n==================================\n")
if __name__ == "__main__":
fire.Fire(main)
Loading…
Cancel
Save