|
|
|
|
@ -1,6 +1,11 @@
|
|
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
|
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
|
|
|
|
|
|
|
|
|
|
# Copyright by Shawn Presser
|
|
|
|
|
# https://github.com/shawwn/
|
|
|
|
|
# taken here
|
|
|
|
|
# https://github.com/shawwn/llama/commit/40d99d329a5e38d85904d3a6519c54e6dd6ee9e1
|
|
|
|
|
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
@ -22,6 +27,9 @@ class LLaMA:
|
|
|
|
|
max_gen_len: int,
|
|
|
|
|
temperature: float = 0.8,
|
|
|
|
|
top_p: float = 0.95,
|
|
|
|
|
top_k: int = 40,
|
|
|
|
|
repetition_penalty: float = (1.0 / 0.85),
|
|
|
|
|
sampler: str = 'top_k',
|
|
|
|
|
) -> List[str]:
|
|
|
|
|
bsz = len(prompts)
|
|
|
|
|
params = self.model.params
|
|
|
|
|
@ -44,11 +52,29 @@ class LLaMA:
|
|
|
|
|
start_pos = min_prompt_size
|
|
|
|
|
prev_pos = 0
|
|
|
|
|
decoded = [None] * bsz
|
|
|
|
|
|
|
|
|
|
for cur_pos in trange(start_pos, total_len, desc="forward"):
|
|
|
|
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
|
|
|
|
|
|
|
|
|
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
|
|
|
|
if repetition_penalty != 1.0:
|
|
|
|
|
logits_new = logits.clone()
|
|
|
|
|
batch_size = len(tokens)
|
|
|
|
|
for i in range(batch_size):
|
|
|
|
|
for token in set(tokens[i].tolist()):
|
|
|
|
|
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
|
|
|
|
if logits[i, token] < 0:
|
|
|
|
|
logits_new[i, token] = logits[i, token] * repetition_penalty
|
|
|
|
|
else:
|
|
|
|
|
logits_new[i, token] = logits[i, token] / repetition_penalty
|
|
|
|
|
logits = logits_new
|
|
|
|
|
|
|
|
|
|
if temperature > 0:
|
|
|
|
|
probs = torch.softmax(logits / temperature, dim=-1)
|
|
|
|
|
next_token = sample_top_p(probs, top_p)
|
|
|
|
|
if sampler == 'top_k':
|
|
|
|
|
next_token = sample_top_k(probs, top_p=top_p, top_k=top_k)
|
|
|
|
|
else:
|
|
|
|
|
next_token = sample_top_p(probs, top_p)
|
|
|
|
|
else:
|
|
|
|
|
next_token = torch.argmax(logits, dim=-1)
|
|
|
|
|
next_token = next_token.reshape(-1).cpu()
|
|
|
|
|
@ -73,10 +99,9 @@ class LLaMA:
|
|
|
|
|
pass # traceback.print_exc()
|
|
|
|
|
try:
|
|
|
|
|
d = self.tokenizer.decode(t)
|
|
|
|
|
print([i] * 20)
|
|
|
|
|
print(d)
|
|
|
|
|
decoded[i] = d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result_count_newlines = d.count("\n")
|
|
|
|
|
if result_count_newlines > count_newlines:
|
|
|
|
|
return decoded
|
|
|
|
|
@ -88,6 +113,7 @@ class LLaMA:
|
|
|
|
|
return decoded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# default sampler
|
|
|
|
|
def sample_top_p(probs, p):
|
|
|
|
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
|
|
|
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
|
|
|
@ -97,3 +123,19 @@ def sample_top_p(probs, p):
|
|
|
|
|
next_token = torch.multinomial(probs_sort, num_samples=1)
|
|
|
|
|
next_token = torch.gather(probs_idx, -1, next_token)
|
|
|
|
|
return next_token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# sampler by Shawn
|
|
|
|
|
def sample_top_k(probs, top_p=0.0, top_k=40):
|
|
|
|
|
if top_k > 0:
|
|
|
|
|
probs_sort, probs_idx = torch.topk(probs, top_k)
|
|
|
|
|
else:
|
|
|
|
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
|
|
|
|
if top_p > 0.0:
|
|
|
|
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
|
|
|
mask = probs_sum - probs_sort > top_p
|
|
|
|
|
probs_sort[mask] = 0.0
|
|
|
|
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
|
|
|
|
next_token = torch.multinomial(probs_sort, num_samples=1)
|
|
|
|
|
next_token = torch.gather(probs_idx, -1, next_token)
|
|
|
|
|
return next_token
|
|
|
|
|
|