|
|
|
|
@ -5,6 +5,7 @@ use crate::transformer::Transformer;
|
|
|
|
|
use crate::unpickler;
|
|
|
|
|
use clap::Parser;
|
|
|
|
|
use colored::Colorize;
|
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
|
use std::io::{Read, Write};
|
|
|
|
|
|
|
|
|
|
#[derive(Parser)]
|
|
|
|
|
@ -14,11 +15,17 @@ struct Cli {
|
|
|
|
|
model_path: String,
|
|
|
|
|
#[arg(long)]
|
|
|
|
|
tokenizer_path: String,
|
|
|
|
|
#[arg(long)]
|
|
|
|
|
param_path: String,
|
|
|
|
|
|
|
|
|
|
#[arg(long)]
|
|
|
|
|
prompt: Option<String>,
|
|
|
|
|
#[arg(long)]
|
|
|
|
|
prompt_file: Option<String>,
|
|
|
|
|
|
|
|
|
|
#[arg(long)]
|
|
|
|
|
max_seq_len: Option<usize>,
|
|
|
|
|
|
|
|
|
|
#[arg(long)]
|
|
|
|
|
temperature: Option<f32>,
|
|
|
|
|
#[arg(long)]
|
|
|
|
|
@ -27,10 +34,21 @@ struct Cli {
|
|
|
|
|
top_k: Option<i32>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Clone, Serialize, Deserialize)]
|
|
|
|
|
struct ModelParams {
|
|
|
|
|
dim: usize,
|
|
|
|
|
multiple_of: usize,
|
|
|
|
|
n_heads: usize,
|
|
|
|
|
n_layers: usize,
|
|
|
|
|
norm_eps: f64,
|
|
|
|
|
vocab_size: i64,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
let cli = Cli::parse();
|
|
|
|
|
let model_path = cli.model_path;
|
|
|
|
|
let tokenizer_path = cli.tokenizer_path;
|
|
|
|
|
let param_path = cli.param_path;
|
|
|
|
|
|
|
|
|
|
let mut be_quiet: bool = false;
|
|
|
|
|
if !colored::control::SHOULD_COLORIZE.should_colorize() {
|
|
|
|
|
@ -46,6 +64,14 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Read ModelParams from param_path, we expect it to be JSON
|
|
|
|
|
let mut fs = std::fs::File::open(¶m_path)?;
|
|
|
|
|
let mut bs = Vec::new();
|
|
|
|
|
fs.read_to_end(&mut bs)?;
|
|
|
|
|
std::mem::drop(fs);
|
|
|
|
|
let params: ModelParams = serde_json::from_slice(&bs)?;
|
|
|
|
|
pln!("Loaded model parameters from {}.", param_path);
|
|
|
|
|
|
|
|
|
|
let prompt: String = match (cli.prompt, cli.prompt_file) {
|
|
|
|
|
(Some(prompt), None) => {
|
|
|
|
|
pln!("Using prompt: {}", prompt);
|
|
|
|
|
@ -84,19 +110,20 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
pln!("Loading embeddings from {}...", model_data_dir);
|
|
|
|
|
let emb = Embedding::from_unpickled(&result, model_data_dir.clone())?;
|
|
|
|
|
|
|
|
|
|
let max_seq_len = 512;
|
|
|
|
|
let max_seq_len = match cli.max_seq_len {
|
|
|
|
|
Some(max_seq_len) => max_seq_len,
|
|
|
|
|
None => 1024,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
pln!("Loading transformer weights from {}...", model_data_dir);
|
|
|
|
|
let tr = Transformer::from_unpickled(
|
|
|
|
|
&result,
|
|
|
|
|
emb,
|
|
|
|
|
4096,
|
|
|
|
|
32,
|
|
|
|
|
32,
|
|
|
|
|
params.dim,
|
|
|
|
|
params.n_layers,
|
|
|
|
|
params.n_heads,
|
|
|
|
|
max_seq_len,
|
|
|
|
|
1e-6,
|
|
|
|
|
32,
|
|
|
|
|
128,
|
|
|
|
|
params.norm_eps,
|
|
|
|
|
model_data_dir,
|
|
|
|
|
)?;
|
|
|
|
|
pln!("All is loaded. Starting inference.");
|
|
|
|
|
@ -116,6 +143,14 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pln!("---");
|
|
|
|
|
pln!(" dim: {}", params.dim);
|
|
|
|
|
pln!(" multiple_of: {}", params.multiple_of);
|
|
|
|
|
pln!(" n_heads: {}", params.n_heads);
|
|
|
|
|
pln!(" n_layers: {}", params.n_layers);
|
|
|
|
|
pln!(" norm_eps: {}", params.norm_eps);
|
|
|
|
|
pln!(" vocab_size: {}", params.vocab_size);
|
|
|
|
|
pln!("---");
|
|
|
|
|
pln!("Max sequence length: {}", max_seq_len);
|
|
|
|
|
pln!("Temperature: {}", token_sampler.get_temperature());
|
|
|
|
|
pln!("Top P: {}", token_sampler.get_top_p());
|
|
|
|
|
pln!("Top K: {}", token_sampler.get_top_k());
|
|
|
|
|
@ -126,12 +161,7 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
);
|
|
|
|
|
pln!(
|
|
|
|
|
"{}",
|
|
|
|
|
" This is the color of the generated text while full context is available"
|
|
|
|
|
.truecolor(128, 255, 128)
|
|
|
|
|
);
|
|
|
|
|
pln!(
|
|
|
|
|
"{}",
|
|
|
|
|
" Remaining text is in this color".truecolor(255, 128, 128)
|
|
|
|
|
" This is the color of the generated text".truecolor(128, 255, 128)
|
|
|
|
|
);
|
|
|
|
|
pln!("---");
|
|
|
|
|
print!("{}", prompt.as_str().truecolor(128, 128, 255));
|
|
|
|
|
@ -139,17 +169,8 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
|
|
|
|
|
let mut caches = tr.make_caches();
|
|
|
|
|
let mut first: bool = true;
|
|
|
|
|
let mut shifts: usize = 0;
|
|
|
|
|
loop {
|
|
|
|
|
if toks_id.len() >= max_seq_len {
|
|
|
|
|
toks_id = toks_id[1..].to_vec();
|
|
|
|
|
prev_pos -= 1;
|
|
|
|
|
caches.shift_left(1);
|
|
|
|
|
shifts += 1;
|
|
|
|
|
// TODO: it seems that text beyond context is just broken.
|
|
|
|
|
// Maybe I cannot just go and shift it.
|
|
|
|
|
}
|
|
|
|
|
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches, shifts);
|
|
|
|
|
while toks_id.len() < max_seq_len {
|
|
|
|
|
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches);
|
|
|
|
|
let highest_pred_idx = token_sampler.sample(&preds);
|
|
|
|
|
toks_id.push(highest_pred_idx as TokenId);
|
|
|
|
|
|
|
|
|
|
@ -166,14 +187,14 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
}
|
|
|
|
|
if first && tok_idx < toks_id.len() - 1 {
|
|
|
|
|
// intentionally left empty
|
|
|
|
|
} else if shifts == 0 {
|
|
|
|
|
print!("{}", tok_str.truecolor(128, 255, 128));
|
|
|
|
|
} else {
|
|
|
|
|
print!("{}", tok_str.truecolor(255, 128, 128));
|
|
|
|
|
print!("{}", tok_str.truecolor(128, 255, 128));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
let _ = std::io::stdout().flush();
|
|
|
|
|
prev_pos = toks_id.len() - 1;
|
|
|
|
|
first = false;
|
|
|
|
|
}
|
|
|
|
|
println!("");
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|