diff --git a/src/rllama_main.rs b/src/rllama_main.rs index 3e4df19..a64cd1b 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -4,6 +4,7 @@ use crate::tokenizer::{TokenId, Tokenizer}; use crate::transformer::Transformer; use crate::unpickler; use clap::Parser; +use colored::Colorize; use std::io::{Read, Write}; #[derive(Parser)] @@ -31,13 +32,27 @@ pub fn main() -> Result<(), Box> { let model_path = cli.model_path; let tokenizer_path = cli.tokenizer_path; + let mut be_quiet: bool = false; + if !colored::control::SHOULD_COLORIZE.should_colorize() { + be_quiet = true; + } + + // Custom println-like macro that respects be_quiet + macro_rules! pln { + ($($arg:tt)*) => { + if !be_quiet { + std::println!($($arg)*); + } + }; + } + let prompt: String = match (cli.prompt, cli.prompt_file) { (Some(prompt), None) => { - println!("Using prompt: {}", prompt); + pln!("Using prompt: {}", prompt); prompt } (None, Some(prompt_file)) => { - println!("Using prompt file: {}", prompt_file); + pln!("Using prompt file: {}", prompt_file); let mut fs = std::fs::File::open(prompt_file)?; let mut bs = Vec::new(); fs.read_to_end(&mut bs)?; @@ -45,14 +60,14 @@ pub fn main() -> Result<(), Box> { String::from_utf8(bs)? } _ => { - println!("Please provide either a prompt or a prompt file."); - return Ok(()); + eprintln!("Please provide either a prompt or a prompt file."); + return Err("Please provide either a prompt or a prompt file.".into()); } }; - println!("Starting up. Loading tokenizer from {}...", tokenizer_path); + pln!("Starting up. Loading tokenizer from {}...", tokenizer_path); let tok = Tokenizer::load(tokenizer_path.as_str())?; - println!("Tokenizer loaded. Loading model from {}...", model_path); + pln!("Tokenizer loaded. Loading model from {}...", model_path); let mut fs = std::fs::File::open(model_path.as_str())?; let mut bs = Vec::new(); fs.read_to_end(&mut bs)?; @@ -66,25 +81,27 @@ pub fn main() -> Result<(), Box> { .join("/") + "/data/"; let result = unpickler::unpickle(&bs)?; - println!("Loading embeddings from {}...", model_data_dir); + pln!("Loading embeddings from {}...", model_data_dir); let emb = Embedding::from_unpickled(&result, model_data_dir.clone())?; - println!("Loading transformer weights from {}...", model_data_dir); + let max_seq_len = 512; + + pln!("Loading transformer weights from {}...", model_data_dir); let tr = Transformer::from_unpickled( &result, emb, 4096, 32, 32, - 512, + max_seq_len, 1e-6, 32, 128, model_data_dir, )?; - println!("All is loaded. Starting inference."); + pln!("All is loaded. Starting inference."); - let mut toks_id: Vec = tok.tokenize_to_ids(prompt); + let mut toks_id: Vec = tok.tokenize_to_ids(prompt.clone()); let mut prev_pos = 0; let mut token_sampler = TokenSampler::new().temperature(0.8).top_p(0.9).top_k(50); @@ -98,30 +115,65 @@ pub fn main() -> Result<(), Box> { token_sampler = token_sampler.top_k(top_k as usize); } - println!("Temperature: {}", token_sampler.get_temperature()); - println!("Top P: {}", token_sampler.get_top_p()); - println!("Top K: {}", token_sampler.get_top_k()); + pln!("---"); + pln!("Temperature: {}", token_sampler.get_temperature()); + pln!("Top P: {}", token_sampler.get_top_p()); + pln!("Top K: {}", token_sampler.get_top_k()); + pln!("---"); + pln!( + "{}", + " This is the color of the initial prompt".truecolor(128, 128, 255) + ); + pln!( + "{}", + " This is the color of the generated text while full context is available" + .truecolor(128, 255, 128) + ); + pln!( + "{}", + " Remaining text is in this color".truecolor(255, 128, 128) + ); + pln!("---"); + print!("{}", prompt.as_str().truecolor(128, 128, 255)); + let _ = std::io::stdout().flush(); let mut caches = tr.make_caches(); + let mut first: bool = true; + let mut shifts: usize = 0; loop { - let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches); + if toks_id.len() >= max_seq_len { + toks_id = toks_id[1..].to_vec(); + prev_pos -= 1; + caches.shift_left(1); + shifts += 1; + // TODO: it seems that text beyond context is just broken. + // Maybe I cannot just go and shift it. + } + let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches, shifts); let highest_pred_idx = token_sampler.sample(&preds); toks_id.push(highest_pred_idx as TokenId); - let mut tok_str: String = "".to_string(); - for tok_id in toks_id[prev_pos + 1..].iter() { + for (tok_idx, tok_id) in toks_id[prev_pos + 1..].iter().enumerate() { if *tok_id == 1 { continue; } + let mut tok_str: String = "".to_string(); let tok = tok.id_to_str(*tok_id); if tok == "<0x0A>" { tok_str += "\n"; } else { tok_str += tok.replace('▁', " ").as_str(); } + if first && tok_idx < toks_id.len() - 1 { + // intentionally left empty + } else if shifts == 0 { + print!("{}", tok_str.truecolor(128, 255, 128)); + } else { + print!("{}", tok_str.truecolor(255, 128, 128)); + } } - print!("{}", tok_str); let _ = std::io::stdout().flush(); prev_pos = toks_id.len() - 1; + first = false; } } diff --git a/src/tensor.rs b/src/tensor.rs index 48779bd..c3a4302 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -865,8 +865,8 @@ impl Tensor { unsafe { let result = Tensor::uninitialized(self.rows, 1, self.dtype); - let capacity_cols: i64 = self.capacity_cols as i64; - let result_capacity_cols = result.capacity_cols as i64; + let capacity_cols: i64 = self.capacity_cols; + let result_capacity_cols: i64 = result.capacity_cols; let col_its: usize = if self.cols % 8 == 0 { (self.cols / 8) as usize } else { @@ -902,6 +902,8 @@ impl Tensor { } // Computes matrix multiplication assuming left side has number of rows as 1 + #[allow(clippy::erasing_op)] + #[allow(clippy::identity_op)] pub fn vector_matrix_mul(&self, other: &Tensor) -> Tensor { if self.cols != other.rows { panic!( @@ -938,6 +940,9 @@ impl Tensor { let row = row8 * 8; let left = _mm256_loadu_ps(left_data.add(row)); let mut r = [0.0f32; 8]; + // i hate you clippy because you ask me + // to make code more unreadable + #[allow(clippy::needless_range_loop)] for i in 0..8 { if row + i < other.rows as usize { r[i] = *right_data.add((row + i) * other_capacity_cols + col); diff --git a/src/transformer.rs b/src/transformer.rs index 4977316..bfbba5e 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -62,6 +62,33 @@ impl AttentionCache { } AttentionCache { cache_k, cache_v } } + + fn shift_left(&mut self, shifts: usize) { + for _ in 0..shifts { + for idx in 0..self.cache_k.len() { + let mut k = self.cache_k[idx].write().unwrap(); + let mut v = self.cache_v[idx].write().unwrap(); + let k_rows = k.rows(); + let k_cols = k.cols(); + for head_idx in 0..k_rows { + for seq_idx in 0..k_cols - 1 { + let kval = k.get_f32(head_idx, seq_idx + 1); + let vval = v.get_f32(head_idx, seq_idx + 1); + k.set_f32(head_idx, seq_idx, kval); + v.set_f32(head_idx, seq_idx, vval); + } + } + } + } + } +} + +impl TransformerCaches { + pub fn shift_left(&mut self, shifts: usize) { + for layer in self.layer_caches.iter_mut() { + layer.shift_left(shifts); + } + } } pub struct RMSNorm { @@ -122,7 +149,9 @@ impl Transformer { let output = Tensor::from_unpickled(unpickled, "output.weight", data_dir)?.to_f32(); Ok(Transformer { - freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len * 2, 10000.0), + // TODO: maybe rotary embedding can be computed on the fly if the sequence gets really long. I just + // slapped * 20 on max seq len here. + freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len * 20, 10000.0), emb, dim, n_layers, @@ -157,6 +186,7 @@ impl Transformer { tokens: &[TokenId], start_pos: usize, caches: &mut TransformerCaches, + shifts: usize, ) -> Tensor { assert!(caches.layer_caches.len() == self.n_layers); let mask: Option = if tokens.len() > 1 { @@ -185,6 +215,7 @@ impl Transformer { &self.freqs_cis, &mask, &mut caches.layer_caches[idx], + shifts, ); } let out = self.norm.forward(&emb_tensor); @@ -234,11 +265,17 @@ impl TransformerBlock { freqs_cis: &FreqsCis, mask: &Option, attention_cache: &mut AttentionCache, + shifts: usize, ) -> Tensor { let attnorm_out = self.attention_norm.forward(x); - let att_out = self - .attn - .forward(&attnorm_out, start_pos, freqs_cis, mask, attention_cache); + let att_out = self.attn.forward( + &attnorm_out, + start_pos, + freqs_cis, + mask, + attention_cache, + shifts, + ); let h = x.add(&att_out); let att_out = self.ffn_norm.forward(&h); let att_out = self.feed_forward.forward(&att_out).transpose(); @@ -300,8 +337,8 @@ impl FeedForward { pub fn forward(&self, x: &Tensor) -> Tensor { let (w1_out, w3_out) = rayon::join( - || self.w1.matrix_mul_transposed(&x), - || self.w3.matrix_mul_transposed(&x), + || self.w1.matrix_mul_transposed(x), + || self.w3.matrix_mul_transposed(x), ); let w1_out = w1_out.silu(); let w1w3_out = w1_out.hadamard_product(&w3_out).transpose(); @@ -367,6 +404,7 @@ impl Attention { freqs_cis: &FreqsCis, mask: &Option, attention_cache: &mut AttentionCache, + shifts: usize, ) -> Tensor { let seq_len = x.rows(); let (xq_out, (xk_out, xv_out)) = rayon::join( @@ -394,8 +432,13 @@ impl Attention { .row(idx) .view(self.n_local_heads as i64, self.head_dim as i64); - let (xq_row, xk_row) = - apply_rotary_emb(&xq_row, &xk_row, freqs_cis, idx as usize, start_pos); + let (xq_row, xk_row) = apply_rotary_emb( + &xq_row, + &xk_row, + freqs_cis, + idx as usize, + start_pos + shifts, + ); xq_views.push(xq_row); xk_views.push(xk_row); @@ -464,8 +507,7 @@ impl Attention { } let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect(); let xq_row = Tensor::concat(&concat_vec2).view(1, 4096); - let result = xq_row.matrix_mul_transposed(&self.wo); - result + xq_row.matrix_mul_transposed(&self.wo) }) .collect();