From ba8df145cd45511ca6ac2728812ff4c3c1be74f4 Mon Sep 17 00:00:00 2001 From: randaller Date: Sat, 11 Mar 2023 13:55:28 +0300 Subject: [PATCH] Update generation.py --- llama/generation.py | 48 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/llama/generation.py b/llama/generation.py index 5ba68f5..20fcf81 100644 --- a/llama/generation.py +++ b/llama/generation.py @@ -1,6 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the GNU General Public License version 3. +# Copyright by Shawn Presser +# https://github.com/shawwn/ +# taken here +# https://github.com/shawwn/llama/commit/40d99d329a5e38d85904d3a6519c54e6dd6ee9e1 + from typing import List import torch @@ -22,6 +27,9 @@ class LLaMA: max_gen_len: int, temperature: float = 0.8, top_p: float = 0.95, + top_k: int = 40, + repetition_penalty: float = (1.0 / 0.85), + sampler: str = 'top_k', ) -> List[str]: bsz = len(prompts) params = self.model.params @@ -44,11 +52,29 @@ class LLaMA: start_pos = min_prompt_size prev_pos = 0 decoded = [None] * bsz + for cur_pos in trange(start_pos, total_len, desc="forward"): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + + # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + logits_new = logits.clone() + batch_size = len(tokens) + for i in range(batch_size): + for token in set(tokens[i].tolist()): + # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if logits[i, token] < 0: + logits_new[i, token] = logits[i, token] * repetition_penalty + else: + logits_new[i, token] = logits[i, token] / repetition_penalty + logits = logits_new + if temperature > 0: probs = torch.softmax(logits / temperature, dim=-1) - next_token = sample_top_p(probs, top_p) + if sampler == 'top_k': + next_token = sample_top_k(probs, top_p=top_p, top_k=top_k) + else: + next_token = sample_top_p(probs, top_p) else: next_token = torch.argmax(logits, dim=-1) next_token = next_token.reshape(-1).cpu() @@ -73,10 +99,9 @@ class LLaMA: pass # traceback.print_exc() try: d = self.tokenizer.decode(t) - print([i] * 20) print(d) decoded[i] = d - + result_count_newlines = d.count("\n") if result_count_newlines > count_newlines: return decoded @@ -88,6 +113,7 @@ class LLaMA: return decoded +# default sampler def sample_top_p(probs, p): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) @@ -97,3 +123,19 @@ def sample_top_p(probs, p): next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token + + +# sampler by Shawn +def sample_top_k(probs, top_p=0.0, top_k=40): + if top_k > 0: + probs_sort, probs_idx = torch.topk(probs, top_k) + else: + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + if top_p > 0.0: + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > top_p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token