Update example-chat.py

main
randaller 3 years ago committed by GitHub
parent b844fd72fb
commit d2487353a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,12 +18,10 @@ def load(
max_seq_len: int, max_seq_len: int,
max_batch_size: int, max_batch_size: int,
) -> LLaMA: ) -> LLaMA:
print("Creating model...")
start_time = time.time() start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
ckpt_path = checkpoints[-1]
print("Loading model...")
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f: with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read()) params = json.loads(f.read())
@ -32,12 +30,49 @@ def load(
) )
tokenizer = Tokenizer(model_path=tokenizer_path) tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words model_args.vocab_size = tokenizer.n_words
model = Transformer(model_args) model = Transformer(model_args)
# Original copyright by tloen
# https://github.com/tloen/llama-int8/blob/main/example.py
key_to_dim = {
"w1": 0,
"w2": -1,
"w3": 0,
"wo": -1,
"wq": 0,
"wk": 0,
"wv": 0,
"output": 0,
"tok_embeddings": -1,
"ffn_norm": None,
"attention_norm": None,
"norm": None,
"rope": None,
}
for i, ckpt in enumerate(checkpoints):
print(f"Loading checkpoint {i}")
checkpoint = torch.load(ckpt, map_location="cpu")
for parameter_name, parameter in model.named_parameters():
short_name = parameter_name.split(".")[-2]
if key_to_dim[short_name] is None and i == 0:
parameter.data = checkpoint[parameter_name]
elif key_to_dim[short_name] == 0:
size = checkpoint[parameter_name].size(0)
parameter.data[size * i: size * (i + 1), :] = checkpoint[
parameter_name
]
elif key_to_dim[short_name] == -1:
size = checkpoint[parameter_name].size(-1)
parameter.data[:, size * i: size * (i + 1)] = checkpoint[
parameter_name
]
del checkpoint[parameter_name]
del checkpoint
model.to("cpu") model.to("cpu")
model.load_state_dict(checkpoint, strict=False)
generator = LLaMA(model, tokenizer) generator = LLaMA(model, tokenizer)
print(f"Loaded model in {time.time() - start_time:.2f} seconds") print(f"Loaded model in {time.time() - start_time:.2f} seconds")

Loading…
Cancel
Save