You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
rllama/src/rllama_main.rs

128 lines
3.8 KiB
Rust

use crate::embedding::Embedding;
use crate::token_sampler::TokenSampler;
use crate::tokenizer::{TokenId, Tokenizer};
use crate::transformer::Transformer;
use crate::unpickler;
use clap::Parser;
use std::io::{Read, Write};
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
#[arg(long)]
model_path: String,
#[arg(long)]
tokenizer_path: String,
#[arg(long)]
prompt: Option<String>,
#[arg(long)]
prompt_file: Option<String>,
#[arg(long)]
temperature: Option<f32>,
#[arg(long)]
top_p: Option<f32>,
#[arg(long)]
top_k: Option<i32>,
}
pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse();
let model_path = cli.model_path;
let tokenizer_path = cli.tokenizer_path;
let prompt: String = match (cli.prompt, cli.prompt_file) {
(Some(prompt), None) => {
println!("Using prompt: {}", prompt);
prompt
}
(None, Some(prompt_file)) => {
println!("Using prompt file: {}", prompt_file);
let mut fs = std::fs::File::open(prompt_file)?;
let mut bs = Vec::new();
fs.read_to_end(&mut bs)?;
std::mem::drop(fs);
String::from_utf8(bs)?
}
_ => {
println!("Please provide either a prompt or a prompt file.");
return Ok(());
}
};
println!("Starting up. Loading tokenizer from {}...", tokenizer_path);
let tok = Tokenizer::load(tokenizer_path.as_str())?;
println!("Tokenizer loaded. Loading model from {}...", model_path);
let mut fs = std::fs::File::open(model_path.as_str())?;
let mut bs = Vec::new();
fs.read_to_end(&mut bs)?;
std::mem::drop(fs);
// We chop off file name from model_path and append "data/"
let model_data_dir = model_path
.split('/')
.take(model_path.split('/').count() - 1)
.collect::<Vec<&str>>()
.join("/")
+ "/data/";
let result = unpickler::unpickle(&bs)?;
println!("Loading embeddings from {}...", model_data_dir);
let emb = Embedding::from_unpickled(&result, model_data_dir.clone())?;
println!("Loading transformer weights from {}...", model_data_dir);
let tr = Transformer::from_unpickled(
&result,
emb,
4096,
32,
32,
512,
1e-6,
32,
128,
model_data_dir,
)?;
println!("All is loaded. Starting inference.");
let mut toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt);
let mut prev_pos = 0;
let mut token_sampler = TokenSampler::new().temperature(0.8).top_p(0.9).top_k(50);
if let Some(temperature) = cli.temperature {
token_sampler = token_sampler.temperature(temperature);
}
if let Some(top_p) = cli.top_p {
token_sampler = token_sampler.top_p(top_p);
}
if let Some(top_k) = cli.top_k {
token_sampler = token_sampler.top_k(top_k as usize);
}
println!("Temperature: {}", token_sampler.get_temperature());
println!("Top P: {}", token_sampler.get_top_p());
println!("Top K: {}", token_sampler.get_top_k());
let mut caches = tr.make_caches();
loop {
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches);
let highest_pred_idx = token_sampler.sample(&preds);
toks_id.push(highest_pred_idx as TokenId);
let mut tok_str: String = "".to_string();
for tok_id in toks_id[prev_pos + 1..].iter() {
if *tok_id == 1 {
continue;
}
let tok = tok.id_to_str(*tok_id);
if tok == "<0x0A>" {
tok_str += "\n";
} else {
tok_str += tok.replace('▁', " ").as_str();
}
}
print!("{}", tok_str);
let _ = std::io::stdout().flush();
prev_pos = toks_id.len() - 1;
}
}