|
|
|
@ -32,7 +32,7 @@ class LLaMA:
|
|
|
|
|
|
|
|
|
|
|
|
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
|
|
|
|
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
|
|
|
|
|
|
|
|
|
|
|
|
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
|
|
|
|
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cpu().long()
|
|
|
|
for k, t in enumerate(prompt_tokens):
|
|
|
|
for k, t in enumerate(prompt_tokens):
|
|
|
|
tokens[k, : len(t)] = torch.tensor(t).long()
|
|
|
|
tokens[k, : len(t)] = torch.tensor(t).long()
|
|
|
|
input_text_mask = tokens != self.tokenizer.pad_id
|
|
|
|
input_text_mask = tokens != self.tokenizer.pad_id
|
|
|
|
|