run on cpu

main
randaller 3 years ago committed by GitHub
parent b998e985a6
commit e04d724890
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,11 +15,11 @@ class LLaMA:
self.tokenizer = tokenizer self.tokenizer = tokenizer
def generate( def generate(
self, self,
prompts: List[str], prompts: List[str],
max_gen_len: int, max_gen_len: int,
temperature: float = 0.8, temperature: float = 0.8,
top_p: float = 0.95, top_p: float = 0.95,
) -> List[str]: ) -> List[str]:
bsz = len(prompts) bsz = len(prompts)
params = self.model.params params = self.model.params
@ -32,7 +32,7 @@ class LLaMA:
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long() tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cpu().long()
for k, t in enumerate(prompt_tokens): for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long() tokens[k, : len(t)] = torch.tensor(t).long()
input_text_mask = tokens != self.tokenizer.pad_id input_text_mask = tokens != self.tokenizer.pad_id

Loading…
Cancel
Save