Add repetition penalty, add colors to outputs based on probabilities, try to make softmax() more numerically stable.

master
Mikko Juola 3 years ago
parent f4629ca987
commit 862d4a15d6

@ -56,6 +56,12 @@ cast to 32-bit floats.
You can use `--temperature`, `--top-p` and `--top-k` to adjust token sampler
settings.
There is `--repetition-penalty` setting. 1.0 means no penalty. This value
likely should be between 0 and 1. Values smaller than 1.0 give a penalty to
tokens that appear in the context, by
`x*(repetitition_penalty^num_occurrences)` before applying `softmax()` on the
output probabilities. Or in other words, values smaller than 1.0 apply penalty.
You can also use `--prompt-file` to read the prompt from a file instead from
the command line.

@ -36,6 +36,8 @@ struct Cli {
top_p: Option<f32>,
#[arg(long)]
top_k: Option<i32>,
#[arg(long)]
repetition_penalty: Option<f32>,
#[cfg(feature = "opencl")]
#[arg(long)]
@ -185,7 +187,11 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt.clone());
let mut prev_pos = 0;
let mut token_sampler = TokenSampler::new().temperature(0.8).top_p(0.9).top_k(50);
let mut token_sampler = TokenSampler::new()
.temperature(0.8)
.top_p(0.9)
.top_k(50)
.repetition_penalty(0.8);
if let Some(temperature) = cli.temperature {
token_sampler = token_sampler.temperature(temperature);
@ -196,6 +202,9 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
if let Some(top_k) = cli.top_k {
token_sampler = token_sampler.top_k(top_k as usize);
}
if let Some(repetition_penalty) = cli.repetition_penalty {
token_sampler = token_sampler.repetition_penalty(repetition_penalty);
}
pln!("---");
pln!(" dim: {}", params.dim);
@ -209,6 +218,10 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
pln!("Temperature: {}", token_sampler.get_temperature());
pln!("Top P: {}", token_sampler.get_top_p());
pln!("Top K: {}", token_sampler.get_top_k());
pln!(
"Repetition penalty: {}",
token_sampler.get_repetition_penalty()
);
pln!("---");
pln!(
"{}",
@ -229,9 +242,9 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut stop_seen: bool = false;
while toks_id.len() < max_seq_len {
let now = std::time::Instant::now();
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches);
let highest_pred_idx = token_sampler.sample(&preds);
let (highest_pred_idx, token_prob) = token_sampler.sample(&preds, &tok, &toks_id);
toks_id.push(highest_pred_idx as TokenId);
for (tok_idx, tok_id) in toks_id[prev_pos + 1..].iter().enumerate() {
@ -252,7 +265,18 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
if first && tok_idx < toks_id.len() - 2 {
// intentionally left empty
} else {
print!("{}", tok_str.truecolor(128, 255, 128));
let redness: f32 = token_prob * 255.0;
let redness = if redness > 255.0 {
255
} else if redness < 0.0 {
0
} else {
redness as u8
};
print!(
"{}",
tok_str.truecolor(128 + redness / 2, 255 - redness / 2, 128)
);
}
}
if first {

@ -1,11 +1,13 @@
use crate::tensor::Tensor;
use crate::tokenizer::TokenId;
use crate::tokenizer::{TokenId, Tokenizer};
use rand::Rng;
use std::collections::BTreeMap;
pub struct TokenSampler {
temperature: f32,
top_p: f32,
top_k: usize,
repetition_penalty: f32,
}
impl Default for TokenSampler {
@ -17,9 +19,11 @@ impl Default for TokenSampler {
impl TokenSampler {
pub fn new() -> Self {
Self {
temperature: 0.8,
temperature: 0.2,
top_p: 1.0,
top_k: 1, // same as argmax
repetition_penalty: 0.8, // 1.0 = no penalty. values above 1.0 make repetition
// encouraged which can quickly devolve into repeating loop
}
}
@ -35,6 +39,10 @@ impl TokenSampler {
self.top_k
}
pub fn get_repetition_penalty(&self) -> f32 {
self.repetition_penalty
}
pub fn temperature(self, temperature: f32) -> Self {
Self {
temperature,
@ -50,20 +58,77 @@ impl TokenSampler {
Self { top_k, ..self }
}
pub fn sample(&self, logits: &Tensor) -> TokenId {
pub fn repetition_penalty(self, repetition_penalty: f32) -> Self {
Self {
repetition_penalty,
..self
}
}
pub fn sample(
&self,
logits: &Tensor,
tokenizer: &Tokenizer,
existing_tokens: &[TokenId],
) -> (TokenId, f32) {
let mut times_used: BTreeMap<TokenId, usize> = BTreeMap::new();
for token in existing_tokens {
times_used
.entry(*token)
.and_modify(|e| *e += 1)
.or_insert(1);
}
let nrows = logits.rows();
assert!(logits.cols() == 1);
let mut logits = logits.transpose();
if self.temperature > 0.0 {
logits = logits.scalar_multiply_f32(1.0 / self.temperature);
logits = logits.softmax();
}
if self.repetition_penalty != 1.0 {
for token_idx in 0..logits.rows() {
if let Some(count) = times_used.get(&(token_idx as TokenId)) {
let penalty = self.repetition_penalty.powf(*count as f32);
logits.set_f32(0, token_idx, logits.get_f32(0, token_idx) * penalty);
}
}
}
let mut maxv: f32 = std::f32::NEG_INFINITY;
for token_idx in 0..logits.rows() {
let v = logits.get_f32(0, token_idx);
if v > maxv {
maxv = v;
}
}
// To numerically stabilize, remove maxv from all logits
// softmax(x + c) = softmax(x) where c is a constant, and we make use of htat
for token_idx in 0..logits.rows() {
logits.set_f32(0, token_idx, logits.get_f32(0, token_idx) - maxv);
}
logits = logits.softmax();
let mut logitsf: Vec<(TokenId, f32)> = Vec::with_capacity(nrows as usize);
for i in 0..nrows {
logitsf.push((i as TokenId, logits.get_f32(0, i)));
let score = logits.get_f32(0, i);
logitsf.push((i as TokenId, score));
}
logitsf.sort_unstable_by(|a, b| {
match b.1.partial_cmp(&a.1) {
Some(c) => c,
None => {
// Sort NaNs to bottom
if b.1.is_nan() {
return std::cmp::Ordering::Less;
} else if a.1.is_nan() {
return std::cmp::Ordering::Greater;
} else {
return std::cmp::Ordering::Equal;
}
}
logitsf.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
}
});
logitsf.truncate(self.top_k);
let mut p_accum: f32 = 0.0;
for (idx, v) in logitsf.iter().enumerate() {
@ -78,14 +143,14 @@ impl TokenSampler {
total_p += v.1;
}
let mut rng = rand::thread_rng();
let p: f32 = rng.gen_range(0.0..total_p);
let p: f32 = rng.gen_range(0.0..=total_p);
p_accum = 0.0;
for v in logitsf.into_iter() {
p_accum += v.1;
if p_accum >= p {
return v.0;
return (v.0, v.1 / total_p);
}
}
0
(0, 0.0)
}
}

Loading…
Cancel
Save