|
|
|
|
@ -15,11 +15,11 @@ class LLaMA:
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
|
|
|
|
def generate(
|
|
|
|
|
self,
|
|
|
|
|
prompts: List[str],
|
|
|
|
|
max_gen_len: int,
|
|
|
|
|
temperature: float = 0.8,
|
|
|
|
|
top_p: float = 0.95,
|
|
|
|
|
self,
|
|
|
|
|
prompts: List[str],
|
|
|
|
|
max_gen_len: int,
|
|
|
|
|
temperature: float = 0.8,
|
|
|
|
|
top_p: float = 0.95,
|
|
|
|
|
) -> List[str]:
|
|
|
|
|
bsz = len(prompts)
|
|
|
|
|
params = self.model.params
|
|
|
|
|
@ -32,7 +32,7 @@ class LLaMA:
|
|
|
|
|
|
|
|
|
|
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
|
|
|
|
|
|
|
|
|
|
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
|
|
|
|
|
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cpu().long()
|
|
|
|
|
for k, t in enumerate(prompt_tokens):
|
|
|
|
|
tokens[k, : len(t)] = torch.tensor(t).long()
|
|
|
|
|
input_text_mask = tokens != self.tokenizer.pad_id
|
|
|
|
|
|