|
|
|
|
@ -11,6 +11,7 @@ import json
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from llama import ModelArgs, Transformer, Tokenizer, LLaMA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(
|
|
|
|
|
ckpt_dir: str,
|
|
|
|
|
tokenizer_path: str,
|
|
|
|
|
@ -42,6 +43,7 @@ def load(
|
|
|
|
|
print(f"Loaded models in {time.time() - start_time:.2f} seconds")
|
|
|
|
|
return generator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(
|
|
|
|
|
ckpt_dir: str = './model',
|
|
|
|
|
tokenizer_path: str = './tokenizer/tokenizer.model',
|
|
|
|
|
@ -94,5 +96,6 @@ def main(
|
|
|
|
|
|
|
|
|
|
print(f"Inference took {time.time() - start_time:.2f} seconds")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
fire.Fire(main)
|
|
|
|
|
|