diff --git a/README.md b/README.md index 06134fa..ec4febb 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,9 @@ This is my attempt at making the LLaMA language model working on a pure Rust CPU implementation. I was inspired by an amazing CPU implementation here: https://github.com/ggerganov/ggml that could run GPT-J 8B models. -As of writing of this, this can run LLaMA-7B at around ~1 token per second, -using something like 1.5 threads because I haven't yet properly figured out how -to multithread this. +As of writing of this, this can run LLaMA-7B at around ~1 token per second, on +a Ryzen 3950X using something like 1.5 threads because I haven't yet properly +figured out how to multithread this. It uses AVX2 intrinsics to speed up itself. Therefore, you need an x86-family CPU to run this. diff --git a/src/rllama_main.rs b/src/rllama_main.rs index 2ccc85b..2fffd84 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -14,7 +14,9 @@ struct Cli { #[arg(long)] tokenizer_path: String, #[arg(long)] - prompt: String, + prompt: Option, + #[arg(long)] + prompt_file: Option, #[arg(long)] temperature: Option, @@ -28,11 +30,29 @@ pub fn main() -> Result<(), Box> { let cli = Cli::parse(); let model_path = cli.model_path; let tokenizer_path = cli.tokenizer_path; - let prompt = cli.prompt; + + let prompt: String = match (cli.prompt, cli.prompt_file) { + (Some(prompt), None) => { + println!("Using prompt: {}", prompt); + prompt + } + (None, Some(prompt_file)) => { + println!("Using prompt file: {}", prompt_file); + let mut fs = std::fs::File::open(prompt_file)?; + let mut bs = Vec::new(); + fs.read_to_end(&mut bs)?; + std::mem::drop(fs); + String::from_utf8(bs)? + } + _ => { + println!("Please provide either a prompt or a prompt file."); + return Ok(()); + } + }; println!("Starting up. Loading tokenizer from {}...", tokenizer_path); let tok = Tokenizer::load(tokenizer_path.as_str())?; - println!("Tokenizer loeaded. Loading model from {}...", model_path); + println!("Tokenizer loaded. Loading model from {}...", model_path); let mut fs = std::fs::File::open(model_path.as_str())?; let mut bs = Vec::new(); fs.read_to_end(&mut bs)?; diff --git a/src/tensor.rs b/src/tensor.rs index eabafe5..48779bd 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -2,6 +2,7 @@ use crate::unpickler; use crate::unpickler::UnpicklingError; use half::f16; use rand::Rng; +use rayon::prelude::*; use std::alloc::Layout; use std::arch::x86_64::*; use std::io::Read; @@ -833,6 +834,73 @@ impl Tensor { } } + /// Same as matrix_vector_mul but uses threading. + pub fn matrix_vector_mul_transposed_multithreaded(&self, other: &Tensor) -> Tensor { + if self.cols != other.cols { + panic!( + "Invalid matrix-vector transposed multiplication {}x{} vs {}x{}", + self.rows, self.cols, other.rows, other.cols + ); + } + assert_eq!(other.rows, 1); + assert_eq!(other.dtype, self.dtype); + assert_eq!(self.dtype, TensorDType::Float32); + + // Use this to smuggle pointers to threads without Rust getting so goddamn mad + // + // Assumption usize = pointer size. + #[derive(Copy, Clone)] + struct WrappedPtr { + ptr: usize, + } + impl WrappedPtr { + fn wrap(ptr: *const u8) -> WrappedPtr { + WrappedPtr { ptr: ptr as usize } + } + + fn unwrap(self) -> *const u8 { + self.ptr as *const u8 + } + } + + unsafe { + let result = Tensor::uninitialized(self.rows, 1, self.dtype); + let capacity_cols: i64 = self.capacity_cols as i64; + let result_capacity_cols = result.capacity_cols as i64; + let col_its: usize = if self.cols % 8 == 0 { + (self.cols / 8) as usize + } else { + (self.cols / 8 + 1) as usize + }; + let self_data = WrappedPtr::wrap(self.data); + let other_data = WrappedPtr::wrap(other.data); + let result_data = WrappedPtr::wrap(result.data); + (0..self.rows as usize) + .into_par_iter() + .with_min_len(64) + .for_each(|row| { + let row = row as i64; + let self_data: *const f32 = self_data.unwrap() as *const f32; + let other_data: *const f32 = other_data.unwrap() as *const f32; + let result_data: *mut f32 = result_data.unwrap() as *mut f32; + + let mut sum8: __m256 = _mm256_setzero_ps(); + for col in 0..col_its { + let col = col * 8; + let left_side8 = + _mm256_loadu_ps(self_data.add((row * capacity_cols) as usize + col)); + let right_side8 = _mm256_loadu_ps(other_data.add(col)); + sum8 = _mm256_fmadd_ps(left_side8, right_side8, sum8); + } + let sum: f32 = horizontal_sum(sum8); + result_data + .add((row * result_capacity_cols) as usize) + .write(sum); + }); + result + } + } + // Computes matrix multiplication assuming left side has number of rows as 1 pub fn vector_matrix_mul(&self, other: &Tensor) -> Tensor { if self.cols != other.rows { @@ -842,15 +910,54 @@ impl Tensor { ); } assert_eq!(self.rows, 1); - let mut result = unsafe { Tensor::uninitialized(1, other.cols, self.dtype) }; - for col in 0..other.cols { - let mut sum = 0.0; - for row in 0..self.cols { - sum += self.get_f32(0, row) * other.get_f32(row, col); + unsafe { + let result = Tensor::uninitialized(1, other.cols, self.dtype); + let col_its: usize = if other.rows % 8 == 0 { + (other.rows / 8) as usize + } else { + (other.rows / 8 + 1) as usize + }; + let left_data: *const f32 = self.data as *const f32; + let right_data: *const f32 = other.data as *const f32; + let tgt_data: *mut f32 = result.data as *mut f32; + let other_capacity_cols = other.capacity_cols as usize; + + let o0: i32 = other_capacity_cols as i32 * 0 * 4; + let o1: i32 = other_capacity_cols as i32 * 1 * 4; + let o2: i32 = other_capacity_cols as i32 * 2 * 4; + let o3: i32 = other_capacity_cols as i32 * 3 * 4; + let o4: i32 = other_capacity_cols as i32 * 4 * 4; + let o5: i32 = other_capacity_cols as i32 * 5 * 4; + let o6: i32 = other_capacity_cols as i32 * 6 * 4; + let o7: i32 = other_capacity_cols as i32 * 7 * 4; + + for col in 0..other.cols { + let col = col as usize; + let mut sum8: __m256 = _mm256_setzero_ps(); + for row8 in 0..col_its { + let row = row8 * 8; + let left = _mm256_loadu_ps(left_data.add(row)); + let mut r = [0.0f32; 8]; + for i in 0..8 { + if row + i < other.rows as usize { + r[i] = *right_data.add((row + i) * other_capacity_cols + col); + } + } + let right = if row + 8 <= other.rows as usize { + _mm256_i32gather_ps( + right_data.add(row * other_capacity_cols + col), + _mm256_set_epi32(o7, o6, o5, o4, o3, o2, o1, o0), + 1, + ) + } else { + _mm256_loadu_ps(r.as_ptr()) + }; + sum8 = _mm256_fmadd_ps(left, right, sum8); + } + *tgt_data.add(col) = horizontal_sum(sum8); } - result.set_f32(0, col, sum); + result } - result } pub fn random(rows: i64, cols: i64, dtype: TensorDType) -> Self { @@ -1246,6 +1353,30 @@ mod tests { } } + #[test] + fn vector_mat_mul_and_naive_mat_mul_agree() { + let mut rng = rand::thread_rng(); + for _ in 0..50 { + let a = rng.gen_range(1..100); + let b = rng.gen_range(1..100); + + let m1 = Tensor::random(1, a, TensorDType::Float32); + let m2 = Tensor::random(a, b, TensorDType::Float32); + + let c = m1.matrix_mul_naive(&m2); + let c2 = m1.vector_matrix_mul(&m2); + + assert_eq!(c.rows(), c2.rows()); + assert_eq!(c.cols(), c2.cols()); + + for row in 0..c.rows { + for col in 0..c.cols { + assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-5); + } + } + } + } + #[test] fn naive_mat_mul_and_fast_are_same_f16() { for _ in 0..50 { diff --git a/src/transformer.rs b/src/transformer.rs index 5b56afa..4977316 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -241,8 +241,7 @@ impl TransformerBlock { .forward(&attnorm_out, start_pos, freqs_cis, mask, attention_cache); let h = x.add(&att_out); let att_out = self.ffn_norm.forward(&h); - let att_out = self.feed_forward.forward(&att_out.transpose()).transpose(); - + let att_out = self.feed_forward.forward(&att_out).transpose(); h.add(&att_out) } } @@ -300,7 +299,6 @@ impl FeedForward { } pub fn forward(&self, x: &Tensor) -> Tensor { - let x = x.transpose(); let (w1_out, w3_out) = rayon::join( || self.w1.matrix_mul_transposed(&x), || self.w3.matrix_mul_transposed(&x), @@ -308,6 +306,11 @@ impl FeedForward { let w1_out = w1_out.silu(); let w1w3_out = w1_out.hadamard_product(&w3_out).transpose(); + if w1w3_out.rows() == 1 { + return self + .w2 + .matrix_vector_mul_transposed_multithreaded(&w1w3_out); + } self.w2.matrix_mul_transposed(&w1w3_out) } } @@ -366,9 +369,15 @@ impl Attention { attention_cache: &mut AttentionCache, ) -> Tensor { let seq_len = x.rows(); - let xq_out = x.matrix_mul_transposed(&self.wq); - let xk_out = x.matrix_mul_transposed(&self.wk); - let xv_out = x.matrix_mul_transposed(&self.wv); + let (xq_out, (xk_out, xv_out)) = rayon::join( + || x.matrix_mul_transposed(&self.wq), + || { + rayon::join( + || x.matrix_mul_transposed(&self.wk), + || x.matrix_mul_transposed(&self.wv), + ) + }, + ); let mut xq_views: Vec = Vec::with_capacity(seq_len as usize); let mut xk_views: Vec = Vec::with_capacity(seq_len as usize); @@ -420,20 +429,6 @@ impl Attention { let mut cache_k = attention_cache.cache_k[idx].write().unwrap(); let mut cache_v = attention_cache.cache_v[idx].write().unwrap(); - /* - let m = xq_row - .matrix_mul(&xk_row) - .scalar_multiply_f32(1.0 / (self.head_dim as f32).sqrt()); - //println!("mask size: {} {}", mask.rows(), mask.cols()); - //println!("m size: {} {}", m.rows(), m.cols()); - let m2 = m.add(mask).to_f32().softmax().matrix_mul(&xv_row); - m2 - println!("xk_row size: {} {}", xk_row.rows(), xk_row.cols()); - println!("xv_row size: {} {}", xv_row.rows(), xv_row.cols()); - println!("cache_k size: {} {}", cache_k.rows(), cache_k.cols()); - panic!("stop"); - */ - for pos in start_pos..start_pos + seq_len as usize { for dim in 0..self.head_dim { let k = xk_row.get_f32(dim as i64, (pos - start_pos) as i64); @@ -460,8 +455,6 @@ impl Attention { }) .collect(); - // convert from 32 matrices of size 8x128 to 8 matrices of size 32x128 - // or rather 4096x1 let output2: Vec = (0..seq_len) .into_par_iter() .map(|idx| { @@ -471,10 +464,11 @@ impl Attention { } let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect(); let xq_row = Tensor::concat(&concat_vec2).view(1, 4096); - - xq_row.matrix_mul_transposed(&self.wo) + let result = xq_row.matrix_mul_transposed(&self.wo); + result }) .collect(); + let output3: Vec<&Tensor> = output2.iter().collect(); let output2: Tensor = Tensor::concat(&output3); output2