diff --git a/llama/generation.py b/llama/generation.py index 493a16b..860134a 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -38,7 +38,7 @@ class LLaMA: tokens[k, : len(t)] = torch.tensor(t).long() input_text_mask = tokens != self.tokenizer.pad_id - pbar = tqdm(total=total_len) + pbar = tqdm(total=max_gen_len) start_pos = min_prompt_size prev_pos = 0