@ -38,7 +38,7 @@ class LLaMA:
tokens[k, : len(t)] = torch.tensor(t).long()
input_text_mask = tokens != self.tokenizer.pad_id
pbar = tqdm(total=total_len)
pbar = tqdm(total=max_gen_len)
start_pos = min_prompt_size
prev_pos = 0