From cf0da5ffae2bf96518b1610d450641273db47cb6 Mon Sep 17 00:00:00 2001 From: randaller Date: Mon, 6 Mar 2023 12:39:55 +0300 Subject: [PATCH] loading checkpoints --- example-cpu.py | 53 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/example-cpu.py b/example-cpu.py index ff37ef9..1c1225d 100644 --- a/example-cpu.py +++ b/example-cpu.py @@ -18,12 +18,10 @@ def load( max_seq_len: int, max_batch_size: int, ) -> LLaMA: + print("Creating model...") start_time = time.time() checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) - ckpt_path = checkpoints[-1] - print("Loading model...") - checkpoint = torch.load(ckpt_path, map_location="cpu") with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) @@ -32,12 +30,49 @@ def load( ) tokenizer = Tokenizer(model_path=tokenizer_path) - model_args.vocab_size = tokenizer.n_words model = Transformer(model_args) + + # Original copyright by tloen + # https://github.com/tloen/llama-int8/blob/main/example.py + key_to_dim = { + "w1": 0, + "w2": -1, + "w3": 0, + "wo": -1, + "wq": 0, + "wk": 0, + "wv": 0, + "output": 0, + "tok_embeddings": -1, + "ffn_norm": None, + "attention_norm": None, + "norm": None, + "rope": None, + } + + for i, ckpt in enumerate(checkpoints): + print(f"Loading checkpoint {i}") + checkpoint = torch.load(ckpt, map_location="cpu") + for parameter_name, parameter in model.named_parameters(): + short_name = parameter_name.split(".")[-2] + if key_to_dim[short_name] is None and i == 0: + parameter.data = checkpoint[parameter_name] + elif key_to_dim[short_name] == 0: + size = checkpoint[parameter_name].size(0) + parameter.data[size * i: size * (i + 1), :] = checkpoint[ + parameter_name + ] + elif key_to_dim[short_name] == -1: + size = checkpoint[parameter_name].size(-1) + parameter.data[:, size * i: size * (i + 1)] = checkpoint[ + parameter_name + ] + del checkpoint[parameter_name] + del checkpoint + model.to("cpu") - model.load_state_dict(checkpoint, strict=False) generator = LLaMA(model, tokenizer) print(f"Loaded model in {time.time() - start_time:.2f} seconds") @@ -49,7 +84,7 @@ def main( tokenizer_path: str = './tokenizer/tokenizer.model', temperature: float = 0.8, top_p: float = 0.95, - max_seq_len: int = 512, + max_seq_len: int = 512, # up to 2048 max_batch_size: int = 32, ): # torch.manual_seed(1) @@ -85,17 +120,13 @@ def main( # cheese =>""", ] - start_time = time.time() - results = generator.generate( prompts, max_gen_len=256, temperature=temperature, top_p=top_p ) for result in results: print(result) - print("\n==================================") - - print(f"Inference took {time.time() - start_time:.2f} seconds") + print("\n==================================\n") if __name__ == "__main__":