You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

145 lines
5.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
# Copyright by Steve Manuatu
# https://github.com/venuatu
# Copyright by Shawn Presser
# https://github.com/shawwn
# taken here
# https://github.com/shawwn/llama/commit/40d99d329a5e38d85904d3a6519c54e6dd6ee9e1
from typing import List
import torch
import traceback
from llama.tokenizer import Tokenizer
from llama.model import Transformer
from tqdm import trange
class LLaMA:
def __init__(self, model: Transformer, tokenizer: Tokenizer):
self.model = model
self.tokenizer = tokenizer
def generate(
self,
prompts: List[str],
max_gen_len: int,
temperature: float = 0.8,
top_p: float = 0.95,
top_k: int = 40,
repetition_penalty: float = (1.0 / 0.85),
sampler: str = 'top_k',
) -> List[str]:
bsz = len(prompts)
params = self.model.params
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
count_newlines = prompts[0].count("\n")
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
min_prompt_size = min([len(t) for t in prompt_tokens])
max_prompt_size = max([len(t) for t in prompt_tokens])
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).long()
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long()
tokens[k, -1] = self.tokenizer.eos_id
input_text_mask = tokens != self.tokenizer.pad_id
start_pos = min_prompt_size
prev_pos = 0
decoded = [None] * bsz
for cur_pos in trange(start_pos, total_len, desc="forward"):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
logits_new = logits.clone()
batch_size = len(tokens)
for i in range(batch_size):
for token in set(tokens[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if logits[i, token] < 0:
logits_new[i, token] = logits[i, token] * repetition_penalty
else:
logits_new[i, token] = logits[i, token] / repetition_penalty
logits = logits_new
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
if sampler == 'top_k':
next_token = sample_top_k(probs, top_p=top_p, top_k=top_k)
else:
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1).cpu()
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
prev_pos = cur_pos
print("-" * 30)
for i, t in enumerate(tokens.tolist()):
# i = cur_pos
# t = next_token
# cut to max gen len
# t = t[: len(pr-ompt_tokens[i]) + max_gen_len]
t = t[: min(cur_pos, len(prompt_tokens[i]) + max_gen_len)]
# cut to eos tok if any
try:
t = t[: t.index(self.tokenizer.eos_id)]
except ValueError:
pass # traceback.print_exc()
try:
d = self.tokenizer.decode(t)
print(d)
decoded[i] = d
result_count_newlines = d.count("\n")
if result_count_newlines > count_newlines:
return decoded
except IndexError:
traceback.print_exc()
print(t)
print("-" * 30)
return decoded
# default sampler
def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
# sampler by Shawn
def sample_top_k(probs, top_p=0.0, top_k=40):
if top_k > 0:
probs_sort, probs_idx = torch.topk(probs, top_k)
else:
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
if top_p > 0.0:
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token