Add a flag that will exit the HTTP server after just one query.

This is for some experiments I want to run to kill the server gracefully
whenever I pull out the logits out of it from a Python script.
master
Mikko Juola 3 years ago
parent 957a8f9f98
commit 26f343ad15

@ -115,6 +115,9 @@ The command line flags for this are:
calculations should be cached. Default is 50. This speeds up token calculations should be cached. Default is 50. This speeds up token
generation for prompts that were already requested before, however it also generation for prompts that were already requested before, however it also
increases memory use as the cache gets more full. increases memory use as the cache gets more full.
* `--inference-server-exit-after-one-query` will make the server exit with
exit code 0 after it has served one HTTP query. This is used for
troubleshooting and experiments.
Prompts and flags related to token sampling are all ignored in inference server Prompts and flags related to token sampling are all ignored in inference server
mode. Instead, they are obtained from each HTTP JSON API request. mode. Instead, they are obtained from each HTTP JSON API request.

@ -74,6 +74,9 @@ struct Cli {
#[arg(long)] #[arg(long)]
inference_server_prompt_cache_size: Option<usize>, inference_server_prompt_cache_size: Option<usize>,
#[arg(long, action)]
inference_server_exit_after_one_query: bool,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
@ -337,6 +340,7 @@ fn server_inference(
attention_cache_repository: Arc::new(RwLock::new(AttentionCacheRepository::empty( attention_cache_repository: Arc::new(RwLock::new(AttentionCacheRepository::empty(
inference_server_prompt_cache_size, inference_server_prompt_cache_size,
))), ))),
exit_after_one_query: cli.inference_server_exit_after_one_query,
}); });
app.launch(); app.launch();
@ -382,6 +386,7 @@ struct GeneratingSession {
no_token_sampling: bool, no_token_sampling: bool,
stop_at_end_token: bool, stop_at_end_token: bool,
sent_stuff_last_time: bool, sent_stuff_last_time: bool,
exit_after_one_query: bool,
result: Vec<u8>, // stores JSONL lines to be returned from read() result: Vec<u8>, // stores JSONL lines to be returned from read()
} }
@ -429,9 +434,15 @@ impl Read for GeneratingSession {
return Ok(bytes_read); return Ok(bytes_read);
} }
if self.tokens.len() >= self.req_max_seq_len { if self.tokens.len() >= self.req_max_seq_len {
if self.exit_after_one_query {
std::process::exit(0);
}
return Ok(0); return Ok(0);
} }
if self.new_tokens_generated >= self.req_max_new_tokens { if self.new_tokens_generated >= self.req_max_new_tokens {
if self.exit_after_one_query {
std::process::exit(0);
}
return Ok(0); return Ok(0);
} }
@ -575,6 +586,7 @@ struct InferenceServerState {
max_seq_len: usize, max_seq_len: usize,
concurrent_requests_semaphore: Semaphore, concurrent_requests_semaphore: Semaphore,
attention_cache_repository: Arc<RwLock<AttentionCacheRepository>>, attention_cache_repository: Arc<RwLock<AttentionCacheRepository>>,
exit_after_one_query: bool,
} }
#[cfg(feature = "server")] #[cfg(feature = "server")]
@ -652,6 +664,7 @@ fn handle_request(
no_token_sampling: no_token_sampling, no_token_sampling: no_token_sampling,
stop_at_end_token: stop_at_end_token, stop_at_end_token: stop_at_end_token,
sent_stuff_last_time: false, sent_stuff_last_time: false,
exit_after_one_query: state.exit_after_one_query,
result: Vec::new(), result: Vec::new(),
}; };

Loading…
Cancel
Save