diff --git a/llama/generation.py b/llama/generation.py index 40a551c..493a16b 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -4,6 +4,7 @@ from typing import List import torch +from tqdm import tqdm from llama.tokenizer import Tokenizer from llama.model import Transformer @@ -36,6 +37,9 @@ class LLaMA: for k, t in enumerate(prompt_tokens): tokens[k, : len(t)] = torch.tensor(t).long() input_text_mask = tokens != self.tokenizer.pad_id + + pbar = tqdm(total=total_len) + start_pos = min_prompt_size prev_pos = 0 for cur_pos in range(start_pos, total_len): @@ -53,6 +57,10 @@ class LLaMA: tokens[:, cur_pos] = next_token prev_pos = cur_pos + pbar.update(1) + + pbar.close() + decoded = [] for i, t in enumerate(tokens.tolist()): # cut to max gen len