@ -38,10 +38,12 @@ class LLaMA:
tokens[k, : len(t)] = torch.tensor(t).long()
input_text_mask = tokens != self.tokenizer.pad_id
pbar = tqdm(total=max_gen_len)
start_pos = min_prompt_size
prev_pos = 0
steps = total_len - start_pos
pbar = tqdm(total=steps)
for cur_pos in range(start_pos, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0: