diff --git a/llama/generation.py b/llama/generation.py index 860134a..555d4b6 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -38,10 +38,12 @@ class LLaMA: tokens[k, : len(t)] = torch.tensor(t).long() input_text_mask = tokens != self.tokenizer.pad_id - pbar = tqdm(total=max_gen_len) - start_pos = min_prompt_size prev_pos = 0 + + steps = total_len - start_pos + pbar = tqdm(total=steps) + for cur_pos in range(start_pos, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) if temperature > 0: