|
|
|
|
@ -206,9 +206,13 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
print!("{}", prompt.as_str().truecolor(128, 128, 255));
|
|
|
|
|
let _ = std::io::stdout().flush();
|
|
|
|
|
|
|
|
|
|
let mut first_token_time: std::time::Duration = std::time::Duration::new(0, 0);
|
|
|
|
|
let mut times_per_token: Vec<std::time::Duration> = vec![];
|
|
|
|
|
let mut caches = tr.make_caches();
|
|
|
|
|
let mut first: bool = true;
|
|
|
|
|
while toks_id.len() < max_seq_len {
|
|
|
|
|
let now = std::time::Instant::now();
|
|
|
|
|
|
|
|
|
|
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches);
|
|
|
|
|
let highest_pred_idx = token_sampler.sample(&preds);
|
|
|
|
|
toks_id.push(highest_pred_idx as TokenId);
|
|
|
|
|
@ -224,16 +228,33 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
} else {
|
|
|
|
|
tok_str += tok.replace('▁', " ").as_str();
|
|
|
|
|
}
|
|
|
|
|
if first && tok_idx < toks_id.len() - 1 {
|
|
|
|
|
if first && tok_idx < toks_id.len() - 2 {
|
|
|
|
|
// intentionally left empty
|
|
|
|
|
} else {
|
|
|
|
|
print!("{}", tok_str.truecolor(128, 255, 128));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if first {
|
|
|
|
|
first_token_time = now.elapsed();
|
|
|
|
|
} else {
|
|
|
|
|
times_per_token.push(now.elapsed());
|
|
|
|
|
}
|
|
|
|
|
let _ = std::io::stdout().flush();
|
|
|
|
|
prev_pos = toks_id.len() - 1;
|
|
|
|
|
first = false;
|
|
|
|
|
}
|
|
|
|
|
println!("");
|
|
|
|
|
if !be_quiet {
|
|
|
|
|
println!("---");
|
|
|
|
|
println!(
|
|
|
|
|
"Time taken to generate first token: {:?}ms",
|
|
|
|
|
first_token_time.as_millis()
|
|
|
|
|
);
|
|
|
|
|
println!(
|
|
|
|
|
"Time taken per token (excluding first token): {:?}ms",
|
|
|
|
|
times_per_token.iter().map(|t| t.as_millis()).sum::<u128>()
|
|
|
|
|
/ times_per_token.len() as u128
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|