Update generation.py

main
randaller 3 years ago committed by GitHub
parent c59ee12659
commit ba8df145cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
# Copyright by Shawn Presser
# https://github.com/shawwn/
# taken here
# https://github.com/shawwn/llama/commit/40d99d329a5e38d85904d3a6519c54e6dd6ee9e1
from typing import List
import torch
@ -22,6 +27,9 @@ class LLaMA:
max_gen_len: int,
temperature: float = 0.8,
top_p: float = 0.95,
top_k: int = 40,
repetition_penalty: float = (1.0 / 0.85),
sampler: str = 'top_k',
) -> List[str]:
bsz = len(prompts)
params = self.model.params
@ -44,10 +52,28 @@ class LLaMA:
start_pos = min_prompt_size
prev_pos = 0
decoded = [None] * bsz
for cur_pos in trange(start_pos, total_len, desc="forward"):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
logits_new = logits.clone()
batch_size = len(tokens)
for i in range(batch_size):
for token in set(tokens[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if logits[i, token] < 0:
logits_new[i, token] = logits[i, token] * repetition_penalty
else:
logits_new[i, token] = logits[i, token] / repetition_penalty
logits = logits_new
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
if sampler == 'top_k':
next_token = sample_top_k(probs, top_p=top_p, top_k=top_k)
else:
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
@ -73,7 +99,6 @@ class LLaMA:
pass # traceback.print_exc()
try:
d = self.tokenizer.decode(t)
print([i] * 20)
print(d)
decoded[i] = d
@ -88,6 +113,7 @@ class LLaMA:
return decoded
# default sampler
def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
@ -97,3 +123,19 @@ def sample_top_p(probs, p):
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
# sampler by Shawn
def sample_top_k(probs, top_p=0.0, top_k=40):
if top_k > 0:
probs_sort, probs_idx = torch.topk(probs, top_k)
else:
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
if top_p > 0.0:
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

Loading…
Cancel
Save