From c9c861d199bd2d87d7e883e3087661c1e287f6c4 Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Mon, 13 Mar 2023 12:59:07 -0700 Subject: [PATCH] Add some measurements so we can get tokens per second. --- src/rllama_main.rs | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/rllama_main.rs b/src/rllama_main.rs index e6c14c4..2698fb1 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -206,9 +206,13 @@ pub fn main() -> Result<(), Box> { print!("{}", prompt.as_str().truecolor(128, 128, 255)); let _ = std::io::stdout().flush(); + let mut first_token_time: std::time::Duration = std::time::Duration::new(0, 0); + let mut times_per_token: Vec = vec![]; let mut caches = tr.make_caches(); let mut first: bool = true; while toks_id.len() < max_seq_len { + let now = std::time::Instant::now(); + let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches); let highest_pred_idx = token_sampler.sample(&preds); toks_id.push(highest_pred_idx as TokenId); @@ -224,16 +228,33 @@ pub fn main() -> Result<(), Box> { } else { tok_str += tok.replace('▁', " ").as_str(); } - if first && tok_idx < toks_id.len() - 1 { + if first && tok_idx < toks_id.len() - 2 { // intentionally left empty } else { print!("{}", tok_str.truecolor(128, 255, 128)); } } + if first { + first_token_time = now.elapsed(); + } else { + times_per_token.push(now.elapsed()); + } let _ = std::io::stdout().flush(); prev_pos = toks_id.len() - 1; first = false; } println!(""); + if !be_quiet { + println!("---"); + println!( + "Time taken to generate first token: {:?}ms", + first_token_time.as_millis() + ); + println!( + "Time taken per token (excluding first token): {:?}ms", + times_per_token.iter().map(|t| t.as_millis()).sum::() + / times_per_token.len() as u128 + ); + } Ok(()) }