added progressbar

main
randaller 3 years ago committed by GitHub
parent e2f7dc6e5a
commit db429c6683
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,6 +4,7 @@
from typing import List from typing import List
import torch import torch
from tqdm import tqdm
from llama.tokenizer import Tokenizer from llama.tokenizer import Tokenizer
from llama.model import Transformer from llama.model import Transformer
@ -36,6 +37,9 @@ class LLaMA:
for k, t in enumerate(prompt_tokens): for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long() tokens[k, : len(t)] = torch.tensor(t).long()
input_text_mask = tokens != self.tokenizer.pad_id input_text_mask = tokens != self.tokenizer.pad_id
pbar = tqdm(total=total_len)
start_pos = min_prompt_size start_pos = min_prompt_size
prev_pos = 0 prev_pos = 0
for cur_pos in range(start_pos, total_len): for cur_pos in range(start_pos, total_len):
@ -53,6 +57,10 @@ class LLaMA:
tokens[:, cur_pos] = next_token tokens[:, cur_pos] = next_token
prev_pos = cur_pos prev_pos = cur_pos
pbar.update(1)
pbar.close()
decoded = [] decoded = []
for i, t in enumerate(tokens.tolist()): for i, t in enumerate(tokens.tolist()):
# cut to max gen len # cut to max gen len

Loading…
Cancel
Save