|
|
|
|
@ -4,6 +4,7 @@
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
from llama.tokenizer import Tokenizer
|
|
|
|
|
from llama.model import Transformer
|
|
|
|
|
@ -36,6 +37,9 @@ class LLaMA:
|
|
|
|
|
for k, t in enumerate(prompt_tokens):
|
|
|
|
|
tokens[k, : len(t)] = torch.tensor(t).long()
|
|
|
|
|
input_text_mask = tokens != self.tokenizer.pad_id
|
|
|
|
|
|
|
|
|
|
pbar = tqdm(total=total_len)
|
|
|
|
|
|
|
|
|
|
start_pos = min_prompt_size
|
|
|
|
|
prev_pos = 0
|
|
|
|
|
for cur_pos in range(start_pos, total_len):
|
|
|
|
|
@ -53,6 +57,10 @@ class LLaMA:
|
|
|
|
|
tokens[:, cur_pos] = next_token
|
|
|
|
|
prev_pos = cur_pos
|
|
|
|
|
|
|
|
|
|
pbar.update(1)
|
|
|
|
|
|
|
|
|
|
pbar.close()
|
|
|
|
|
|
|
|
|
|
decoded = []
|
|
|
|
|
for i, t in enumerate(tokens.tolist()):
|
|
|
|
|
# cut to max gen len
|
|
|
|
|
|