diff --git a/README.md b/README.md index f707752..de8417b 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,9 @@ The command line flags for this are: calculations should be cached. Default is 50. This speeds up token generation for prompts that were already requested before, however it also increases memory use as the cache gets more full. + * `--inference-server-exit-after-one-query` will make the server exit with + exit code 0 after it has served one HTTP query. This is used for + troubleshooting and experiments. Prompts and flags related to token sampling are all ignored in inference server mode. Instead, they are obtained from each HTTP JSON API request. diff --git a/src/rllama_main.rs b/src/rllama_main.rs index 45a2228..84dfa1d 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -74,6 +74,9 @@ struct Cli { #[arg(long)] inference_server_prompt_cache_size: Option, + + #[arg(long, action)] + inference_server_exit_after_one_query: bool, } #[derive(Clone, Serialize, Deserialize)] @@ -337,6 +340,7 @@ fn server_inference( attention_cache_repository: Arc::new(RwLock::new(AttentionCacheRepository::empty( inference_server_prompt_cache_size, ))), + exit_after_one_query: cli.inference_server_exit_after_one_query, }); app.launch(); @@ -382,6 +386,7 @@ struct GeneratingSession { no_token_sampling: bool, stop_at_end_token: bool, sent_stuff_last_time: bool, + exit_after_one_query: bool, result: Vec, // stores JSONL lines to be returned from read() } @@ -429,9 +434,15 @@ impl Read for GeneratingSession { return Ok(bytes_read); } if self.tokens.len() >= self.req_max_seq_len { + if self.exit_after_one_query { + std::process::exit(0); + } return Ok(0); } if self.new_tokens_generated >= self.req_max_new_tokens { + if self.exit_after_one_query { + std::process::exit(0); + } return Ok(0); } @@ -575,6 +586,7 @@ struct InferenceServerState { max_seq_len: usize, concurrent_requests_semaphore: Semaphore, attention_cache_repository: Arc>, + exit_after_one_query: bool, } #[cfg(feature = "server")] @@ -652,6 +664,7 @@ fn handle_request( no_token_sampling: no_token_sampling, stop_at_end_token: stop_at_end_token, sent_stuff_last_time: false, + exit_after_one_query: state.exit_after_one_query, result: Vec::new(), };