diff --git a/README.md b/README.md index b1d3653..d4595e2 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,9 @@ cast to 32-bit floats. You can use `--temperature`, `--top-p` and `--top-k` to adjust token sampler settings. +You can also use `--prompt-file` to read the prompt from a file instead from +the command line. + # How to turn on OpenCL Use `opencl` Cargo feature. diff --git a/src/rllama_main.rs b/src/rllama_main.rs index f097346..1c1e3fb 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -226,6 +226,7 @@ pub fn main() -> Result<(), Box> { let mut times_per_token: Vec = vec![]; let mut caches = tr.make_caches(); let mut first: bool = true; + let mut stop_seen: bool = false; while toks_id.len() < max_seq_len { let now = std::time::Instant::now(); @@ -239,6 +240,10 @@ pub fn main() -> Result<(), Box> { } let mut tok_str: String = "".to_string(); let tok = tok.id_to_str(*tok_id); + if tok == "" { + tok_str += ""; + stop_seen = true; + } if tok == "<0x0A>" { tok_str += "\n"; } else { @@ -258,8 +263,16 @@ pub fn main() -> Result<(), Box> { let _ = std::io::stdout().flush(); prev_pos = toks_id.len() - 1; first = false; + if stop_seen { + break; + } } println!(""); + if stop_seen { + if !be_quiet { + println!("Stop token seen. Stopping."); + } + } if !be_quiet { println!("---"); println!(