added progressbar

main
randaller 3 years ago committed by GitHub
parent e2f7dc6e5a
commit db429c6683
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,6 +4,7 @@
from typing import List
import torch
from tqdm import tqdm
from llama.tokenizer import Tokenizer
from llama.model import Transformer
@ -36,6 +37,9 @@ class LLaMA:
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long()
input_text_mask = tokens != self.tokenizer.pad_id
pbar = tqdm(total=total_len)
start_pos = min_prompt_size
prev_pos = 0
for cur_pos in range(start_pos, total_len):
@ -53,6 +57,10 @@ class LLaMA:
tokens[:, cur_pos] = next_token
prev_pos = cur_pos
pbar.update(1)
pbar.close()
decoded = []
for i, t in enumerate(tokens.tolist()):
# cut to max gen len

Loading…
Cancel
Save