@ -1,18 +1,23 @@
use crate ::embedding ::Embedding ;
use crate ::embedding ::Embedding ;
use crate ::semaphore ::Semaphore ;
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
use crate ::tensor_opencl_support ::OpenCL ;
use crate ::tensor_opencl_support ::OpenCL ;
use crate ::token_sampler ::TokenSampler ;
use crate ::token_sampler ::TokenSampler ;
use crate ::tokenizer ::{ TokenId , Tokenizer } ;
use crate ::tokenizer ::{ TokenId , Tokenizer } ;
use crate ::transformer ::{ DataSettings , Transformer };
use crate ::transformer ::{ DataSettings , Transformer , TransformerCaches };
use crate ::unpickler ;
use crate ::unpickler ;
use crate ::unpickler ::Value ;
use crate ::unpickler ::Value ;
use clap ::Parser ;
use clap ::Parser ;
use colored ::Colorize ;
use colored ::Colorize ;
#[ cfg(feature = " server " ) ]
use rocket ::{ response ::status , response ::Stream , Data , State } ;
use serde ::{ Deserialize , Serialize } ;
use serde ::{ Deserialize , Serialize } ;
use std ::collections ::BTreeMap ;
use std ::io ::{ Read , Write } ;
use std ::io ::{ Read , Write } ;
use std ::path ::PathBuf ;
use std ::path ::PathBuf ;
use std ::sync ::{ Arc , RwLock } ;
#[ derive(Parser )]
#[ derive(Parser , Clone )]
#[ command(author, version, about, long_about = None) ]
#[ command(author, version, about, long_about = None) ]
struct Cli {
struct Cli {
#[ arg(long) ]
#[ arg(long) ]
@ -51,6 +56,24 @@ struct Cli {
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
#[ arg(long) ]
#[ arg(long) ]
opencl_device : Option < usize > ,
opencl_device : Option < usize > ,
#[ arg(long, action) ]
inference_server : bool ,
#[ arg(long) ]
inference_server_port : Option < u16 > ,
#[ arg(long) ]
inference_server_host : Option < String > ,
#[ arg(long) ]
inference_server_max_concurrent_inferences : Option < usize > ,
#[ arg(long) ]
inference_server_api_path : Option < String > ,
#[ arg(long) ]
inference_server_prompt_cache_size : Option < usize > ,
}
}
#[ derive(Clone, Serialize, Deserialize) ]
#[ derive(Clone, Serialize, Deserialize) ]
@ -65,9 +88,15 @@ struct ModelParams {
pub fn main ( ) -> Result < ( ) , Box < dyn std ::error ::Error > > {
pub fn main ( ) -> Result < ( ) , Box < dyn std ::error ::Error > > {
let cli = Cli ::parse ( ) ;
let cli = Cli ::parse ( ) ;
let model_path = cli . model_path ;
let model_path = cli . model_path . clone ( ) ;
let tokenizer_path = cli . tokenizer_path ;
let tokenizer_path = cli . tokenizer_path . clone ( ) ;
let param_path = cli . param_path ;
let param_path = cli . param_path . clone ( ) ;
#[ cfg(not(feature = " server " )) ]
if cli . inference_server {
eprintln! ( "Inference server is not enabled in this build." ) ;
return Err ( "Inference server is not enabled in this build." . into ( ) ) ;
}
let max_threads : usize = match cli . max_threads {
let max_threads : usize = match cli . max_threads {
None = > rayon ::current_num_threads ( ) ,
None = > rayon ::current_num_threads ( ) ,
@ -91,6 +120,15 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
colored ::control ::SHOULD_COLORIZE . set_override ( false ) ;
colored ::control ::SHOULD_COLORIZE . set_override ( false ) ;
}
}
// Custom println-like macro that respects be_quiet
macro_rules! pln {
( $( $arg :tt ) * ) = > {
if ! be_quiet {
std ::println! ( $( $arg ) * ) ;
}
} ;
}
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
let opencl : Option < OpenCL > = {
let opencl : Option < OpenCL > = {
let opencl_device = cli . opencl_device . unwrap_or ( 0 ) ;
let opencl_device = cli . opencl_device . unwrap_or ( 0 ) ;
@ -107,15 +145,6 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
}
}
} ;
} ;
// Custom println-like macro that respects be_quiet
macro_rules! pln {
( $( $arg :tt ) * ) = > {
if ! be_quiet {
std ::println! ( $( $arg ) * ) ;
}
} ;
}
// Read ModelParams from param_path, we expect it to be JSON
// Read ModelParams from param_path, we expect it to be JSON
let mut fs = std ::fs ::File ::open ( & param_path ) ? ;
let mut fs = std ::fs ::File ::open ( & param_path ) ? ;
let mut bs = Vec ::new ( ) ;
let mut bs = Vec ::new ( ) ;
@ -124,12 +153,12 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let params : ModelParams = serde_json ::from_slice ( & bs ) ? ;
let params : ModelParams = serde_json ::from_slice ( & bs ) ? ;
pln ! ( "Loaded model parameters from {}." , param_path ) ;
pln ! ( "Loaded model parameters from {}." , param_path ) ;
let prompt : String = match ( cli . prompt , cli . prompt_file ) {
let prompt : String = match ( & cli . prompt , & cli . prompt_file ) {
( Some ( prompt ) , None ) = > {
( Some ( ref prompt ) , None ) = > {
pln ! ( "Using prompt: {}" , prompt ) ;
pln ! ( "Using prompt: {}" , prompt ) ;
prompt
prompt . clone ( )
}
}
( None , Some ( prompt_file ) ) = > {
( None , Some ( ref prompt_file ) ) = > {
pln ! ( "Using prompt file: {}" , prompt_file ) ;
pln ! ( "Using prompt file: {}" , prompt_file ) ;
let mut fs = std ::fs ::File ::open ( prompt_file ) ? ;
let mut fs = std ::fs ::File ::open ( prompt_file ) ? ;
let mut bs = Vec ::new ( ) ;
let mut bs = Vec ::new ( ) ;
@ -138,8 +167,12 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
String ::from_utf8 ( bs ) ?
String ::from_utf8 ( bs ) ?
}
}
_ = > {
_ = > {
eprintln! ( "Please provide either a prompt or a prompt file." ) ;
if cli . inference_server {
return Err ( "Please provide either a prompt or a prompt file." . into ( ) ) ;
"" . to_string ( )
} else {
eprintln! ( "Please provide either a prompt or a prompt file." ) ;
return Err ( "Please provide either a prompt or a prompt file." . into ( ) ) ;
}
}
}
} ;
} ;
@ -212,13 +245,445 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
) ? ;
) ? ;
pln ! ( "All is loaded. Starting inference." ) ;
pln ! ( "All is loaded. Starting inference." ) ;
let tr : Arc < Transformer > = Arc ::new ( tr ) ;
let tok : Arc < Tokenizer > = Arc ::new ( tok ) ;
if cli . inference_server {
#[ cfg(feature = " server " ) ]
{
server_inference ( cli , tr , tok , be_quiet , max_seq_len , params , max_threads )
}
#[ cfg(not(feature = " server " )) ]
{
eprintln! ( "The inference server feature is not enabled." ) ;
eprintln! ( "Please enable it with the \"inference-server\" feature." ) ;
Err ( "The inference server feature is not enabled." . into ( ) )
}
} else {
command_line_inference (
cli . clone ( ) ,
tr . clone ( ) ,
tok . clone ( ) ,
prompt . clone ( ) ,
be_quiet ,
max_seq_len ,
params . clone ( ) ,
max_threads ,
)
}
}
#[ cfg(feature = " server " ) ]
fn server_inference (
cli : Cli ,
tr : Arc < Transformer > ,
tok : Arc < Tokenizer > ,
be_quiet : bool ,
max_seq_len : usize ,
_params : ModelParams ,
_max_threads : usize ,
) -> Result < ( ) , Box < dyn std ::error ::Error > > {
macro_rules! pln {
( $( $arg :tt ) * ) = > {
if ! be_quiet {
std ::println! ( $( $arg ) * ) ;
}
} ;
}
let inference_server_port = cli . inference_server_port . unwrap_or ( 8080 ) ;
let inference_server_host = cli
. inference_server_host
. clone ( )
. unwrap_or ( "127.0.0.1" . to_string ( ) ) ;
let inference_server_max_concurrent_inferences =
cli . inference_server_max_concurrent_inferences . unwrap_or ( 5 ) ;
let inference_server_api_path = cli
. inference_server_api_path
. clone ( )
. unwrap_or ( "/rllama/v1/inference" . to_string ( ) ) ;
let inference_server_prompt_cache_size = cli . inference_server_prompt_cache_size . unwrap_or ( 50 ) ;
pln ! (
"Maximum concurrent inferences: {}" ,
inference_server_max_concurrent_inferences
) ;
pln ! ( "Prompt cache size: {}" , inference_server_prompt_cache_size ) ;
pln ! ( "Maximum sequence length: {}" , max_seq_len ) ;
pln ! (
"--- Starting HTTP server on {}:{}, answering to requests at {} ---" ,
inference_server_host ,
inference_server_port ,
inference_server_api_path
) ;
// If there are too many connections, they will hang until they get their turn.
// Maybe can later implement return 503 slow down or something similar.
let concurrent_requests_semaphore = Semaphore ::new ( inference_server_max_concurrent_inferences ) ;
let rocket_conf = rocket ::Config ::build ( rocket ::config ::Environment ::Production )
. address ( inference_server_host )
. port ( inference_server_port )
. finalize ( )
. unwrap ( ) ;
let app = rocket ::custom ( rocket_conf )
. mount ( & inference_server_api_path , routes ! [ handle_request ] )
. manage ( InferenceServerState {
transformer : tr ,
tokenizer : tok ,
max_seq_len ,
concurrent_requests_semaphore ,
attention_cache_repository : Arc ::new ( RwLock ::new ( AttentionCacheRepository ::empty (
inference_server_prompt_cache_size ,
) ) ) ,
} ) ;
app . launch ( ) ;
panic! ( "Starting web server failed." ) ;
}
fn is_false ( b : & bool ) -> bool {
! b
}
#[ derive(Serialize, Deserialize, Clone, Debug) ]
struct InferenceRequest {
temperature : Option < f32 > ,
top_k : Option < usize > ,
top_p : Option < f32 > ,
repetition_penalty : Option < f32 > ,
max_seq_len : Option < usize > ,
max_new_tokens : Option < usize > ,
no_token_sampling : Option < bool > ,
stop_at_end_token : Option < bool > ,
prompt : String ,
}
#[ cfg(feature = " server " ) ]
#[ derive(Serialize, Deserialize, Clone, Debug) ]
struct PredResult {
p : f32 ,
#[ serde(skip_serializing_if = " is_false " ) ]
is_end_token : bool ,
}
#[ cfg(feature = " server " ) ]
struct GeneratingSession {
transformer : Arc < Transformer > ,
token_sampler : TokenSampler ,
tokenizer : Arc < Tokenizer > ,
attention_cache_repository : Arc < RwLock < AttentionCacheRepository > > ,
tokens : Vec < TokenId > ,
req_max_seq_len : usize ,
req_max_new_tokens : usize ,
new_tokens_generated : usize ,
prev_pos : usize ,
no_token_sampling : bool ,
stop_at_end_token : bool ,
sent_stuff_last_time : bool ,
result : Vec < u8 > , // stores JSONL lines to be returned from read()
}
#[ cfg(feature = " server " ) ]
impl GeneratingSession {
fn read_from_result ( & mut self , buf : & mut [ u8 ] ) -> usize {
if ! self . result . is_empty ( ) {
if self . result . len ( ) < = buf . len ( ) {
for idx in 0 .. self . result . len ( ) {
buf [ idx ] = self . result [ idx ] ;
}
let len = self . result . len ( ) ;
self . sent_stuff_last_time = true ;
self . result . truncate ( 0 ) ;
return len ;
} else {
for idx in 0 .. buf . len ( ) {
buf [ idx ] = self . result [ idx ] ;
}
self . result = self . result [ buf . len ( ) .. ] . to_vec ( ) ;
self . sent_stuff_last_time = true ;
return buf . len ( ) ;
}
}
return 0 ;
}
}
#[ cfg(feature = " server " ) ]
impl Read for GeneratingSession {
fn read ( & mut self , buf : & mut [ u8 ] ) -> std ::io ::Result < usize > {
if self . sent_stuff_last_time & & self . result . is_empty ( ) {
// If we return WouldBlock every time we send something, it'll cause Rocket to
// flush available data.
self . sent_stuff_last_time = false ;
return Err ( std ::io ::Error ::new (
std ::io ::ErrorKind ::WouldBlock ,
"WouldBlock" ,
) ) ;
}
// Push more data to the upstream if we have something stored.
let bytes_read = self . read_from_result ( buf ) ;
if bytes_read > 0 {
return Ok ( bytes_read ) ;
}
if self . tokens . len ( ) > = self . req_max_seq_len {
return Ok ( 0 ) ;
}
if self . new_tokens_generated > = self . req_max_new_tokens {
return Ok ( 0 ) ;
}
let ( mut caches , update_pos ) = {
let mut ac = self . attention_cache_repository . write ( ) . unwrap ( ) ;
match ac . get ( & self . tokens ) {
Some ( ( c , pos ) ) if pos > = self . prev_pos = > ( c . true_clone ( ) , pos ) ,
Some ( _ ) = > {
std ::mem ::drop ( ac ) ;
( self . transformer . make_caches ( ) , 0 )
}
None = > {
let caches = self . transformer . make_caches ( ) ;
ac . put ( self . tokens . clone ( ) , caches . true_clone ( ) , self . prev_pos ) ;
( caches , self . prev_pos )
}
}
} ;
if update_pos > self . prev_pos {
self . prev_pos = update_pos ;
}
assert! ( self . result . is_empty ( ) ) ;
let predictions =
self . transformer
. forward ( & self . tokens [ self . prev_pos .. ] , self . prev_pos , & mut caches ) ;
self . prev_pos = self . tokens . len ( ) ;
let ( highest_pred_idx , token_prob ) =
self . token_sampler
. sample ( & predictions , self . tokenizer . as_ref ( ) , & self . tokens ) ;
self . tokens . push ( highest_pred_idx as TokenId ) ;
{
let mut ac = self . attention_cache_repository . write ( ) . unwrap ( ) ;
ac . put ( self . tokens . clone ( ) , caches , self . prev_pos ) ;
}
self . new_tokens_generated + = 1 ;
let token : & str = self . tokenizer . id_to_str ( highest_pred_idx as TokenId ) ;
let mut is_end_token : bool = false ;
if token = = "</s>" & & self . stop_at_end_token {
self . new_tokens_generated = self . req_max_new_tokens ;
is_end_token = true ;
}
let mut result : BTreeMap < String , PredResult > = BTreeMap ::new ( ) ;
if self . no_token_sampling {
// All predictions go the line.
let probs = self
. token_sampler
. logits_to_btreemap ( & predictions , self . tokenizer . as_ref ( ) ) ;
for ( k , v ) in probs . into_iter ( ) {
let mut is_end_token : bool = false ;
if k = = "</s>" {
is_end_token = true ;
}
result . insert (
k ,
PredResult {
p : v ,
is_end_token : is_end_token ,
} ,
) ;
}
// Convert to JSON
let json = serde_json ::to_string ( & result ) . unwrap ( ) ;
self . result . extend ( json . as_bytes ( ) ) ;
self . result . push ( b'\n' ) ;
return Ok ( self . read_from_result ( buf ) ) ;
} else {
result . insert (
token . to_string ( ) ,
PredResult {
p : token_prob ,
is_end_token ,
} ,
) ;
let json = serde_json ::to_string ( & result ) . unwrap ( ) ;
self . result . extend ( json . as_bytes ( ) ) ;
self . result . push ( b'\n' ) ;
return Ok ( self . read_from_result ( buf ) ) ;
}
}
}
#[ cfg(feature = " server " ) ]
struct AttentionCacheRepository {
caches : BTreeMap < Vec < TokenId > , ( TransformerCaches , usize , std ::time ::Instant ) > ,
max_sz : usize ,
}
#[ cfg(feature = " server " ) ]
impl AttentionCacheRepository {
fn empty ( max_size : usize ) -> AttentionCacheRepository {
AttentionCacheRepository {
caches : BTreeMap ::new ( ) ,
max_sz : max_size ,
}
}
/// Makes sure the cache repository is not larger than sz, evicts any older items.
fn limit_size ( & mut self , sz : usize ) {
if sz = = 0 {
self . caches = BTreeMap ::new ( ) ;
return ;
}
// Slow algorithm but I guess our cache will never be unimaginably large so it's probably
// fine
while self . caches . len ( ) > sz {
let mut oldest_time = None ;
let mut oldest_key : Option < & Vec < TokenId > > = None ;
for ( k , ( _ , _ , time ) ) in self . caches . iter ( ) {
if oldest_time . is_none ( ) | | time < oldest_time . unwrap ( ) {
oldest_time = Some ( time ) ;
oldest_key = Some ( k ) ;
}
}
let oldest_key = oldest_key . unwrap ( ) . clone ( ) ;
self . caches . remove ( & oldest_key ) ;
}
}
fn get ( & self , tokens : & [ TokenId ] ) -> Option < ( & TransformerCaches , usize ) > {
if let Some ( ( caches , pos , _ ) ) = self . caches . get ( tokens ) {
Some ( ( caches , * pos ) )
} else {
None
}
}
fn put ( & mut self , tokens : Vec < TokenId > , caches : TransformerCaches , prev_pos : usize ) {
self . caches
. insert ( tokens , ( caches , prev_pos , std ::time ::Instant ::now ( ) ) ) ;
self . limit_size ( self . max_sz ) ;
}
}
#[ cfg(feature = " server " ) ]
#[ derive(Clone) ]
struct InferenceServerState {
transformer : Arc < Transformer > ,
tokenizer : Arc < Tokenizer > ,
max_seq_len : usize ,
concurrent_requests_semaphore : Semaphore ,
attention_cache_repository : Arc < RwLock < AttentionCacheRepository > > ,
}
#[ cfg(feature = " server " ) ]
#[ post( " / " , data = " <input> " ) ]
fn handle_request (
state : State < InferenceServerState > ,
input : Data ,
) -> Result < Stream < GeneratingSession > , status ::BadRequest < String > > {
let _lock = state . concurrent_requests_semaphore . acquire ( ) ;
let tr = state . transformer . clone ( ) ;
let tok = state . tokenizer . clone ( ) ;
let mut data = input . open ( ) ;
let mut databuf : Vec < u8 > = Vec ::new ( ) ;
data . read_to_end ( & mut databuf ) . unwrap ( ) ;
// Parse the JSON out of the request
let request : InferenceRequest = match serde_json ::from_slice ( & databuf ) {
Err ( _e ) = > {
return Err ( status ::BadRequest ( Some ( "Invalid JSON." . to_string ( ) ) ) ) ;
}
Ok ( ir ) = > ir ,
} ;
let stop_at_end_token = request . stop_at_end_token . unwrap_or ( true ) ;
let temperature = request . temperature . unwrap_or ( 1.0 ) ;
let top_k = request . top_k . unwrap_or ( 20 ) ;
let top_p = request . top_p . unwrap_or ( 1.0 ) ;
let repetition_penalty = request . repetition_penalty . unwrap_or ( 1.0 ) ;
let mut req_max_seq_len = request . max_seq_len . unwrap_or ( state . max_seq_len ) ;
if req_max_seq_len > state . max_seq_len {
req_max_seq_len = state . max_seq_len ;
}
let req_max_new_tokens = request . max_new_tokens . unwrap_or ( 20 ) ;
let no_token_sampling = request . no_token_sampling . unwrap_or ( false ) ;
let prompt = request . prompt ;
if temperature . is_nan ( ) {
return Err ( status ::BadRequest ( Some (
"Temperature must be a number." . to_string ( ) ,
) ) ) ;
}
if top_k = = 0 {
return Err ( status ::BadRequest ( Some (
"Top-k must be greater than 0." . to_string ( ) ,
) ) ) ;
}
if top_p . is_nan ( ) {
return Err ( status ::BadRequest ( Some (
"Top-p must be a number." . to_string ( ) ,
) ) ) ;
}
if repetition_penalty . is_nan ( ) {
return Err ( status ::BadRequest ( Some (
"Repetition penalty must be a number." . to_string ( ) ,
) ) ) ;
}
let token_sampler = TokenSampler ::new ( )
. temperature ( temperature )
. top_p ( top_p )
. top_k ( top_k )
. repetition_penalty ( repetition_penalty ) ;
let toks_id : Vec < TokenId > = tok . tokenize_to_ids ( prompt . clone ( ) ) ;
let gsession = GeneratingSession {
transformer : tr ,
tokenizer : tok ,
attention_cache_repository : state . attention_cache_repository . clone ( ) ,
token_sampler : token_sampler ,
tokens : toks_id ,
req_max_seq_len : req_max_seq_len ,
req_max_new_tokens : req_max_new_tokens ,
new_tokens_generated : 0 ,
prev_pos : 0 ,
no_token_sampling : no_token_sampling ,
stop_at_end_token : stop_at_end_token ,
sent_stuff_last_time : false ,
result : Vec ::new ( ) ,
} ;
return Ok ( rocket ::response ::Stream ::chunked ( gsession , 1024 ) ) ;
}
fn command_line_inference (
cli : Cli ,
tr : Arc < Transformer > ,
tok : Arc < Tokenizer > ,
prompt : String ,
be_quiet : bool ,
max_seq_len : usize ,
params : ModelParams ,
max_threads : usize ,
) -> Result < ( ) , Box < dyn std ::error ::Error > > {
// Custom println-like macro that respects be_quiet
macro_rules! pln {
( $( $arg :tt ) * ) = > {
if ! be_quiet {
std ::println! ( $( $arg ) * ) ;
}
} ;
}
let mut toks_id : Vec < TokenId > = tok . tokenize_to_ids ( prompt . clone ( ) ) ;
let mut toks_id : Vec < TokenId > = tok . tokenize_to_ids ( prompt . clone ( ) ) ;
let mut prev_pos = 0 ;
let mut prev_pos = 0 ;
let mut token_sampler = TokenSampler ::new ( )
let mut token_sampler = TokenSampler ::new ( )
. temperature ( 0.8 )
. temperature ( 1.0 )
. top_p ( 0.9 )
. top_p ( 1.0 )
. top_k ( 50 )
. top_k ( 2 0)
. repetition_penalty ( 0.8 ) ;
. repetition_penalty ( 1.0 ) ;
if let Some ( temperature ) = cli . temperature {
if let Some ( temperature ) = cli . temperature {
token_sampler = token_sampler . temperature ( temperature ) ;
token_sampler = token_sampler . temperature ( temperature ) ;