loading checkpoints

main
randaller 3 years ago committed by GitHub
parent d2487353a7
commit cf0da5ffae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,12 +18,10 @@ def load(
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
print("Creating model...")
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
ckpt_path = checkpoints[-1]
print("Loading model...")
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
@ -32,12 +30,49 @@ def load(
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
model = Transformer(model_args)
# Original copyright by tloen
# https://github.com/tloen/llama-int8/blob/main/example.py
key_to_dim = {
"w1": 0,
"w2": -1,
"w3": 0,
"wo": -1,
"wq": 0,
"wk": 0,
"wv": 0,
"output": 0,
"tok_embeddings": -1,
"ffn_norm": None,
"attention_norm": None,
"norm": None,
"rope": None,
}
for i, ckpt in enumerate(checkpoints):
print(f"Loading checkpoint {i}")
checkpoint = torch.load(ckpt, map_location="cpu")
for parameter_name, parameter in model.named_parameters():
short_name = parameter_name.split(".")[-2]
if key_to_dim[short_name] is None and i == 0:
parameter.data = checkpoint[parameter_name]
elif key_to_dim[short_name] == 0:
size = checkpoint[parameter_name].size(0)
parameter.data[size * i: size * (i + 1), :] = checkpoint[
parameter_name
]
elif key_to_dim[short_name] == -1:
size = checkpoint[parameter_name].size(-1)
parameter.data[:, size * i: size * (i + 1)] = checkpoint[
parameter_name
]
del checkpoint[parameter_name]
del checkpoint
model.to("cpu")
model.load_state_dict(checkpoint, strict=False)
generator = LLaMA(model, tokenizer)
print(f"Loaded model in {time.time() - start_time:.2f} seconds")
@ -49,7 +84,7 @@ def main(
tokenizer_path: str = './tokenizer/tokenizer.model',
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 512,
max_seq_len: int = 512, # up to 2048
max_batch_size: int = 32,
):
# torch.manual_seed(1)
@ -85,17 +120,13 @@ def main(
# cheese =>""",
]
start_time = time.time()
results = generator.generate(
prompts, max_gen_len=256, temperature=temperature, top_p=top_p
)
for result in results:
print(result)
print("\n==================================")
print(f"Inference took {time.time() - start_time:.2f} seconds")
print("\n==================================\n")
if __name__ == "__main__":

Loading…
Cancel
Save