From 8b0513bf7cbed21ecfbe622456d9327aaafc1865 Mon Sep 17 00:00:00 2001 From: randaller Date: Mon, 6 Mar 2023 16:30:36 +0300 Subject: [PATCH] Update generation.py --- llama/generation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llama/generation.py b/llama/generation.py index 860134a..555d4b6 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -38,10 +38,12 @@ class LLaMA: tokens[k, : len(t)] = torch.tensor(t).long() input_text_mask = tokens != self.tokenizer.pad_id - pbar = tqdm(total=max_gen_len) - start_pos = min_prompt_size prev_pos = 0 + + steps = total_len - start_pos + pbar = tqdm(total=steps) + for cur_pos in range(start_pos, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) if temperature > 0: