use crate::embedding::Embedding; use crate::semaphore::Semaphore; #[cfg(feature = "opencl")] use crate::tensor_opencl_support::OpenCL; use crate::token_sampler::TokenSampler; use crate::tokenizer::{TokenId, Tokenizer}; use crate::transformer::{DataSettings, Transformer, TransformerCaches}; use crate::unpickler; use crate::unpickler::Value; use clap::Parser; use colored::Colorize; #[cfg(feature = "server")] use rocket::{response::status, response::Stream, Data, State}; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; use std::io::{Read, Write}; use std::path::PathBuf; use std::sync::{Arc, RwLock}; #[derive(Parser, Clone)] #[command(author, version, about, long_about = None)] struct Cli { #[arg(long)] model_path: String, #[arg(long)] tokenizer_path: String, #[arg(long)] param_path: String, #[arg(short, long, action)] quiet: bool, #[arg(long)] prompt: Option, #[arg(long)] prompt_file: Option, #[arg(long)] max_seq_len: Option, #[arg(long)] temperature: Option, #[arg(long)] top_p: Option, #[arg(long)] top_k: Option, #[arg(long)] repetition_penalty: Option, #[arg(long)] max_threads: Option, #[arg(long, action)] f16: bool, #[arg(long, action)] k4: bool, #[cfg(feature = "opencl")] #[arg(long)] opencl_device: Option, #[arg(long, action)] inference_server: bool, #[arg(long)] inference_server_port: Option, #[arg(long)] inference_server_host: Option, #[arg(long)] inference_server_max_concurrent_inferences: Option, #[arg(long)] inference_server_api_path: Option, #[arg(long)] inference_server_prompt_cache_size: Option, #[arg(long, action)] inference_server_exit_after_one_query: bool, } #[derive(Clone, Serialize, Deserialize)] struct ModelParams { dim: usize, multiple_of: usize, n_heads: usize, n_layers: usize, norm_eps: f64, vocab_size: i64, } pub fn main() -> Result<(), Box> { let cli = Cli::parse(); let model_path = cli.model_path.clone(); let tokenizer_path = cli.tokenizer_path.clone(); let param_path = cli.param_path.clone(); #[cfg(not(feature = "server"))] if cli.inference_server { eprintln!("Inference server is not enabled in this build."); return Err("Inference server is not enabled in this build.".into()); } let max_threads: usize = match cli.max_threads { None => rayon::current_num_threads(), Some(max_threads) => { rayon::ThreadPoolBuilder::new() .num_threads(max_threads) .build_global() .unwrap(); max_threads } }; let mut be_quiet: bool = false; if !colored::control::SHOULD_COLORIZE.should_colorize() { be_quiet = true; } if cli.quiet { be_quiet = true; } if be_quiet { colored::control::SHOULD_COLORIZE.set_override(false); } // Custom println-like macro that respects be_quiet macro_rules! pln { ($($arg:tt)*) => { if !be_quiet { std::println!($($arg)*); } }; } #[cfg(feature = "opencl")] let opencl: Option = { let opencl_device = cli.opencl_device.unwrap_or(0); match OpenCL::new(!be_quiet, opencl_device) { Err(openclerr) => { eprintln!("OpenCL error: {}", openclerr); eprintln!("OpenCL is disabled because it failed to initialize."); None } Ok(opencl) => { println!("OpenCL initialized."); Some(opencl) } } }; // Read ModelParams from param_path, we expect it to be JSON let mut fs = std::fs::File::open(¶m_path)?; let mut bs = Vec::new(); fs.read_to_end(&mut bs)?; std::mem::drop(fs); let params: ModelParams = serde_json::from_slice(&bs)?; pln!("Loaded model parameters from {}.", param_path); let prompt: String = match (&cli.prompt, &cli.prompt_file) { (Some(ref prompt), None) => { pln!("Using prompt: {}", prompt); prompt.clone() } (None, Some(ref prompt_file)) => { pln!("Using prompt file: {}", prompt_file); let mut fs = std::fs::File::open(prompt_file)?; let mut bs = Vec::new(); fs.read_to_end(&mut bs)?; std::mem::drop(fs); String::from_utf8(bs)? } _ => { if cli.inference_server { "".to_string() } else { eprintln!("Please provide either a prompt or a prompt file."); return Err("Please provide either a prompt or a prompt file.".into()); } } }; pln!("Starting up. Loading tokenizer from {}...", tokenizer_path); let tok = Tokenizer::load(tokenizer_path.as_str())?; pln!("Tokenizer loaded. Loading model from {}...", model_path); let mut unpickle_results: Vec = vec![]; let mut part: usize = 0; loop { let model_path: PathBuf = model_path.clone().into(); let base_path = model_path.join(format!("consolidated.{:02}", part)); // The data file is in consolidated.XX/data.pkl where XX is the part number. let full_path = base_path.join("data.pkl"); let mut fs = match std::fs::File::open(&full_path) { Ok(fs) => fs, Err(err) => { if err.kind() == std::io::ErrorKind::NotFound { break; } else { return Err(err.into()); } } }; let mut bs = Vec::new(); fs.read_to_end(&mut bs)?; std::mem::drop(fs); pln!("Read data.pkl from path {}", full_path.display()); let result = unpickler::unpickle(&bs)?; unpickle_results.push(result); part += 1; } pln!("Loading embeddings from {}...", model_path); let emb = Embedding::from_unpickled(&unpickle_results, model_path.clone())?; let max_seq_len = cli.max_seq_len.unwrap_or(1024); let mut data_settings = { #[cfg(feature = "opencl")] { if let Some(opencl) = opencl { let ds = DataSettings::new(Some(opencl)); ds.use_opencl() } else { DataSettings::new(None) } } #[cfg(not(feature = "opencl"))] DataSettings::new() }; if cli.f16 { data_settings = data_settings.force_f16(); } if cli.k4 { data_settings = data_settings.force_k4(); } pln!("Loading transformer weights from {}...", model_path); let tr = Transformer::from_unpickled( &unpickle_results, emb, params.dim, params.n_layers, params.n_heads, max_seq_len, params.norm_eps, data_settings, model_path, )?; pln!("All is loaded. Starting inference."); let tr: Arc = Arc::new(tr); let tok: Arc = Arc::new(tok); if cli.inference_server { #[cfg(feature = "server")] { server_inference(cli, tr, tok, be_quiet, max_seq_len, params, max_threads) } #[cfg(not(feature = "server"))] { eprintln!("The inference server feature is not enabled."); eprintln!("Please enable it with the \"inference-server\" feature."); Err("The inference server feature is not enabled.".into()) } } else { command_line_inference( cli.clone(), tr.clone(), tok.clone(), prompt.clone(), be_quiet, max_seq_len, params.clone(), max_threads, ) } } #[cfg(feature = "server")] fn server_inference( cli: Cli, tr: Arc, tok: Arc, be_quiet: bool, max_seq_len: usize, _params: ModelParams, _max_threads: usize, ) -> Result<(), Box> { macro_rules! pln { ($($arg:tt)*) => { if !be_quiet { std::println!($($arg)*); } }; } let inference_server_port = cli.inference_server_port.unwrap_or(8080); let inference_server_host = cli .inference_server_host .clone() .unwrap_or("127.0.0.1".to_string()); let inference_server_max_concurrent_inferences = cli.inference_server_max_concurrent_inferences.unwrap_or(5); let inference_server_api_path = cli .inference_server_api_path .clone() .unwrap_or("/rllama/v1/inference".to_string()); let inference_server_prompt_cache_size = cli.inference_server_prompt_cache_size.unwrap_or(50); pln!( "Maximum concurrent inferences: {}", inference_server_max_concurrent_inferences ); pln!("Prompt cache size: {}", inference_server_prompt_cache_size); pln!("Maximum sequence length: {}", max_seq_len); pln!( "--- Starting HTTP server on {}:{}, answering to requests at {} ---", inference_server_host, inference_server_port, inference_server_api_path ); // If there are too many connections, they will hang until they get their turn. // Maybe can later implement return 503 slow down or something similar. let concurrent_requests_semaphore = Semaphore::new(inference_server_max_concurrent_inferences); let rocket_conf = rocket::Config::build(rocket::config::Environment::Production) .address(inference_server_host) .port(inference_server_port) .finalize() .unwrap(); let app = rocket::custom(rocket_conf) .mount(&inference_server_api_path, routes![handle_request]) .manage(InferenceServerState { transformer: tr, tokenizer: tok, max_seq_len, concurrent_requests_semaphore, attention_cache_repository: Arc::new(RwLock::new(AttentionCacheRepository::empty( inference_server_prompt_cache_size, ))), exit_after_one_query: cli.inference_server_exit_after_one_query, }); app.launch(); panic!("Starting web server failed."); } fn is_false(b: &bool) -> bool { !b } #[derive(Serialize, Deserialize, Clone, Debug)] struct InferenceRequest { temperature: Option, top_k: Option, top_p: Option, repetition_penalty: Option, max_seq_len: Option, max_new_tokens: Option, no_token_sampling: Option, stop_at_end_token: Option, prompt: String, } #[cfg(feature = "server")] #[derive(Serialize, Deserialize, Clone, Debug)] struct PredResult { p: f32, #[serde(skip_serializing_if = "is_false")] is_end_token: bool, } #[cfg(feature = "server")] struct GeneratingSession { transformer: Arc, token_sampler: TokenSampler, tokenizer: Arc, attention_cache_repository: Arc>, tokens: Vec, req_max_seq_len: usize, req_max_new_tokens: usize, new_tokens_generated: usize, prev_pos: usize, no_token_sampling: bool, stop_at_end_token: bool, sent_stuff_last_time: bool, exit_after_one_query: bool, result: Vec, // stores JSONL lines to be returned from read() } #[cfg(feature = "server")] impl GeneratingSession { fn read_from_result(&mut self, buf: &mut [u8]) -> usize { if !self.result.is_empty() { if self.result.len() <= buf.len() { for idx in 0..self.result.len() { buf[idx] = self.result[idx]; } let len = self.result.len(); self.sent_stuff_last_time = true; self.result.truncate(0); return len; } else { for idx in 0..buf.len() { buf[idx] = self.result[idx]; } self.result = self.result[buf.len()..].to_vec(); self.sent_stuff_last_time = true; return buf.len(); } } return 0; } } #[cfg(feature = "server")] impl Read for GeneratingSession { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { if self.sent_stuff_last_time && self.result.is_empty() { // If we return WouldBlock every time we send something, it'll cause Rocket to // flush available data. self.sent_stuff_last_time = false; return Err(std::io::Error::new( std::io::ErrorKind::WouldBlock, "WouldBlock", )); } // Push more data to the upstream if we have something stored. let bytes_read = self.read_from_result(buf); if bytes_read > 0 { return Ok(bytes_read); } if self.tokens.len() >= self.req_max_seq_len { if self.exit_after_one_query { std::process::exit(0); } return Ok(0); } if self.new_tokens_generated >= self.req_max_new_tokens { if self.exit_after_one_query { std::process::exit(0); } return Ok(0); } let (mut caches, update_pos) = { let mut ac = self.attention_cache_repository.write().unwrap(); match ac.get(&self.tokens) { Some((c, pos)) if pos >= self.prev_pos => (c.true_clone(), pos), Some(_) => { std::mem::drop(ac); (self.transformer.make_caches(), 0) } None => { let caches = self.transformer.make_caches(); ac.put(self.tokens.clone(), caches.true_clone(), self.prev_pos); (caches, self.prev_pos) } } }; if update_pos > self.prev_pos { self.prev_pos = update_pos; } assert!(self.result.is_empty()); let predictions = self.transformer .forward(&self.tokens[self.prev_pos..], self.prev_pos, &mut caches); self.prev_pos = self.tokens.len(); let (highest_pred_idx, token_prob) = self.token_sampler .sample(&predictions, self.tokenizer.as_ref(), &self.tokens); self.tokens.push(highest_pred_idx as TokenId); { let mut ac = self.attention_cache_repository.write().unwrap(); ac.put(self.tokens.clone(), caches, self.prev_pos); } self.new_tokens_generated += 1; let token: &str = self.tokenizer.id_to_str(highest_pred_idx as TokenId); let mut is_end_token: bool = false; if token == "" && self.stop_at_end_token { self.new_tokens_generated = self.req_max_new_tokens; is_end_token = true; } let mut result: BTreeMap = BTreeMap::new(); if self.no_token_sampling { // All predictions go the line. let probs = self .token_sampler .logits_to_btreemap(&predictions, self.tokenizer.as_ref()); for (k, v) in probs.into_iter() { let mut is_end_token: bool = false; if k == "" { is_end_token = true; } result.insert( k, PredResult { p: v, is_end_token: is_end_token, }, ); } // Convert to JSON let json = serde_json::to_string(&result).unwrap(); self.result.extend(json.as_bytes()); self.result.push(b'\n'); return Ok(self.read_from_result(buf)); } else { result.insert( token.to_string(), PredResult { p: token_prob, is_end_token, }, ); let json = serde_json::to_string(&result).unwrap(); self.result.extend(json.as_bytes()); self.result.push(b'\n'); return Ok(self.read_from_result(buf)); } } } #[cfg(feature = "server")] struct AttentionCacheRepository { caches: BTreeMap, (TransformerCaches, usize, std::time::Instant)>, max_sz: usize, } #[cfg(feature = "server")] impl AttentionCacheRepository { fn empty(max_size: usize) -> AttentionCacheRepository { AttentionCacheRepository { caches: BTreeMap::new(), max_sz: max_size, } } /// Makes sure the cache repository is not larger than sz, evicts any older items. fn limit_size(&mut self, sz: usize) { if sz == 0 { self.caches = BTreeMap::new(); return; } // Slow algorithm but I guess our cache will never be unimaginably large so it's probably // fine while self.caches.len() > sz { let mut oldest_time = None; let mut oldest_key: Option<&Vec> = None; for (k, (_, _, time)) in self.caches.iter() { if oldest_time.is_none() || time < oldest_time.unwrap() { oldest_time = Some(time); oldest_key = Some(k); } } let oldest_key = oldest_key.unwrap().clone(); self.caches.remove(&oldest_key); } } fn get(&self, tokens: &[TokenId]) -> Option<(&TransformerCaches, usize)> { if let Some((caches, pos, _)) = self.caches.get(tokens) { Some((caches, *pos)) } else { None } } fn put(&mut self, tokens: Vec, caches: TransformerCaches, prev_pos: usize) { self.caches .insert(tokens, (caches, prev_pos, std::time::Instant::now())); self.limit_size(self.max_sz); } } #[cfg(feature = "server")] #[derive(Clone)] struct InferenceServerState { transformer: Arc, tokenizer: Arc, max_seq_len: usize, concurrent_requests_semaphore: Semaphore, attention_cache_repository: Arc>, exit_after_one_query: bool, } #[cfg(feature = "server")] #[post("/", data = "")] fn handle_request( state: State, input: Data, ) -> Result, status::BadRequest> { let _lock = state.concurrent_requests_semaphore.acquire(); let tr = state.transformer.clone(); let tok = state.tokenizer.clone(); let mut data = input.open(); let mut databuf: Vec = Vec::new(); data.read_to_end(&mut databuf).unwrap(); // Parse the JSON out of the request let request: InferenceRequest = match serde_json::from_slice(&databuf) { Err(_e) => { return Err(status::BadRequest(Some("Invalid JSON.".to_string()))); } Ok(ir) => ir, }; let stop_at_end_token = request.stop_at_end_token.unwrap_or(true); let temperature = request.temperature.unwrap_or(1.0); let top_k = request.top_k.unwrap_or(20); let top_p = request.top_p.unwrap_or(1.0); let repetition_penalty = request.repetition_penalty.unwrap_or(1.0); let mut req_max_seq_len = request.max_seq_len.unwrap_or(state.max_seq_len); if req_max_seq_len > state.max_seq_len { req_max_seq_len = state.max_seq_len; } let req_max_new_tokens = request.max_new_tokens.unwrap_or(20); let no_token_sampling = request.no_token_sampling.unwrap_or(false); let prompt = request.prompt; if temperature.is_nan() { return Err(status::BadRequest(Some( "Temperature must be a number.".to_string(), ))); } if top_k == 0 { return Err(status::BadRequest(Some( "Top-k must be greater than 0.".to_string(), ))); } if top_p.is_nan() { return Err(status::BadRequest(Some( "Top-p must be a number.".to_string(), ))); } if repetition_penalty.is_nan() { return Err(status::BadRequest(Some( "Repetition penalty must be a number.".to_string(), ))); } let token_sampler = TokenSampler::new() .temperature(temperature) .top_p(top_p) .top_k(top_k) .repetition_penalty(repetition_penalty); let toks_id: Vec = tok.tokenize_to_ids(prompt.clone()); let gsession = GeneratingSession { transformer: tr, tokenizer: tok, attention_cache_repository: state.attention_cache_repository.clone(), token_sampler: token_sampler, tokens: toks_id, req_max_seq_len: req_max_seq_len, req_max_new_tokens: req_max_new_tokens, new_tokens_generated: 0, prev_pos: 0, no_token_sampling: no_token_sampling, stop_at_end_token: stop_at_end_token, sent_stuff_last_time: false, exit_after_one_query: state.exit_after_one_query, result: Vec::new(), }; return Ok(rocket::response::Stream::chunked(gsession, 1024)); } fn command_line_inference( cli: Cli, tr: Arc, tok: Arc, prompt: String, be_quiet: bool, max_seq_len: usize, params: ModelParams, max_threads: usize, ) -> Result<(), Box> { // Custom println-like macro that respects be_quiet macro_rules! pln { ($($arg:tt)*) => { if !be_quiet { std::println!($($arg)*); } }; } let mut toks_id: Vec = tok.tokenize_to_ids(prompt.clone()); let mut prev_pos = 0; let mut token_sampler = TokenSampler::new() .temperature(1.0) .top_p(1.0) .top_k(20) .repetition_penalty(1.0); if let Some(temperature) = cli.temperature { token_sampler = token_sampler.temperature(temperature); } if let Some(top_p) = cli.top_p { token_sampler = token_sampler.top_p(top_p); } if let Some(top_k) = cli.top_k { token_sampler = token_sampler.top_k(top_k as usize); } if let Some(repetition_penalty) = cli.repetition_penalty { token_sampler = token_sampler.repetition_penalty(repetition_penalty); } pln!("---"); pln!(" dim: {}", params.dim); pln!(" multiple_of: {}", params.multiple_of); pln!(" n_heads: {}", params.n_heads); pln!(" n_layers: {}", params.n_layers); pln!(" norm_eps: {}", params.norm_eps); pln!(" vocab_size: {}", params.vocab_size); pln!("---"); pln!(" maximum number of threads: {}", max_threads); pln!("---"); pln!("Max sequence length: {}", max_seq_len); pln!("Temperature: {}", token_sampler.get_temperature()); pln!("Top P: {}", token_sampler.get_top_p()); pln!("Top K: {}", token_sampler.get_top_k()); pln!( "Repetition penalty: {}", token_sampler.get_repetition_penalty() ); pln!("---"); pln!( "{}", " This is the color of the initial prompt".truecolor(128, 128, 255) ); pln!( "{}", " This is the color of the generated text".truecolor(128, 255, 128) ); pln!("---"); print!("{}", prompt.as_str().truecolor(128, 128, 255)); let _ = std::io::stdout().flush(); let mut first_token_time: std::time::Duration = std::time::Duration::new(0, 0); let mut times_per_token: Vec = vec![]; let mut caches = tr.make_caches(); let mut first: bool = true; let mut stop_seen: bool = false; while toks_id.len() < max_seq_len { let now = std::time::Instant::now(); let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches); let (highest_pred_idx, token_prob) = token_sampler.sample(&preds, &tok, &toks_id); toks_id.push(highest_pred_idx as TokenId); for (tok_idx, tok_id) in toks_id[prev_pos + 1..].iter().enumerate() { if *tok_id == 1 { continue; } let mut tok_str: String = "".to_string(); let tok = tok.id_to_str(*tok_id); if tok == "" { tok_str += ""; stop_seen = true; } if tok == "<0x0A>" { tok_str += "\n"; } else { tok_str += tok.replace('▁', " ").as_str(); } if first && tok_idx < toks_id.len() - 2 { // intentionally left empty } else { let redness: f32 = token_prob * 255.0; let redness = if redness > 255.0 { 255 } else if redness < 0.0 { 0 } else { redness as u8 }; print!( "{}", tok_str.truecolor(128 + redness / 2, 255 - redness / 2, 128) ); } } if first { first_token_time = now.elapsed(); } else { times_per_token.push(now.elapsed()); } let _ = std::io::stdout().flush(); prev_pos = toks_id.len() - 1; first = false; if stop_seen { break; } } println!(); if stop_seen && !be_quiet { println!("Stop token seen. Stopping."); } if !be_quiet { println!("---"); println!( "Time taken to generate first token: {:?}ms", first_token_time.as_millis() ); println!( "Time taken per token (excluding first token): {:?}ms", times_per_token.iter().map(|t| t.as_millis()).sum::() / times_per_token.len() as u128 ); } Ok(()) }