|
|
|
|
@ -74,6 +74,9 @@ struct Cli {
|
|
|
|
|
|
|
|
|
|
#[arg(long)]
|
|
|
|
|
inference_server_prompt_cache_size: Option<usize>,
|
|
|
|
|
|
|
|
|
|
#[arg(long, action)]
|
|
|
|
|
inference_server_exit_after_one_query: bool,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Clone, Serialize, Deserialize)]
|
|
|
|
|
@ -337,6 +340,7 @@ fn server_inference(
|
|
|
|
|
attention_cache_repository: Arc::new(RwLock::new(AttentionCacheRepository::empty(
|
|
|
|
|
inference_server_prompt_cache_size,
|
|
|
|
|
))),
|
|
|
|
|
exit_after_one_query: cli.inference_server_exit_after_one_query,
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
app.launch();
|
|
|
|
|
@ -382,6 +386,7 @@ struct GeneratingSession {
|
|
|
|
|
no_token_sampling: bool,
|
|
|
|
|
stop_at_end_token: bool,
|
|
|
|
|
sent_stuff_last_time: bool,
|
|
|
|
|
exit_after_one_query: bool,
|
|
|
|
|
result: Vec<u8>, // stores JSONL lines to be returned from read()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -429,9 +434,15 @@ impl Read for GeneratingSession {
|
|
|
|
|
return Ok(bytes_read);
|
|
|
|
|
}
|
|
|
|
|
if self.tokens.len() >= self.req_max_seq_len {
|
|
|
|
|
if self.exit_after_one_query {
|
|
|
|
|
std::process::exit(0);
|
|
|
|
|
}
|
|
|
|
|
return Ok(0);
|
|
|
|
|
}
|
|
|
|
|
if self.new_tokens_generated >= self.req_max_new_tokens {
|
|
|
|
|
if self.exit_after_one_query {
|
|
|
|
|
std::process::exit(0);
|
|
|
|
|
}
|
|
|
|
|
return Ok(0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -575,6 +586,7 @@ struct InferenceServerState {
|
|
|
|
|
max_seq_len: usize,
|
|
|
|
|
concurrent_requests_semaphore: Semaphore,
|
|
|
|
|
attention_cache_repository: Arc<RwLock<AttentionCacheRepository>>,
|
|
|
|
|
exit_after_one_query: bool,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[cfg(feature = "server")]
|
|
|
|
|
@ -652,6 +664,7 @@ fn handle_request(
|
|
|
|
|
no_token_sampling: no_token_sampling,
|
|
|
|
|
stop_at_end_token: stop_at_end_token,
|
|
|
|
|
sent_stuff_last_time: false,
|
|
|
|
|
exit_after_one_query: state.exit_after_one_query,
|
|
|
|
|
result: Vec::new(),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|