|
|
|
|
@ -4,6 +4,7 @@ use crate::tokenizer::{TokenId, Tokenizer};
|
|
|
|
|
use crate::transformer::Transformer;
|
|
|
|
|
use crate::unpickler;
|
|
|
|
|
use clap::Parser;
|
|
|
|
|
use colored::Colorize;
|
|
|
|
|
use std::io::{Read, Write};
|
|
|
|
|
|
|
|
|
|
#[derive(Parser)]
|
|
|
|
|
@ -31,13 +32,27 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
let model_path = cli.model_path;
|
|
|
|
|
let tokenizer_path = cli.tokenizer_path;
|
|
|
|
|
|
|
|
|
|
let mut be_quiet: bool = false;
|
|
|
|
|
if !colored::control::SHOULD_COLORIZE.should_colorize() {
|
|
|
|
|
be_quiet = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Custom println-like macro that respects be_quiet
|
|
|
|
|
macro_rules! pln {
|
|
|
|
|
($($arg:tt)*) => {
|
|
|
|
|
if !be_quiet {
|
|
|
|
|
std::println!($($arg)*);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let prompt: String = match (cli.prompt, cli.prompt_file) {
|
|
|
|
|
(Some(prompt), None) => {
|
|
|
|
|
println!("Using prompt: {}", prompt);
|
|
|
|
|
pln!("Using prompt: {}", prompt);
|
|
|
|
|
prompt
|
|
|
|
|
}
|
|
|
|
|
(None, Some(prompt_file)) => {
|
|
|
|
|
println!("Using prompt file: {}", prompt_file);
|
|
|
|
|
pln!("Using prompt file: {}", prompt_file);
|
|
|
|
|
let mut fs = std::fs::File::open(prompt_file)?;
|
|
|
|
|
let mut bs = Vec::new();
|
|
|
|
|
fs.read_to_end(&mut bs)?;
|
|
|
|
|
@ -45,14 +60,14 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
String::from_utf8(bs)?
|
|
|
|
|
}
|
|
|
|
|
_ => {
|
|
|
|
|
println!("Please provide either a prompt or a prompt file.");
|
|
|
|
|
return Ok(());
|
|
|
|
|
eprintln!("Please provide either a prompt or a prompt file.");
|
|
|
|
|
return Err("Please provide either a prompt or a prompt file.".into());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
println!("Starting up. Loading tokenizer from {}...", tokenizer_path);
|
|
|
|
|
pln!("Starting up. Loading tokenizer from {}...", tokenizer_path);
|
|
|
|
|
let tok = Tokenizer::load(tokenizer_path.as_str())?;
|
|
|
|
|
println!("Tokenizer loaded. Loading model from {}...", model_path);
|
|
|
|
|
pln!("Tokenizer loaded. Loading model from {}...", model_path);
|
|
|
|
|
let mut fs = std::fs::File::open(model_path.as_str())?;
|
|
|
|
|
let mut bs = Vec::new();
|
|
|
|
|
fs.read_to_end(&mut bs)?;
|
|
|
|
|
@ -66,25 +81,27 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
.join("/")
|
|
|
|
|
+ "/data/";
|
|
|
|
|
let result = unpickler::unpickle(&bs)?;
|
|
|
|
|
println!("Loading embeddings from {}...", model_data_dir);
|
|
|
|
|
pln!("Loading embeddings from {}...", model_data_dir);
|
|
|
|
|
let emb = Embedding::from_unpickled(&result, model_data_dir.clone())?;
|
|
|
|
|
|
|
|
|
|
println!("Loading transformer weights from {}...", model_data_dir);
|
|
|
|
|
let max_seq_len = 512;
|
|
|
|
|
|
|
|
|
|
pln!("Loading transformer weights from {}...", model_data_dir);
|
|
|
|
|
let tr = Transformer::from_unpickled(
|
|
|
|
|
&result,
|
|
|
|
|
emb,
|
|
|
|
|
4096,
|
|
|
|
|
32,
|
|
|
|
|
32,
|
|
|
|
|
512,
|
|
|
|
|
max_seq_len,
|
|
|
|
|
1e-6,
|
|
|
|
|
32,
|
|
|
|
|
128,
|
|
|
|
|
model_data_dir,
|
|
|
|
|
)?;
|
|
|
|
|
println!("All is loaded. Starting inference.");
|
|
|
|
|
pln!("All is loaded. Starting inference.");
|
|
|
|
|
|
|
|
|
|
let mut toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt);
|
|
|
|
|
let mut toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt.clone());
|
|
|
|
|
let mut prev_pos = 0;
|
|
|
|
|
let mut token_sampler = TokenSampler::new().temperature(0.8).top_p(0.9).top_k(50);
|
|
|
|
|
|
|
|
|
|
@ -98,30 +115,65 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
token_sampler = token_sampler.top_k(top_k as usize);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
println!("Temperature: {}", token_sampler.get_temperature());
|
|
|
|
|
println!("Top P: {}", token_sampler.get_top_p());
|
|
|
|
|
println!("Top K: {}", token_sampler.get_top_k());
|
|
|
|
|
pln!("---");
|
|
|
|
|
pln!("Temperature: {}", token_sampler.get_temperature());
|
|
|
|
|
pln!("Top P: {}", token_sampler.get_top_p());
|
|
|
|
|
pln!("Top K: {}", token_sampler.get_top_k());
|
|
|
|
|
pln!("---");
|
|
|
|
|
pln!(
|
|
|
|
|
"{}",
|
|
|
|
|
" This is the color of the initial prompt".truecolor(128, 128, 255)
|
|
|
|
|
);
|
|
|
|
|
pln!(
|
|
|
|
|
"{}",
|
|
|
|
|
" This is the color of the generated text while full context is available"
|
|
|
|
|
.truecolor(128, 255, 128)
|
|
|
|
|
);
|
|
|
|
|
pln!(
|
|
|
|
|
"{}",
|
|
|
|
|
" Remaining text is in this color".truecolor(255, 128, 128)
|
|
|
|
|
);
|
|
|
|
|
pln!("---");
|
|
|
|
|
print!("{}", prompt.as_str().truecolor(128, 128, 255));
|
|
|
|
|
let _ = std::io::stdout().flush();
|
|
|
|
|
|
|
|
|
|
let mut caches = tr.make_caches();
|
|
|
|
|
let mut first: bool = true;
|
|
|
|
|
let mut shifts: usize = 0;
|
|
|
|
|
loop {
|
|
|
|
|
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches);
|
|
|
|
|
if toks_id.len() >= max_seq_len {
|
|
|
|
|
toks_id = toks_id[1..].to_vec();
|
|
|
|
|
prev_pos -= 1;
|
|
|
|
|
caches.shift_left(1);
|
|
|
|
|
shifts += 1;
|
|
|
|
|
// TODO: it seems that text beyond context is just broken.
|
|
|
|
|
// Maybe I cannot just go and shift it.
|
|
|
|
|
}
|
|
|
|
|
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches, shifts);
|
|
|
|
|
let highest_pred_idx = token_sampler.sample(&preds);
|
|
|
|
|
toks_id.push(highest_pred_idx as TokenId);
|
|
|
|
|
|
|
|
|
|
let mut tok_str: String = "".to_string();
|
|
|
|
|
for tok_id in toks_id[prev_pos + 1..].iter() {
|
|
|
|
|
for (tok_idx, tok_id) in toks_id[prev_pos + 1..].iter().enumerate() {
|
|
|
|
|
if *tok_id == 1 {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
let mut tok_str: String = "".to_string();
|
|
|
|
|
let tok = tok.id_to_str(*tok_id);
|
|
|
|
|
if tok == "<0x0A>" {
|
|
|
|
|
tok_str += "\n";
|
|
|
|
|
} else {
|
|
|
|
|
tok_str += tok.replace('▁', " ").as_str();
|
|
|
|
|
}
|
|
|
|
|
if first && tok_idx < toks_id.len() - 1 {
|
|
|
|
|
// intentionally left empty
|
|
|
|
|
} else if shifts == 0 {
|
|
|
|
|
print!("{}", tok_str.truecolor(128, 255, 128));
|
|
|
|
|
} else {
|
|
|
|
|
print!("{}", tok_str.truecolor(255, 128, 128));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
print!("{}", tok_str);
|
|
|
|
|
let _ = std::io::stdout().flush();
|
|
|
|
|
prev_pos = toks_id.len() - 1;
|
|
|
|
|
first = false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|