@ -1,18 +1,23 @@
use crate ::embedding ::Embedding ;
use crate ::semaphore ::Semaphore ;
#[ cfg(feature = " opencl " ) ]
use crate ::tensor_opencl_support ::OpenCL ;
use crate ::token_sampler ::TokenSampler ;
use crate ::tokenizer ::{ TokenId , Tokenizer } ;
use crate ::transformer ::{ DataSettings , Transformer };
use crate ::transformer ::{ DataSettings , Transformer , TransformerCaches };
use crate ::unpickler ;
use crate ::unpickler ::Value ;
use clap ::Parser ;
use colored ::Colorize ;
#[ cfg(feature = " server " ) ]
use rocket ::{ response ::status , response ::Stream , Data , State } ;
use serde ::{ Deserialize , Serialize } ;
use std ::collections ::BTreeMap ;
use std ::io ::{ Read , Write } ;
use std ::path ::PathBuf ;
use std ::sync ::{ Arc , RwLock } ;
#[ derive(Parser )]
#[ derive(Parser , Clone )]
#[ command(author, version, about, long_about = None) ]
struct Cli {
#[ arg(long) ]
@ -22,6 +27,9 @@ struct Cli {
#[ arg(long) ]
param_path : String ,
#[ arg(short, long, action) ]
quiet : bool ,
#[ arg(long) ]
prompt : Option < String > ,
#[ arg(long) ]
@ -36,10 +44,39 @@ struct Cli {
top_p : Option < f32 > ,
#[ arg(long) ]
top_k : Option < i32 > ,
#[ arg(long) ]
repetition_penalty : Option < f32 > ,
#[ arg(long) ]
max_threads : Option < usize > ,
#[ arg(long, action) ]
f16 : bool ,
#[ cfg(feature = " opencl " ) ]
#[ arg(long) ]
opencl_device : Option < usize > ,
#[ arg(long, action) ]
inference_server : bool ,
#[ arg(long) ]
inference_server_port : Option < u16 > ,
#[ arg(long) ]
inference_server_host : Option < String > ,
#[ arg(long) ]
inference_server_max_concurrent_inferences : Option < usize > ,
#[ arg(long) ]
inference_server_api_path : Option < String > ,
#[ arg(long) ]
inference_server_prompt_cache_size : Option < usize > ,
#[ arg(long, action) ]
inference_server_exit_after_one_query : bool ,
}
#[ derive(Clone, Serialize, Deserialize) ]
@ -54,14 +91,46 @@ struct ModelParams {
pub fn main ( ) -> Result < ( ) , Box < dyn std ::error ::Error > > {
let cli = Cli ::parse ( ) ;
let model_path = cli . model_path ;
let tokenizer_path = cli . tokenizer_path ;
let param_path = cli . param_path ;
let model_path = cli . model_path . clone ( ) ;
let tokenizer_path = cli . tokenizer_path . clone ( ) ;
let param_path = cli . param_path . clone ( ) ;
#[ cfg(not(feature = " server " )) ]
if cli . inference_server {
eprintln! ( "Inference server is not enabled in this build." ) ;
return Err ( "Inference server is not enabled in this build." . into ( ) ) ;
}
let max_threads : usize = match cli . max_threads {
None = > rayon ::current_num_threads ( ) ,
Some ( max_threads ) = > {
rayon ::ThreadPoolBuilder ::new ( )
. num_threads ( max_threads )
. build_global ( )
. unwrap ( ) ;
max_threads
}
} ;
let mut be_quiet : bool = false ;
if ! colored ::control ::SHOULD_COLORIZE . should_colorize ( ) {
be_quiet = true ;
}
if cli . quiet {
be_quiet = true ;
}
if be_quiet {
colored ::control ::SHOULD_COLORIZE . set_override ( false ) ;
}
// Custom println-like macro that respects be_quiet
macro_rules! pln {
( $( $arg :tt ) * ) = > {
if ! be_quiet {
std ::println! ( $( $arg ) * ) ;
}
} ;
}
#[ cfg(feature = " opencl " ) ]
let opencl : Option < OpenCL > = {
@ -79,15 +148,6 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
}
} ;
// Custom println-like macro that respects be_quiet
macro_rules! pln {
( $( $arg :tt ) * ) = > {
if ! be_quiet {
std ::println! ( $( $arg ) * ) ;
}
} ;
}
// Read ModelParams from param_path, we expect it to be JSON
let mut fs = std ::fs ::File ::open ( & param_path ) ? ;
let mut bs = Vec ::new ( ) ;
@ -96,12 +156,12 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let params : ModelParams = serde_json ::from_slice ( & bs ) ? ;
pln ! ( "Loaded model parameters from {}." , param_path ) ;
let prompt : String = match ( cli . prompt , cli . prompt_file ) {
( Some ( prompt ) , None ) = > {
let prompt : String = match ( & cli . prompt , & cli . prompt_file ) {
( Some ( ref prompt ) , None ) = > {
pln ! ( "Using prompt: {}" , prompt ) ;
prompt
prompt . clone ( )
}
( None , Some ( prompt_file ) ) = > {
( None , Some ( ref prompt_file ) ) = > {
pln ! ( "Using prompt file: {}" , prompt_file ) ;
let mut fs = std ::fs ::File ::open ( prompt_file ) ? ;
let mut bs = Vec ::new ( ) ;
@ -110,8 +170,12 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
String ::from_utf8 ( bs ) ?
}
_ = > {
eprintln! ( "Please provide either a prompt or a prompt file." ) ;
return Err ( "Please provide either a prompt or a prompt file." . into ( ) ) ;
if cli . inference_server {
"" . to_string ( )
} else {
eprintln! ( "Please provide either a prompt or a prompt file." ) ;
return Err ( "Please provide either a prompt or a prompt file." . into ( ) ) ;
}
}
} ;
@ -150,12 +214,9 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
pln ! ( "Loading embeddings from {}..." , model_path ) ;
let emb = Embedding ::from_unpickled ( & unpickle_results , model_path . clone ( ) ) ? ;
let max_seq_len = match cli . max_seq_len {
Some ( max_seq_len ) = > max_seq_len ,
None = > 1024 ,
} ;
let max_seq_len = cli . max_seq_len . unwrap_or ( 1024 ) ;
let data_settings = {
let mut data_settings = {
#[ cfg(feature = " opencl " ) ]
{
if let Some ( opencl ) = opencl {
@ -169,6 +230,10 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
DataSettings ::new ( )
} ;
if cli . f16 {
data_settings = data_settings . force_f16 ( ) ;
}
pln ! ( "Loading transformer weights from {}..." , model_path ) ;
let tr = Transformer ::from_unpickled (
& unpickle_results ,
@ -183,9 +248,455 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
) ? ;
pln ! ( "All is loaded. Starting inference." ) ;
let tr : Arc < Transformer > = Arc ::new ( tr ) ;
let tok : Arc < Tokenizer > = Arc ::new ( tok ) ;
if cli . inference_server {
#[ cfg(feature = " server " ) ]
{
server_inference ( cli , tr , tok , be_quiet , max_seq_len , params , max_threads )
}
#[ cfg(not(feature = " server " )) ]
{
eprintln! ( "The inference server feature is not enabled." ) ;
eprintln! ( "Please enable it with the \"inference-server\" feature." ) ;
Err ( "The inference server feature is not enabled." . into ( ) )
}
} else {
command_line_inference (
cli . clone ( ) ,
tr . clone ( ) ,
tok . clone ( ) ,
prompt . clone ( ) ,
be_quiet ,
max_seq_len ,
params . clone ( ) ,
max_threads ,
)
}
}
#[ cfg(feature = " server " ) ]
fn server_inference (
cli : Cli ,
tr : Arc < Transformer > ,
tok : Arc < Tokenizer > ,
be_quiet : bool ,
max_seq_len : usize ,
_params : ModelParams ,
_max_threads : usize ,
) -> Result < ( ) , Box < dyn std ::error ::Error > > {
macro_rules! pln {
( $( $arg :tt ) * ) = > {
if ! be_quiet {
std ::println! ( $( $arg ) * ) ;
}
} ;
}
let inference_server_port = cli . inference_server_port . unwrap_or ( 8080 ) ;
let inference_server_host = cli
. inference_server_host
. clone ( )
. unwrap_or ( "127.0.0.1" . to_string ( ) ) ;
let inference_server_max_concurrent_inferences =
cli . inference_server_max_concurrent_inferences . unwrap_or ( 5 ) ;
let inference_server_api_path = cli
. inference_server_api_path
. clone ( )
. unwrap_or ( "/rllama/v1/inference" . to_string ( ) ) ;
let inference_server_prompt_cache_size = cli . inference_server_prompt_cache_size . unwrap_or ( 50 ) ;
pln ! (
"Maximum concurrent inferences: {}" ,
inference_server_max_concurrent_inferences
) ;
pln ! ( "Prompt cache size: {}" , inference_server_prompt_cache_size ) ;
pln ! ( "Maximum sequence length: {}" , max_seq_len ) ;
pln ! (
"--- Starting HTTP server on {}:{}, answering to requests at {} ---" ,
inference_server_host ,
inference_server_port ,
inference_server_api_path
) ;
// If there are too many connections, they will hang until they get their turn.
// Maybe can later implement return 503 slow down or something similar.
let concurrent_requests_semaphore = Semaphore ::new ( inference_server_max_concurrent_inferences ) ;
let rocket_conf = rocket ::Config ::build ( rocket ::config ::Environment ::Production )
. address ( inference_server_host )
. port ( inference_server_port )
. finalize ( )
. unwrap ( ) ;
let app = rocket ::custom ( rocket_conf )
. mount ( & inference_server_api_path , routes ! [ handle_request ] )
. manage ( InferenceServerState {
transformer : tr ,
tokenizer : tok ,
max_seq_len ,
concurrent_requests_semaphore ,
attention_cache_repository : Arc ::new ( RwLock ::new ( AttentionCacheRepository ::empty (
inference_server_prompt_cache_size ,
) ) ) ,
exit_after_one_query : cli . inference_server_exit_after_one_query ,
} ) ;
app . launch ( ) ;
panic! ( "Starting web server failed." ) ;
}
fn is_false ( b : & bool ) -> bool {
! b
}
#[ derive(Serialize, Deserialize, Clone, Debug) ]
struct InferenceRequest {
temperature : Option < f32 > ,
top_k : Option < usize > ,
top_p : Option < f32 > ,
repetition_penalty : Option < f32 > ,
max_seq_len : Option < usize > ,
max_new_tokens : Option < usize > ,
no_token_sampling : Option < bool > ,
stop_at_end_token : Option < bool > ,
prompt : String ,
}
#[ cfg(feature = " server " ) ]
#[ derive(Serialize, Deserialize, Clone, Debug) ]
struct PredResult {
p : f32 ,
#[ serde(skip_serializing_if = " is_false " ) ]
is_end_token : bool ,
}
#[ cfg(feature = " server " ) ]
struct GeneratingSession {
transformer : Arc < Transformer > ,
token_sampler : TokenSampler ,
tokenizer : Arc < Tokenizer > ,
attention_cache_repository : Arc < RwLock < AttentionCacheRepository > > ,
tokens : Vec < TokenId > ,
req_max_seq_len : usize ,
req_max_new_tokens : usize ,
new_tokens_generated : usize ,
prev_pos : usize ,
no_token_sampling : bool ,
stop_at_end_token : bool ,
sent_stuff_last_time : bool ,
exit_after_one_query : bool ,
result : Vec < u8 > , // stores JSONL lines to be returned from read()
}
#[ cfg(feature = " server " ) ]
impl GeneratingSession {
fn read_from_result ( & mut self , buf : & mut [ u8 ] ) -> usize {
if ! self . result . is_empty ( ) {
if self . result . len ( ) < = buf . len ( ) {
for idx in 0 .. self . result . len ( ) {
buf [ idx ] = self . result [ idx ] ;
}
let len = self . result . len ( ) ;
self . sent_stuff_last_time = true ;
self . result . truncate ( 0 ) ;
return len ;
} else {
for idx in 0 .. buf . len ( ) {
buf [ idx ] = self . result [ idx ] ;
}
self . result = self . result [ buf . len ( ) .. ] . to_vec ( ) ;
self . sent_stuff_last_time = true ;
return buf . len ( ) ;
}
}
return 0 ;
}
}
#[ cfg(feature = " server " ) ]
impl Read for GeneratingSession {
fn read ( & mut self , buf : & mut [ u8 ] ) -> std ::io ::Result < usize > {
if self . sent_stuff_last_time & & self . result . is_empty ( ) {
// If we return WouldBlock every time we send something, it'll cause Rocket to
// flush available data.
self . sent_stuff_last_time = false ;
return Err ( std ::io ::Error ::new (
std ::io ::ErrorKind ::WouldBlock ,
"WouldBlock" ,
) ) ;
}
// Push more data to the upstream if we have something stored.
let bytes_read = self . read_from_result ( buf ) ;
if bytes_read > 0 {
return Ok ( bytes_read ) ;
}
if self . tokens . len ( ) > = self . req_max_seq_len {
if self . exit_after_one_query {
std ::process ::exit ( 0 ) ;
}
return Ok ( 0 ) ;
}
if self . new_tokens_generated > = self . req_max_new_tokens {
if self . exit_after_one_query {
std ::process ::exit ( 0 ) ;
}
return Ok ( 0 ) ;
}
let ( mut caches , update_pos ) = {
let mut ac = self . attention_cache_repository . write ( ) . unwrap ( ) ;
match ac . get ( & self . tokens ) {
Some ( ( c , pos ) ) if pos > = self . prev_pos = > ( c . true_clone ( ) , pos ) ,
Some ( _ ) = > {
std ::mem ::drop ( ac ) ;
( self . transformer . make_caches ( ) , 0 )
}
None = > {
let caches = self . transformer . make_caches ( ) ;
ac . put ( self . tokens . clone ( ) , caches . true_clone ( ) , self . prev_pos ) ;
( caches , self . prev_pos )
}
}
} ;
if update_pos > self . prev_pos {
self . prev_pos = update_pos ;
}
assert! ( self . result . is_empty ( ) ) ;
let predictions =
self . transformer
. forward ( & self . tokens [ self . prev_pos .. ] , self . prev_pos , & mut caches ) ;
self . prev_pos = self . tokens . len ( ) ;
let ( highest_pred_idx , token_prob ) =
self . token_sampler
. sample ( & predictions , self . tokenizer . as_ref ( ) , & self . tokens ) ;
self . tokens . push ( highest_pred_idx as TokenId ) ;
{
let mut ac = self . attention_cache_repository . write ( ) . unwrap ( ) ;
ac . put ( self . tokens . clone ( ) , caches , self . prev_pos ) ;
}
self . new_tokens_generated + = 1 ;
let token : & str = self . tokenizer . id_to_str ( highest_pred_idx as TokenId ) ;
let mut is_end_token : bool = false ;
if token = = "</s>" & & self . stop_at_end_token {
self . new_tokens_generated = self . req_max_new_tokens ;
is_end_token = true ;
}
let mut result : BTreeMap < String , PredResult > = BTreeMap ::new ( ) ;
if self . no_token_sampling {
// All predictions go the line.
let probs = self
. token_sampler
. logits_to_btreemap ( & predictions , self . tokenizer . as_ref ( ) ) ;
for ( k , v ) in probs . into_iter ( ) {
let mut is_end_token : bool = false ;
if k = = "</s>" {
is_end_token = true ;
}
result . insert (
k ,
PredResult {
p : v ,
is_end_token : is_end_token ,
} ,
) ;
}
// Convert to JSON
let json = serde_json ::to_string ( & result ) . unwrap ( ) ;
self . result . extend ( json . as_bytes ( ) ) ;
self . result . push ( b'\n' ) ;
return Ok ( self . read_from_result ( buf ) ) ;
} else {
result . insert (
token . to_string ( ) ,
PredResult {
p : token_prob ,
is_end_token ,
} ,
) ;
let json = serde_json ::to_string ( & result ) . unwrap ( ) ;
self . result . extend ( json . as_bytes ( ) ) ;
self . result . push ( b'\n' ) ;
return Ok ( self . read_from_result ( buf ) ) ;
}
}
}
#[ cfg(feature = " server " ) ]
struct AttentionCacheRepository {
caches : BTreeMap < Vec < TokenId > , ( TransformerCaches , usize , std ::time ::Instant ) > ,
max_sz : usize ,
}
#[ cfg(feature = " server " ) ]
impl AttentionCacheRepository {
fn empty ( max_size : usize ) -> AttentionCacheRepository {
AttentionCacheRepository {
caches : BTreeMap ::new ( ) ,
max_sz : max_size ,
}
}
/// Makes sure the cache repository is not larger than sz, evicts any older items.
fn limit_size ( & mut self , sz : usize ) {
if sz = = 0 {
self . caches = BTreeMap ::new ( ) ;
return ;
}
// Slow algorithm but I guess our cache will never be unimaginably large so it's probably
// fine
while self . caches . len ( ) > sz {
let mut oldest_time = None ;
let mut oldest_key : Option < & Vec < TokenId > > = None ;
for ( k , ( _ , _ , time ) ) in self . caches . iter ( ) {
if oldest_time . is_none ( ) | | time < oldest_time . unwrap ( ) {
oldest_time = Some ( time ) ;
oldest_key = Some ( k ) ;
}
}
let oldest_key = oldest_key . unwrap ( ) . clone ( ) ;
self . caches . remove ( & oldest_key ) ;
}
}
fn get ( & self , tokens : & [ TokenId ] ) -> Option < ( & TransformerCaches , usize ) > {
if let Some ( ( caches , pos , _ ) ) = self . caches . get ( tokens ) {
Some ( ( caches , * pos ) )
} else {
None
}
}
fn put ( & mut self , tokens : Vec < TokenId > , caches : TransformerCaches , prev_pos : usize ) {
self . caches
. insert ( tokens , ( caches , prev_pos , std ::time ::Instant ::now ( ) ) ) ;
self . limit_size ( self . max_sz ) ;
}
}
#[ cfg(feature = " server " ) ]
#[ derive(Clone) ]
struct InferenceServerState {
transformer : Arc < Transformer > ,
tokenizer : Arc < Tokenizer > ,
max_seq_len : usize ,
concurrent_requests_semaphore : Semaphore ,
attention_cache_repository : Arc < RwLock < AttentionCacheRepository > > ,
exit_after_one_query : bool ,
}
#[ cfg(feature = " server " ) ]
#[ post( " / " , data = " <input> " ) ]
fn handle_request (
state : State < InferenceServerState > ,
input : Data ,
) -> Result < Stream < GeneratingSession > , status ::BadRequest < String > > {
let _lock = state . concurrent_requests_semaphore . acquire ( ) ;
let tr = state . transformer . clone ( ) ;
let tok = state . tokenizer . clone ( ) ;
let mut data = input . open ( ) ;
let mut databuf : Vec < u8 > = Vec ::new ( ) ;
data . read_to_end ( & mut databuf ) . unwrap ( ) ;
// Parse the JSON out of the request
let request : InferenceRequest = match serde_json ::from_slice ( & databuf ) {
Err ( _e ) = > {
return Err ( status ::BadRequest ( Some ( "Invalid JSON." . to_string ( ) ) ) ) ;
}
Ok ( ir ) = > ir ,
} ;
let stop_at_end_token = request . stop_at_end_token . unwrap_or ( true ) ;
let temperature = request . temperature . unwrap_or ( 1.0 ) ;
let top_k = request . top_k . unwrap_or ( 20 ) ;
let top_p = request . top_p . unwrap_or ( 1.0 ) ;
let repetition_penalty = request . repetition_penalty . unwrap_or ( 1.0 ) ;
let mut req_max_seq_len = request . max_seq_len . unwrap_or ( state . max_seq_len ) ;
if req_max_seq_len > state . max_seq_len {
req_max_seq_len = state . max_seq_len ;
}
let req_max_new_tokens = request . max_new_tokens . unwrap_or ( 20 ) ;
let no_token_sampling = request . no_token_sampling . unwrap_or ( false ) ;
let prompt = request . prompt ;
if temperature . is_nan ( ) {
return Err ( status ::BadRequest ( Some (
"Temperature must be a number." . to_string ( ) ,
) ) ) ;
}
if top_k = = 0 {
return Err ( status ::BadRequest ( Some (
"Top-k must be greater than 0." . to_string ( ) ,
) ) ) ;
}
if top_p . is_nan ( ) {
return Err ( status ::BadRequest ( Some (
"Top-p must be a number." . to_string ( ) ,
) ) ) ;
}
if repetition_penalty . is_nan ( ) {
return Err ( status ::BadRequest ( Some (
"Repetition penalty must be a number." . to_string ( ) ,
) ) ) ;
}
let token_sampler = TokenSampler ::new ( )
. temperature ( temperature )
. top_p ( top_p )
. top_k ( top_k )
. repetition_penalty ( repetition_penalty ) ;
let toks_id : Vec < TokenId > = tok . tokenize_to_ids ( prompt . clone ( ) ) ;
let gsession = GeneratingSession {
transformer : tr ,
tokenizer : tok ,
attention_cache_repository : state . attention_cache_repository . clone ( ) ,
token_sampler : token_sampler ,
tokens : toks_id ,
req_max_seq_len : req_max_seq_len ,
req_max_new_tokens : req_max_new_tokens ,
new_tokens_generated : 0 ,
prev_pos : 0 ,
no_token_sampling : no_token_sampling ,
stop_at_end_token : stop_at_end_token ,
sent_stuff_last_time : false ,
exit_after_one_query : state . exit_after_one_query ,
result : Vec ::new ( ) ,
} ;
return Ok ( rocket ::response ::Stream ::chunked ( gsession , 1024 ) ) ;
}
fn command_line_inference (
cli : Cli ,
tr : Arc < Transformer > ,
tok : Arc < Tokenizer > ,
prompt : String ,
be_quiet : bool ,
max_seq_len : usize ,
params : ModelParams ,
max_threads : usize ,
) -> Result < ( ) , Box < dyn std ::error ::Error > > {
// Custom println-like macro that respects be_quiet
macro_rules! pln {
( $( $arg :tt ) * ) = > {
if ! be_quiet {
std ::println! ( $( $arg ) * ) ;
}
} ;
}
let mut toks_id : Vec < TokenId > = tok . tokenize_to_ids ( prompt . clone ( ) ) ;
let mut prev_pos = 0 ;
let mut token_sampler = TokenSampler ::new ( ) . temperature ( 0.8 ) . top_p ( 0.9 ) . top_k ( 50 ) ;
let mut token_sampler = TokenSampler ::new ( )
. temperature ( 1.0 )
. top_p ( 1.0 )
. top_k ( 20 )
. repetition_penalty ( 1.0 ) ;
if let Some ( temperature ) = cli . temperature {
token_sampler = token_sampler . temperature ( temperature ) ;
@ -196,6 +707,9 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
if let Some ( top_k ) = cli . top_k {
token_sampler = token_sampler . top_k ( top_k as usize ) ;
}
if let Some ( repetition_penalty ) = cli . repetition_penalty {
token_sampler = token_sampler . repetition_penalty ( repetition_penalty ) ;
}
pln ! ( "---" ) ;
pln ! ( " dim: {}" , params . dim ) ;
@ -205,10 +719,16 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
pln ! ( " norm_eps: {}" , params . norm_eps ) ;
pln ! ( " vocab_size: {}" , params . vocab_size ) ;
pln ! ( "---" ) ;
pln ! ( " maximum number of threads: {}" , max_threads ) ;
pln ! ( "---" ) ;
pln ! ( "Max sequence length: {}" , max_seq_len ) ;
pln ! ( "Temperature: {}" , token_sampler . get_temperature ( ) ) ;
pln ! ( "Top P: {}" , token_sampler . get_top_p ( ) ) ;
pln ! ( "Top K: {}" , token_sampler . get_top_k ( ) ) ;
pln ! (
"Repetition penalty: {}" ,
token_sampler . get_repetition_penalty ( )
) ;
pln ! ( "---" ) ;
pln ! (
"{}" ,
@ -226,11 +746,12 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut times_per_token : Vec < std ::time ::Duration > = vec! [ ] ;
let mut caches = tr . make_caches ( ) ;
let mut first : bool = true ;
let mut stop_seen : bool = false ;
while toks_id . len ( ) < max_seq_len {
let now = std ::time ::Instant ::now ( ) ;
let preds = tr . forward ( & toks_id [ prev_pos .. ] , prev_pos , & mut caches ) ;
let highest_pred_idx = token_sampler . sample ( & preds ) ;
let ( highest_pred_idx , token_prob ) = token_sampler . sample ( & preds , & tok , & toks_id ) ;
toks_id . push ( highest_pred_idx as TokenId ) ;
for ( tok_idx , tok_id ) in toks_id [ prev_pos + 1 .. ] . iter ( ) . enumerate ( ) {
@ -239,6 +760,10 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
}
let mut tok_str : String = "" . to_string ( ) ;
let tok = tok . id_to_str ( * tok_id ) ;
if tok = = "</s>" {
tok_str + = "" ;
stop_seen = true ;
}
if tok = = "<0x0A>" {
tok_str + = "\n" ;
} else {
@ -247,7 +772,18 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
if first & & tok_idx < toks_id . len ( ) - 2 {
// intentionally left empty
} else {
print! ( "{}" , tok_str . truecolor ( 128 , 255 , 128 ) ) ;
let redness : f32 = token_prob * 255.0 ;
let redness = if redness > 255.0 {
255
} else if redness < 0.0 {
0
} else {
redness as u8
} ;
print! (
"{}" ,
tok_str . truecolor ( 128 + redness / 2 , 255 - redness / 2 , 128 )
) ;
}
}
if first {
@ -258,8 +794,14 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let _ = std ::io ::stdout ( ) . flush ( ) ;
prev_pos = toks_id . len ( ) - 1 ;
first = false ;
if stop_seen {
break ;
}
}
println! ( ) ;
if stop_seen & & ! be_quiet {
println! ( "Stop token seen. Stopping." ) ;
}
println! ( "" ) ;
if ! be_quiet {
println! ( "---" ) ;
println! (