From d2487353a70ccdd063e1c8bbed55b63b14f1e847 Mon Sep 17 00:00:00 2001 From: randaller Date: Mon, 6 Mar 2023 12:39:27 +0300 Subject: [PATCH] Update example-chat.py --- example-chat.py | 45 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/example-chat.py b/example-chat.py index fe7e768..949864b 100644 --- a/example-chat.py +++ b/example-chat.py @@ -18,12 +18,10 @@ def load( max_seq_len: int, max_batch_size: int, ) -> LLaMA: + print("Creating model...") start_time = time.time() checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) - ckpt_path = checkpoints[-1] - print("Loading model...") - checkpoint = torch.load(ckpt_path, map_location="cpu") with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) @@ -32,12 +30,49 @@ def load( ) tokenizer = Tokenizer(model_path=tokenizer_path) - model_args.vocab_size = tokenizer.n_words model = Transformer(model_args) + + # Original copyright by tloen + # https://github.com/tloen/llama-int8/blob/main/example.py + key_to_dim = { + "w1": 0, + "w2": -1, + "w3": 0, + "wo": -1, + "wq": 0, + "wk": 0, + "wv": 0, + "output": 0, + "tok_embeddings": -1, + "ffn_norm": None, + "attention_norm": None, + "norm": None, + "rope": None, + } + + for i, ckpt in enumerate(checkpoints): + print(f"Loading checkpoint {i}") + checkpoint = torch.load(ckpt, map_location="cpu") + for parameter_name, parameter in model.named_parameters(): + short_name = parameter_name.split(".")[-2] + if key_to_dim[short_name] is None and i == 0: + parameter.data = checkpoint[parameter_name] + elif key_to_dim[short_name] == 0: + size = checkpoint[parameter_name].size(0) + parameter.data[size * i: size * (i + 1), :] = checkpoint[ + parameter_name + ] + elif key_to_dim[short_name] == -1: + size = checkpoint[parameter_name].size(-1) + parameter.data[:, size * i: size * (i + 1)] = checkpoint[ + parameter_name + ] + del checkpoint[parameter_name] + del checkpoint + model.to("cpu") - model.load_state_dict(checkpoint, strict=False) generator = LLaMA(model, tokenizer) print(f"Loaded model in {time.time() - start_time:.2f} seconds")