diff --git a/llama/generation.py b/llama/generation.py index 3abd3ed..40a551c 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -15,11 +15,11 @@ class LLaMA: self.tokenizer = tokenizer def generate( - self, - prompts: List[str], - max_gen_len: int, - temperature: float = 0.8, - top_p: float = 0.95, + self, + prompts: List[str], + max_gen_len: int, + temperature: float = 0.8, + top_p: float = 0.95, ) -> List[str]: bsz = len(prompts) params = self.model.params @@ -32,7 +32,7 @@ class LLaMA: total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) - tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long() + tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cpu().long() for k, t in enumerate(prompt_tokens): tokens[k, : len(t)] = torch.tensor(t).long() input_text_mask = tokens != self.tokenizer.pad_id