You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
rllama/src/rllama_main.rs

825 lines
26 KiB
Rust

use crate::embedding::Embedding;
use crate::semaphore::Semaphore;
#[cfg(feature = "opencl")]
use crate::tensor_opencl_support::OpenCL;
use crate::token_sampler::TokenSampler;
use crate::tokenizer::{TokenId, Tokenizer};
use crate::transformer::{DataSettings, Transformer, TransformerCaches};
use crate::unpickler;
use crate::unpickler::Value;
use clap::Parser;
use colored::Colorize;
#[cfg(feature = "server")]
use rocket::{response::status, response::Stream, Data, State};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::io::{Read, Write};
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
#[derive(Parser, Clone)]
#[command(author, version, about, long_about = None)]
struct Cli {
#[arg(long)]
model_path: String,
#[arg(long)]
tokenizer_path: String,
#[arg(long)]
param_path: String,
#[arg(short, long, action)]
quiet: bool,
#[arg(long)]
prompt: Option<String>,
#[arg(long)]
prompt_file: Option<String>,
#[arg(long)]
max_seq_len: Option<usize>,
#[arg(long)]
temperature: Option<f32>,
#[arg(long)]
top_p: Option<f32>,
#[arg(long)]
top_k: Option<i32>,
#[arg(long)]
repetition_penalty: Option<f32>,
#[arg(long)]
max_threads: Option<usize>,
#[arg(long, action)]
f16: bool,
#[arg(long, action)]
k4: bool,
#[cfg(feature = "opencl")]
#[arg(long)]
opencl_device: Option<usize>,
#[arg(long, action)]
inference_server: bool,
#[arg(long)]
inference_server_port: Option<u16>,
#[arg(long)]
inference_server_host: Option<String>,
#[arg(long)]
inference_server_max_concurrent_inferences: Option<usize>,
#[arg(long)]
inference_server_api_path: Option<String>,
#[arg(long)]
inference_server_prompt_cache_size: Option<usize>,
#[arg(long, action)]
inference_server_exit_after_one_query: bool,
}
#[derive(Clone, Serialize, Deserialize)]
struct ModelParams {
dim: usize,
multiple_of: usize,
n_heads: usize,
n_layers: usize,
norm_eps: f64,
vocab_size: i64,
}
pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse();
let model_path = cli.model_path.clone();
let tokenizer_path = cli.tokenizer_path.clone();
let param_path = cli.param_path.clone();
#[cfg(not(feature = "server"))]
if cli.inference_server {
eprintln!("Inference server is not enabled in this build.");
return Err("Inference server is not enabled in this build.".into());
}
let max_threads: usize = match cli.max_threads {
None => rayon::current_num_threads(),
Some(max_threads) => {
rayon::ThreadPoolBuilder::new()
.num_threads(max_threads)
.build_global()
.unwrap();
max_threads
}
};
let mut be_quiet: bool = false;
if !colored::control::SHOULD_COLORIZE.should_colorize() {
be_quiet = true;
}
if cli.quiet {
be_quiet = true;
}
if be_quiet {
colored::control::SHOULD_COLORIZE.set_override(false);
}
// Custom println-like macro that respects be_quiet
macro_rules! pln {
($($arg:tt)*) => {
if !be_quiet {
std::println!($($arg)*);
}
};
}
#[cfg(feature = "opencl")]
let opencl: Option<OpenCL> = {
let opencl_device = cli.opencl_device.unwrap_or(0);
match OpenCL::new(!be_quiet, opencl_device) {
Err(openclerr) => {
eprintln!("OpenCL error: {}", openclerr);
eprintln!("OpenCL is disabled because it failed to initialize.");
None
}
Ok(opencl) => {
println!("OpenCL initialized.");
Some(opencl)
}
}
};
// Read ModelParams from param_path, we expect it to be JSON
let mut fs = std::fs::File::open(&param_path)?;
let mut bs = Vec::new();
fs.read_to_end(&mut bs)?;
std::mem::drop(fs);
let params: ModelParams = serde_json::from_slice(&bs)?;
pln!("Loaded model parameters from {}.", param_path);
let prompt: String = match (&cli.prompt, &cli.prompt_file) {
(Some(ref prompt), None) => {
pln!("Using prompt: {}", prompt);
prompt.clone()
}
(None, Some(ref prompt_file)) => {
pln!("Using prompt file: {}", prompt_file);
let mut fs = std::fs::File::open(prompt_file)?;
let mut bs = Vec::new();
fs.read_to_end(&mut bs)?;
std::mem::drop(fs);
String::from_utf8(bs)?
}
_ => {
if cli.inference_server {
"".to_string()
} else {
eprintln!("Please provide either a prompt or a prompt file.");
return Err("Please provide either a prompt or a prompt file.".into());
}
}
};
pln!("Starting up. Loading tokenizer from {}...", tokenizer_path);
let tok = Tokenizer::load(tokenizer_path.as_str())?;
pln!("Tokenizer loaded. Loading model from {}...", model_path);
let mut unpickle_results: Vec<Value> = vec![];
let mut part: usize = 0;
loop {
let model_path: PathBuf = model_path.clone().into();
let base_path = model_path.join(format!("consolidated.{:02}", part));
// The data file is in consolidated.XX/data.pkl where XX is the part number.
let full_path = base_path.join("data.pkl");
let mut fs = match std::fs::File::open(&full_path) {
Ok(fs) => fs,
Err(err) => {
if err.kind() == std::io::ErrorKind::NotFound {
break;
} else {
return Err(err.into());
}
}
};
let mut bs = Vec::new();
fs.read_to_end(&mut bs)?;
std::mem::drop(fs);
pln!("Read data.pkl from path {}", full_path.display());
let result = unpickler::unpickle(&bs)?;
unpickle_results.push(result);
part += 1;
}
pln!("Loading embeddings from {}...", model_path);
let emb = Embedding::from_unpickled(&unpickle_results, model_path.clone())?;
let max_seq_len = cli.max_seq_len.unwrap_or(1024);
let mut data_settings = {
#[cfg(feature = "opencl")]
{
if let Some(opencl) = opencl {
let ds = DataSettings::new(Some(opencl));
ds.use_opencl()
} else {
DataSettings::new(None)
}
}
#[cfg(not(feature = "opencl"))]
DataSettings::new()
};
if cli.f16 {
data_settings = data_settings.force_f16();
}
if cli.k4 {
data_settings = data_settings.force_k4();
}
pln!("Loading transformer weights from {}...", model_path);
let tr = Transformer::from_unpickled(
&unpickle_results,
emb,
params.dim,
params.n_layers,
params.n_heads,
max_seq_len,
params.norm_eps,
data_settings,
model_path,
)?;
pln!("All is loaded. Starting inference.");
let tr: Arc<Transformer> = Arc::new(tr);
let tok: Arc<Tokenizer> = Arc::new(tok);
if cli.inference_server {
#[cfg(feature = "server")]
{
server_inference(cli, tr, tok, be_quiet, max_seq_len, params, max_threads)
}
#[cfg(not(feature = "server"))]
{
eprintln!("The inference server feature is not enabled.");
eprintln!("Please enable it with the \"inference-server\" feature.");
Err("The inference server feature is not enabled.".into())
}
} else {
command_line_inference(
cli.clone(),
tr.clone(),
tok.clone(),
prompt.clone(),
be_quiet,
max_seq_len,
params.clone(),
max_threads,
)
}
}
#[cfg(feature = "server")]
fn server_inference(
cli: Cli,
tr: Arc<Transformer>,
tok: Arc<Tokenizer>,
be_quiet: bool,
max_seq_len: usize,
_params: ModelParams,
_max_threads: usize,
) -> Result<(), Box<dyn std::error::Error>> {
macro_rules! pln {
($($arg:tt)*) => {
if !be_quiet {
std::println!($($arg)*);
}
};
}
let inference_server_port = cli.inference_server_port.unwrap_or(8080);
let inference_server_host = cli
.inference_server_host
.clone()
.unwrap_or("127.0.0.1".to_string());
let inference_server_max_concurrent_inferences =
cli.inference_server_max_concurrent_inferences.unwrap_or(5);
let inference_server_api_path = cli
.inference_server_api_path
.clone()
.unwrap_or("/rllama/v1/inference".to_string());
let inference_server_prompt_cache_size = cli.inference_server_prompt_cache_size.unwrap_or(50);
pln!(
"Maximum concurrent inferences: {}",
inference_server_max_concurrent_inferences
);
pln!("Prompt cache size: {}", inference_server_prompt_cache_size);
pln!("Maximum sequence length: {}", max_seq_len);
pln!(
"--- Starting HTTP server on {}:{}, answering to requests at {} ---",
inference_server_host,
inference_server_port,
inference_server_api_path
);
// If there are too many connections, they will hang until they get their turn.
// Maybe can later implement return 503 slow down or something similar.
let concurrent_requests_semaphore = Semaphore::new(inference_server_max_concurrent_inferences);
let rocket_conf = rocket::Config::build(rocket::config::Environment::Production)
.address(inference_server_host)
.port(inference_server_port)
.finalize()
.unwrap();
let app = rocket::custom(rocket_conf)
.mount(&inference_server_api_path, routes![handle_request])
.manage(InferenceServerState {
transformer: tr,
tokenizer: tok,
max_seq_len,
concurrent_requests_semaphore,
attention_cache_repository: Arc::new(RwLock::new(AttentionCacheRepository::empty(
inference_server_prompt_cache_size,
))),
exit_after_one_query: cli.inference_server_exit_after_one_query,
});
app.launch();
panic!("Starting web server failed.");
}
fn is_false(b: &bool) -> bool {
!b
}
#[derive(Serialize, Deserialize, Clone, Debug)]
struct InferenceRequest {
temperature: Option<f32>,
top_k: Option<usize>,
top_p: Option<f32>,
repetition_penalty: Option<f32>,
max_seq_len: Option<usize>,
max_new_tokens: Option<usize>,
no_token_sampling: Option<bool>,
stop_at_end_token: Option<bool>,
prompt: String,
}
#[cfg(feature = "server")]
#[derive(Serialize, Deserialize, Clone, Debug)]
struct PredResult {
p: f32,
#[serde(skip_serializing_if = "is_false")]
is_end_token: bool,
}
#[cfg(feature = "server")]
struct GeneratingSession {
transformer: Arc<Transformer>,
token_sampler: TokenSampler,
tokenizer: Arc<Tokenizer>,
attention_cache_repository: Arc<RwLock<AttentionCacheRepository>>,
tokens: Vec<TokenId>,
req_max_seq_len: usize,
req_max_new_tokens: usize,
new_tokens_generated: usize,
prev_pos: usize,
no_token_sampling: bool,
stop_at_end_token: bool,
sent_stuff_last_time: bool,
exit_after_one_query: bool,
result: Vec<u8>, // stores JSONL lines to be returned from read()
}
#[cfg(feature = "server")]
impl GeneratingSession {
fn read_from_result(&mut self, buf: &mut [u8]) -> usize {
if !self.result.is_empty() {
if self.result.len() <= buf.len() {
for idx in 0..self.result.len() {
buf[idx] = self.result[idx];
}
let len = self.result.len();
self.sent_stuff_last_time = true;
self.result.truncate(0);
return len;
} else {
for idx in 0..buf.len() {
buf[idx] = self.result[idx];
}
self.result = self.result[buf.len()..].to_vec();
self.sent_stuff_last_time = true;
return buf.len();
}
}
return 0;
}
}
#[cfg(feature = "server")]
impl Read for GeneratingSession {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.sent_stuff_last_time && self.result.is_empty() {
// If we return WouldBlock every time we send something, it'll cause Rocket to
// flush available data.
self.sent_stuff_last_time = false;
return Err(std::io::Error::new(
std::io::ErrorKind::WouldBlock,
"WouldBlock",
));
}
// Push more data to the upstream if we have something stored.
let bytes_read = self.read_from_result(buf);
if bytes_read > 0 {
return Ok(bytes_read);
}
if self.tokens.len() >= self.req_max_seq_len {
if self.exit_after_one_query {
std::process::exit(0);
}
return Ok(0);
}
if self.new_tokens_generated >= self.req_max_new_tokens {
if self.exit_after_one_query {
std::process::exit(0);
}
return Ok(0);
}
let (mut caches, update_pos) = {
let mut ac = self.attention_cache_repository.write().unwrap();
match ac.get(&self.tokens) {
Some((c, pos)) if pos >= self.prev_pos => (c.true_clone(), pos),
Some(_) => {
std::mem::drop(ac);
(self.transformer.make_caches(), 0)
}
None => {
let caches = self.transformer.make_caches();
ac.put(self.tokens.clone(), caches.true_clone(), self.prev_pos);
(caches, self.prev_pos)
}
}
};
if update_pos > self.prev_pos {
self.prev_pos = update_pos;
}
assert!(self.result.is_empty());
let predictions =
self.transformer
.forward(&self.tokens[self.prev_pos..], self.prev_pos, &mut caches);
self.prev_pos = self.tokens.len();
let (highest_pred_idx, token_prob) =
self.token_sampler
.sample(&predictions, self.tokenizer.as_ref(), &self.tokens);
self.tokens.push(highest_pred_idx as TokenId);
{
let mut ac = self.attention_cache_repository.write().unwrap();
ac.put(self.tokens.clone(), caches, self.prev_pos);
}
self.new_tokens_generated += 1;
let token: &str = self.tokenizer.id_to_str(highest_pred_idx as TokenId);
let mut is_end_token: bool = false;
if token == "</s>" && self.stop_at_end_token {
self.new_tokens_generated = self.req_max_new_tokens;
is_end_token = true;
}
let mut result: BTreeMap<String, PredResult> = BTreeMap::new();
if self.no_token_sampling {
// All predictions go the line.
let probs = self
.token_sampler
.logits_to_btreemap(&predictions, self.tokenizer.as_ref());
for (k, v) in probs.into_iter() {
let mut is_end_token: bool = false;
if k == "</s>" {
is_end_token = true;
}
result.insert(
k,
PredResult {
p: v,
is_end_token: is_end_token,
},
);
}
// Convert to JSON
let json = serde_json::to_string(&result).unwrap();
self.result.extend(json.as_bytes());
self.result.push(b'\n');
return Ok(self.read_from_result(buf));
} else {
result.insert(
token.to_string(),
PredResult {
p: token_prob,
is_end_token,
},
);
let json = serde_json::to_string(&result).unwrap();
self.result.extend(json.as_bytes());
self.result.push(b'\n');
return Ok(self.read_from_result(buf));
}
}
}
#[cfg(feature = "server")]
struct AttentionCacheRepository {
caches: BTreeMap<Vec<TokenId>, (TransformerCaches, usize, std::time::Instant)>,
max_sz: usize,
}
#[cfg(feature = "server")]
impl AttentionCacheRepository {
fn empty(max_size: usize) -> AttentionCacheRepository {
AttentionCacheRepository {
caches: BTreeMap::new(),
max_sz: max_size,
}
}
/// Makes sure the cache repository is not larger than sz, evicts any older items.
fn limit_size(&mut self, sz: usize) {
if sz == 0 {
self.caches = BTreeMap::new();
return;
}
// Slow algorithm but I guess our cache will never be unimaginably large so it's probably
// fine
while self.caches.len() > sz {
let mut oldest_time = None;
let mut oldest_key: Option<&Vec<TokenId>> = None;
for (k, (_, _, time)) in self.caches.iter() {
if oldest_time.is_none() || time < oldest_time.unwrap() {
oldest_time = Some(time);
oldest_key = Some(k);
}
}
let oldest_key = oldest_key.unwrap().clone();
self.caches.remove(&oldest_key);
}
}
fn get(&self, tokens: &[TokenId]) -> Option<(&TransformerCaches, usize)> {
if let Some((caches, pos, _)) = self.caches.get(tokens) {
Some((caches, *pos))
} else {
None
}
}
fn put(&mut self, tokens: Vec<TokenId>, caches: TransformerCaches, prev_pos: usize) {
self.caches
.insert(tokens, (caches, prev_pos, std::time::Instant::now()));
self.limit_size(self.max_sz);
}
}
#[cfg(feature = "server")]
#[derive(Clone)]
struct InferenceServerState {
transformer: Arc<Transformer>,
tokenizer: Arc<Tokenizer>,
max_seq_len: usize,
concurrent_requests_semaphore: Semaphore,
attention_cache_repository: Arc<RwLock<AttentionCacheRepository>>,
exit_after_one_query: bool,
}
#[cfg(feature = "server")]
#[post("/", data = "<input>")]
fn handle_request(
state: State<InferenceServerState>,
input: Data,
) -> Result<Stream<GeneratingSession>, status::BadRequest<String>> {
let _lock = state.concurrent_requests_semaphore.acquire();
let tr = state.transformer.clone();
let tok = state.tokenizer.clone();
let mut data = input.open();
let mut databuf: Vec<u8> = Vec::new();
data.read_to_end(&mut databuf).unwrap();
// Parse the JSON out of the request
let request: InferenceRequest = match serde_json::from_slice(&databuf) {
Err(_e) => {
return Err(status::BadRequest(Some("Invalid JSON.".to_string())));
}
Ok(ir) => ir,
};
let stop_at_end_token = request.stop_at_end_token.unwrap_or(true);
let temperature = request.temperature.unwrap_or(1.0);
let top_k = request.top_k.unwrap_or(20);
let top_p = request.top_p.unwrap_or(1.0);
let repetition_penalty = request.repetition_penalty.unwrap_or(1.0);
let mut req_max_seq_len = request.max_seq_len.unwrap_or(state.max_seq_len);
if req_max_seq_len > state.max_seq_len {
req_max_seq_len = state.max_seq_len;
}
let req_max_new_tokens = request.max_new_tokens.unwrap_or(20);
let no_token_sampling = request.no_token_sampling.unwrap_or(false);
let prompt = request.prompt;
if temperature.is_nan() {
return Err(status::BadRequest(Some(
"Temperature must be a number.".to_string(),
)));
}
if top_k == 0 {
return Err(status::BadRequest(Some(
"Top-k must be greater than 0.".to_string(),
)));
}
if top_p.is_nan() {
return Err(status::BadRequest(Some(
"Top-p must be a number.".to_string(),
)));
}
if repetition_penalty.is_nan() {
return Err(status::BadRequest(Some(
"Repetition penalty must be a number.".to_string(),
)));
}
let token_sampler = TokenSampler::new()
.temperature(temperature)
.top_p(top_p)
.top_k(top_k)
.repetition_penalty(repetition_penalty);
let toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt.clone());
let gsession = GeneratingSession {
transformer: tr,
tokenizer: tok,
attention_cache_repository: state.attention_cache_repository.clone(),
token_sampler: token_sampler,
tokens: toks_id,
req_max_seq_len: req_max_seq_len,
req_max_new_tokens: req_max_new_tokens,
new_tokens_generated: 0,
prev_pos: 0,
no_token_sampling: no_token_sampling,
stop_at_end_token: stop_at_end_token,
sent_stuff_last_time: false,
exit_after_one_query: state.exit_after_one_query,
result: Vec::new(),
};
return Ok(rocket::response::Stream::chunked(gsession, 1024));
}
fn command_line_inference(
cli: Cli,
tr: Arc<Transformer>,
tok: Arc<Tokenizer>,
prompt: String,
be_quiet: bool,
max_seq_len: usize,
params: ModelParams,
max_threads: usize,
) -> Result<(), Box<dyn std::error::Error>> {
// Custom println-like macro that respects be_quiet
macro_rules! pln {
($($arg:tt)*) => {
if !be_quiet {
std::println!($($arg)*);
}
};
}
let mut toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt.clone());
let mut prev_pos = 0;
let mut token_sampler = TokenSampler::new()
.temperature(1.0)
.top_p(1.0)
.top_k(20)
.repetition_penalty(1.0);
if let Some(temperature) = cli.temperature {
token_sampler = token_sampler.temperature(temperature);
}
if let Some(top_p) = cli.top_p {
token_sampler = token_sampler.top_p(top_p);
}
if let Some(top_k) = cli.top_k {
token_sampler = token_sampler.top_k(top_k as usize);
}
if let Some(repetition_penalty) = cli.repetition_penalty {
token_sampler = token_sampler.repetition_penalty(repetition_penalty);
}
pln!("---");
pln!(" dim: {}", params.dim);
pln!(" multiple_of: {}", params.multiple_of);
pln!(" n_heads: {}", params.n_heads);
pln!(" n_layers: {}", params.n_layers);
pln!(" norm_eps: {}", params.norm_eps);
pln!(" vocab_size: {}", params.vocab_size);
pln!("---");
pln!(" maximum number of threads: {}", max_threads);
pln!("---");
pln!("Max sequence length: {}", max_seq_len);
pln!("Temperature: {}", token_sampler.get_temperature());
pln!("Top P: {}", token_sampler.get_top_p());
pln!("Top K: {}", token_sampler.get_top_k());
pln!(
"Repetition penalty: {}",
token_sampler.get_repetition_penalty()
);
pln!("---");
pln!(
"{}",
" This is the color of the initial prompt".truecolor(128, 128, 255)
);
pln!(
"{}",
" This is the color of the generated text".truecolor(128, 255, 128)
);
pln!("---");
print!("{}", prompt.as_str().truecolor(128, 128, 255));
let _ = std::io::stdout().flush();
let mut first_token_time: std::time::Duration = std::time::Duration::new(0, 0);
let mut times_per_token: Vec<std::time::Duration> = vec![];
let mut caches = tr.make_caches();
let mut first: bool = true;
let mut stop_seen: bool = false;
while toks_id.len() < max_seq_len {
let now = std::time::Instant::now();
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches);
let (highest_pred_idx, token_prob) = token_sampler.sample(&preds, &tok, &toks_id);
toks_id.push(highest_pred_idx as TokenId);
for (tok_idx, tok_id) in toks_id[prev_pos + 1..].iter().enumerate() {
if *tok_id == 1 {
continue;
}
let mut tok_str: String = "".to_string();
let tok = tok.id_to_str(*tok_id);
if tok == "</s>" {
tok_str += "";
stop_seen = true;
}
if tok == "<0x0A>" {
tok_str += "\n";
} else {
tok_str += tok.replace('▁', " ").as_str();
}
if first && tok_idx < toks_id.len() - 2 {
// intentionally left empty
} else {
let redness: f32 = token_prob * 255.0;
let redness = if redness > 255.0 {
255
} else if redness < 0.0 {
0
} else {
redness as u8
};
print!(
"{}",
tok_str.truecolor(128 + redness / 2, 255 - redness / 2, 128)
);
}
}
if first {
first_token_time = now.elapsed();
} else {
times_per_token.push(now.elapsed());
}
let _ = std::io::stdout().flush();
prev_pos = toks_id.len() - 1;
first = false;
if stop_seen {
break;
}
}
println!();
if stop_seen && !be_quiet {
println!("Stop token seen. Stopping.");
}
if !be_quiet {
println!("---");
println!(
"Time taken to generate first token: {:?}ms",
first_token_time.as_millis()
);
println!(
"Time taken per token (excluding first token): {:?}ms",
times_per_token.iter().map(|t| t.as_millis()).sum::<u128>()
/ times_per_token.len() as u128
);
}
Ok(())
}