From 862d4a15d62e3fa8a87b39bb7aacef4890051388 Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Tue, 14 Mar 2023 00:40:08 -0700 Subject: [PATCH] Add repetition penalty, add colors to outputs based on probabilities, try to make softmax() more numerically stable. --- README.md | 6 ++++ src/rllama_main.rs | 32 ++++++++++++++--- src/token_sampler.rs | 83 +++++++++++++++++++++++++++++++++++++++----- 3 files changed, 108 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index d4595e2..3f4b99c 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,12 @@ cast to 32-bit floats. You can use `--temperature`, `--top-p` and `--top-k` to adjust token sampler settings. +There is `--repetition-penalty` setting. 1.0 means no penalty. This value +likely should be between 0 and 1. Values smaller than 1.0 give a penalty to +tokens that appear in the context, by +`x*(repetitition_penalty^num_occurrences)` before applying `softmax()` on the +output probabilities. Or in other words, values smaller than 1.0 apply penalty. + You can also use `--prompt-file` to read the prompt from a file instead from the command line. diff --git a/src/rllama_main.rs b/src/rllama_main.rs index 1c1e3fb..a3e3203 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -36,6 +36,8 @@ struct Cli { top_p: Option, #[arg(long)] top_k: Option, + #[arg(long)] + repetition_penalty: Option, #[cfg(feature = "opencl")] #[arg(long)] @@ -185,7 +187,11 @@ pub fn main() -> Result<(), Box> { let mut toks_id: Vec = tok.tokenize_to_ids(prompt.clone()); let mut prev_pos = 0; - let mut token_sampler = TokenSampler::new().temperature(0.8).top_p(0.9).top_k(50); + let mut token_sampler = TokenSampler::new() + .temperature(0.8) + .top_p(0.9) + .top_k(50) + .repetition_penalty(0.8); if let Some(temperature) = cli.temperature { token_sampler = token_sampler.temperature(temperature); @@ -196,6 +202,9 @@ pub fn main() -> Result<(), Box> { if let Some(top_k) = cli.top_k { token_sampler = token_sampler.top_k(top_k as usize); } + if let Some(repetition_penalty) = cli.repetition_penalty { + token_sampler = token_sampler.repetition_penalty(repetition_penalty); + } pln!("---"); pln!(" dim: {}", params.dim); @@ -209,6 +218,10 @@ pub fn main() -> Result<(), Box> { pln!("Temperature: {}", token_sampler.get_temperature()); pln!("Top P: {}", token_sampler.get_top_p()); pln!("Top K: {}", token_sampler.get_top_k()); + pln!( + "Repetition penalty: {}", + token_sampler.get_repetition_penalty() + ); pln!("---"); pln!( "{}", @@ -229,9 +242,9 @@ pub fn main() -> Result<(), Box> { let mut stop_seen: bool = false; while toks_id.len() < max_seq_len { let now = std::time::Instant::now(); - let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches); - let highest_pred_idx = token_sampler.sample(&preds); + + let (highest_pred_idx, token_prob) = token_sampler.sample(&preds, &tok, &toks_id); toks_id.push(highest_pred_idx as TokenId); for (tok_idx, tok_id) in toks_id[prev_pos + 1..].iter().enumerate() { @@ -252,7 +265,18 @@ pub fn main() -> Result<(), Box> { if first && tok_idx < toks_id.len() - 2 { // intentionally left empty } else { - print!("{}", tok_str.truecolor(128, 255, 128)); + let redness: f32 = token_prob * 255.0; + let redness = if redness > 255.0 { + 255 + } else if redness < 0.0 { + 0 + } else { + redness as u8 + }; + print!( + "{}", + tok_str.truecolor(128 + redness / 2, 255 - redness / 2, 128) + ); } } if first { diff --git a/src/token_sampler.rs b/src/token_sampler.rs index 21afe03..5426478 100644 --- a/src/token_sampler.rs +++ b/src/token_sampler.rs @@ -1,11 +1,13 @@ use crate::tensor::Tensor; -use crate::tokenizer::TokenId; +use crate::tokenizer::{TokenId, Tokenizer}; use rand::Rng; +use std::collections::BTreeMap; pub struct TokenSampler { temperature: f32, top_p: f32, top_k: usize, + repetition_penalty: f32, } impl Default for TokenSampler { @@ -17,9 +19,11 @@ impl Default for TokenSampler { impl TokenSampler { pub fn new() -> Self { Self { - temperature: 0.8, + temperature: 0.2, top_p: 1.0, top_k: 1, // same as argmax + repetition_penalty: 0.8, // 1.0 = no penalty. values above 1.0 make repetition + // encouraged which can quickly devolve into repeating loop } } @@ -35,6 +39,10 @@ impl TokenSampler { self.top_k } + pub fn get_repetition_penalty(&self) -> f32 { + self.repetition_penalty + } + pub fn temperature(self, temperature: f32) -> Self { Self { temperature, @@ -50,20 +58,77 @@ impl TokenSampler { Self { top_k, ..self } } - pub fn sample(&self, logits: &Tensor) -> TokenId { + pub fn repetition_penalty(self, repetition_penalty: f32) -> Self { + Self { + repetition_penalty, + ..self + } + } + + pub fn sample( + &self, + logits: &Tensor, + tokenizer: &Tokenizer, + existing_tokens: &[TokenId], + ) -> (TokenId, f32) { + let mut times_used: BTreeMap = BTreeMap::new(); + for token in existing_tokens { + times_used + .entry(*token) + .and_modify(|e| *e += 1) + .or_insert(1); + } + let nrows = logits.rows(); assert!(logits.cols() == 1); let mut logits = logits.transpose(); if self.temperature > 0.0 { logits = logits.scalar_multiply_f32(1.0 / self.temperature); - logits = logits.softmax(); } + if self.repetition_penalty != 1.0 { + for token_idx in 0..logits.rows() { + if let Some(count) = times_used.get(&(token_idx as TokenId)) { + let penalty = self.repetition_penalty.powf(*count as f32); + logits.set_f32(0, token_idx, logits.get_f32(0, token_idx) * penalty); + } + } + } + let mut maxv: f32 = std::f32::NEG_INFINITY; + for token_idx in 0..logits.rows() { + let v = logits.get_f32(0, token_idx); + if v > maxv { + maxv = v; + } + } + // To numerically stabilize, remove maxv from all logits + // softmax(x + c) = softmax(x) where c is a constant, and we make use of htat + for token_idx in 0..logits.rows() { + logits.set_f32(0, token_idx, logits.get_f32(0, token_idx) - maxv); + } + logits = logits.softmax(); + let mut logitsf: Vec<(TokenId, f32)> = Vec::with_capacity(nrows as usize); for i in 0..nrows { - logitsf.push((i as TokenId, logits.get_f32(0, i))); + let score = logits.get_f32(0, i); + logitsf.push((i as TokenId, score)); } - logitsf.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + logitsf.sort_unstable_by(|a, b| { + match b.1.partial_cmp(&a.1) { + Some(c) => c, + None => { + // Sort NaNs to bottom + if b.1.is_nan() { + return std::cmp::Ordering::Less; + } else if a.1.is_nan() { + return std::cmp::Ordering::Greater; + } else { + return std::cmp::Ordering::Equal; + } + } + } + }); + logitsf.truncate(self.top_k); let mut p_accum: f32 = 0.0; for (idx, v) in logitsf.iter().enumerate() { @@ -78,14 +143,14 @@ impl TokenSampler { total_p += v.1; } let mut rng = rand::thread_rng(); - let p: f32 = rng.gen_range(0.0..total_p); + let p: f32 = rng.gen_range(0.0..=total_p); p_accum = 0.0; for v in logitsf.into_iter() { p_accum += v.1; if p_accum >= p { - return v.0; + return (v.0, v.1 / total_p); } } - 0 + (0, 0.0) } }