Update generation.py

main
randaller 3 years ago committed by GitHub
parent 753d3fbb98
commit 8b0513bf7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -38,10 +38,12 @@ class LLaMA:
tokens[k, : len(t)] = torch.tensor(t).long() tokens[k, : len(t)] = torch.tensor(t).long()
input_text_mask = tokens != self.tokenizer.pad_id input_text_mask = tokens != self.tokenizer.pad_id
pbar = tqdm(total=max_gen_len)
start_pos = min_prompt_size start_pos = min_prompt_size
prev_pos = 0 prev_pos = 0
steps = total_len - start_pos
pbar = tqdm(total=steps)
for cur_pos in range(start_pos, total_len): for cur_pos in range(start_pos, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0: if temperature > 0:

Loading…
Cancel
Save