|
|
|
|
@ -18,12 +18,10 @@ def load(
|
|
|
|
|
max_seq_len: int,
|
|
|
|
|
max_batch_size: int,
|
|
|
|
|
) -> LLaMA:
|
|
|
|
|
print("Creating model...")
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
|
|
|
|
ckpt_path = checkpoints[-1]
|
|
|
|
|
|
|
|
|
|
print("Loading model...")
|
|
|
|
|
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
|
|
|
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
|
|
|
|
params = json.loads(f.read())
|
|
|
|
|
|
|
|
|
|
@ -32,12 +30,49 @@ def load(
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
tokenizer = Tokenizer(model_path=tokenizer_path)
|
|
|
|
|
|
|
|
|
|
model_args.vocab_size = tokenizer.n_words
|
|
|
|
|
|
|
|
|
|
model = Transformer(model_args)
|
|
|
|
|
|
|
|
|
|
# Original copyright by tloen
|
|
|
|
|
# https://github.com/tloen/llama-int8/blob/main/example.py
|
|
|
|
|
key_to_dim = {
|
|
|
|
|
"w1": 0,
|
|
|
|
|
"w2": -1,
|
|
|
|
|
"w3": 0,
|
|
|
|
|
"wo": -1,
|
|
|
|
|
"wq": 0,
|
|
|
|
|
"wk": 0,
|
|
|
|
|
"wv": 0,
|
|
|
|
|
"output": 0,
|
|
|
|
|
"tok_embeddings": -1,
|
|
|
|
|
"ffn_norm": None,
|
|
|
|
|
"attention_norm": None,
|
|
|
|
|
"norm": None,
|
|
|
|
|
"rope": None,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for i, ckpt in enumerate(checkpoints):
|
|
|
|
|
print(f"Loading checkpoint {i}")
|
|
|
|
|
checkpoint = torch.load(ckpt, map_location="cpu")
|
|
|
|
|
for parameter_name, parameter in model.named_parameters():
|
|
|
|
|
short_name = parameter_name.split(".")[-2]
|
|
|
|
|
if key_to_dim[short_name] is None and i == 0:
|
|
|
|
|
parameter.data = checkpoint[parameter_name]
|
|
|
|
|
elif key_to_dim[short_name] == 0:
|
|
|
|
|
size = checkpoint[parameter_name].size(0)
|
|
|
|
|
parameter.data[size * i: size * (i + 1), :] = checkpoint[
|
|
|
|
|
parameter_name
|
|
|
|
|
]
|
|
|
|
|
elif key_to_dim[short_name] == -1:
|
|
|
|
|
size = checkpoint[parameter_name].size(-1)
|
|
|
|
|
parameter.data[:, size * i: size * (i + 1)] = checkpoint[
|
|
|
|
|
parameter_name
|
|
|
|
|
]
|
|
|
|
|
del checkpoint[parameter_name]
|
|
|
|
|
del checkpoint
|
|
|
|
|
|
|
|
|
|
model.to("cpu")
|
|
|
|
|
model.load_state_dict(checkpoint, strict=False)
|
|
|
|
|
|
|
|
|
|
generator = LLaMA(model, tokenizer)
|
|
|
|
|
print(f"Loaded model in {time.time() - start_time:.2f} seconds")
|
|
|
|
|
@ -49,7 +84,7 @@ def main(
|
|
|
|
|
tokenizer_path: str = './tokenizer/tokenizer.model',
|
|
|
|
|
temperature: float = 0.8,
|
|
|
|
|
top_p: float = 0.95,
|
|
|
|
|
max_seq_len: int = 512,
|
|
|
|
|
max_seq_len: int = 512, # up to 2048
|
|
|
|
|
max_batch_size: int = 32,
|
|
|
|
|
):
|
|
|
|
|
# torch.manual_seed(1)
|
|
|
|
|
@ -85,17 +120,13 @@ def main(
|
|
|
|
|
# cheese =>""",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
results = generator.generate(
|
|
|
|
|
prompts, max_gen_len=256, temperature=temperature, top_p=top_p
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for result in results:
|
|
|
|
|
print(result)
|
|
|
|
|
print("\n==================================")
|
|
|
|
|
|
|
|
|
|
print(f"Inference took {time.time() - start_time:.2f} seconds")
|
|
|
|
|
print("\n==================================\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|