|
|
|
@ -38,10 +38,12 @@ class LLaMA:
|
|
|
|
tokens[k, : len(t)] = torch.tensor(t).long()
|
|
|
|
tokens[k, : len(t)] = torch.tensor(t).long()
|
|
|
|
input_text_mask = tokens != self.tokenizer.pad_id
|
|
|
|
input_text_mask = tokens != self.tokenizer.pad_id
|
|
|
|
|
|
|
|
|
|
|
|
pbar = tqdm(total=max_gen_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_pos = min_prompt_size
|
|
|
|
start_pos = min_prompt_size
|
|
|
|
prev_pos = 0
|
|
|
|
prev_pos = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
steps = total_len - start_pos
|
|
|
|
|
|
|
|
pbar = tqdm(total=steps)
|
|
|
|
|
|
|
|
|
|
|
|
for cur_pos in range(start_pos, total_len):
|
|
|
|
for cur_pos in range(start_pos, total_len):
|
|
|
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
|
|
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
|
|
|
if temperature > 0:
|
|
|
|
if temperature > 0:
|
|
|
|
|