run on cpu

main
randaller 3 years ago committed by GitHub
parent b998e985a6
commit e04d724890
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,11 +15,11 @@ class LLaMA:
self.tokenizer = tokenizer
def generate(
self,
prompts: List[str],
max_gen_len: int,
temperature: float = 0.8,
top_p: float = 0.95,
self,
prompts: List[str],
max_gen_len: int,
temperature: float = 0.8,
top_p: float = 0.95,
) -> List[str]:
bsz = len(prompts)
params = self.model.params
@ -32,7 +32,7 @@ class LLaMA:
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cpu().long()
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long()
input_text_mask = tokens != self.tokenizer.pad_id

Loading…
Cancel
Save