From cd28aba5e20ceb683bf336ff02973973a73274a5 Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Sat, 11 Mar 2023 03:03:50 -0800 Subject: [PATCH] Make the output look nicer. --- Cargo.lock | 12 ++++++++++++ Cargo.toml | 1 + src/rllama_main.rs | 15 ++++++++++----- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4e256e8..d3aaf80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -164,6 +164,17 @@ dependencies = [ "os_str_bytes", ] +[[package]] +name = "colored" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3616f750b84d8f0de8a58bda93e08e2a81ad3f523089b05f1dffecab48c6cbd" +dependencies = [ + "atty", + "lazy_static", + "winapi", +] + [[package]] name = "console" version = "0.15.5" @@ -778,6 +789,7 @@ version = "0.1.0" dependencies = [ "approx", "clap 4.1.8", + "colored", "criterion", "embedded-profiling", "half 2.2.1", diff --git a/Cargo.toml b/Cargo.toml index 9f24c01..ec1209f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ approx = "0.5" rayon = "1.7" clap = { version = "4.1", features = ["derive"] } indicatif = "0.17" +colored = "2" # We need protobuf compiler [build-dependencies] diff --git a/src/rllama_main.rs b/src/rllama_main.rs index 2fffd84..3e4df19 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -4,7 +4,7 @@ use crate::tokenizer::{TokenId, Tokenizer}; use crate::transformer::Transformer; use crate::unpickler; use clap::Parser; -use std::io::Read; +use std::io::{Read, Write}; #[derive(Parser)] #[command(author, version, about, long_about = None)] @@ -107,16 +107,21 @@ pub fn main() -> Result<(), Box> { let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches); let highest_pred_idx = token_sampler.sample(&preds); toks_id.push(highest_pred_idx as TokenId); - prev_pos = toks_id.len() - 1; let mut tok_str: String = "".to_string(); - for tok_id in toks_id.iter() { + for tok_id in toks_id[prev_pos + 1..].iter() { if *tok_id == 1 { continue; } let tok = tok.id_to_str(*tok_id); - tok_str += tok.replace('▁', " ").as_str(); + if tok == "<0x0A>" { + tok_str += "\n"; + } else { + tok_str += tok.replace('▁', " ").as_str(); + } } - println!("{}", tok_str); + print!("{}", tok_str); + let _ = std::io::stdout().flush(); + prev_pos = toks_id.len() - 1; } }