Make the output look nicer.

broken-opencl-code
Mikko Juola 3 years ago
parent d7a3f57510
commit cd28aba5e2

12
Cargo.lock generated

@ -164,6 +164,17 @@ dependencies = [
"os_str_bytes", "os_str_bytes",
] ]
[[package]]
name = "colored"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3616f750b84d8f0de8a58bda93e08e2a81ad3f523089b05f1dffecab48c6cbd"
dependencies = [
"atty",
"lazy_static",
"winapi",
]
[[package]] [[package]]
name = "console" name = "console"
version = "0.15.5" version = "0.15.5"
@ -778,6 +789,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"approx", "approx",
"clap 4.1.8", "clap 4.1.8",
"colored",
"criterion", "criterion",
"embedded-profiling", "embedded-profiling",
"half 2.2.1", "half 2.2.1",

@ -21,6 +21,7 @@ approx = "0.5"
rayon = "1.7" rayon = "1.7"
clap = { version = "4.1", features = ["derive"] } clap = { version = "4.1", features = ["derive"] }
indicatif = "0.17" indicatif = "0.17"
colored = "2"
# We need protobuf compiler # We need protobuf compiler
[build-dependencies] [build-dependencies]

@ -4,7 +4,7 @@ use crate::tokenizer::{TokenId, Tokenizer};
use crate::transformer::Transformer; use crate::transformer::Transformer;
use crate::unpickler; use crate::unpickler;
use clap::Parser; use clap::Parser;
use std::io::Read; use std::io::{Read, Write};
#[derive(Parser)] #[derive(Parser)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
@ -107,16 +107,21 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches); let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches);
let highest_pred_idx = token_sampler.sample(&preds); let highest_pred_idx = token_sampler.sample(&preds);
toks_id.push(highest_pred_idx as TokenId); toks_id.push(highest_pred_idx as TokenId);
prev_pos = toks_id.len() - 1;
let mut tok_str: String = "".to_string(); let mut tok_str: String = "".to_string();
for tok_id in toks_id.iter() { for tok_id in toks_id[prev_pos + 1..].iter() {
if *tok_id == 1 { if *tok_id == 1 {
continue; continue;
} }
let tok = tok.id_to_str(*tok_id); let tok = tok.id_to_str(*tok_id);
tok_str += tok.replace('▁', " ").as_str(); if tok == "<0x0A>" {
tok_str += "\n";
} else {
tok_str += tok.replace('▁', " ").as_str();
}
} }
println!("{}", tok_str); print!("{}", tok_str);
let _ = std::io::stdout().flush();
prev_pos = toks_id.len() - 1;
} }
} }

Loading…
Cancel
Save