From a1970b8a9cbd3cb56527d27079c2d9011735e1ac Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Fri, 17 Mar 2023 11:04:35 -0700 Subject: [PATCH] Improve matrix multiplication transposed further, this gives around ~10%-20% further increase by improving memory load to instruction ratio. --- src/rllama_main.rs | 13 ++-- src/tensor.rs | 152 ++++++++++++++++++++++++++++++++++--------- src/token_sampler.rs | 4 +- src/tokenizer.rs | 2 +- src/transformer.rs | 8 +-- src/unpickler.rs | 2 +- 6 files changed, 134 insertions(+), 47 deletions(-) diff --git a/src/rllama_main.rs b/src/rllama_main.rs index a3e3203..cf30ddf 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -152,10 +152,7 @@ pub fn main() -> Result<(), Box> { pln!("Loading embeddings from {}...", model_path); let emb = Embedding::from_unpickled(&unpickle_results, model_path.clone())?; - let max_seq_len = match cli.max_seq_len { - Some(max_seq_len) => max_seq_len, - None => 1024, - }; + let max_seq_len = cli.max_seq_len.unwrap_or(1024); let data_settings = { #[cfg(feature = "opencl")] @@ -291,11 +288,9 @@ pub fn main() -> Result<(), Box> { break; } } - println!(""); - if stop_seen { - if !be_quiet { - println!("Stop token seen. Stopping."); - } + println!(); + if stop_seen && !be_quiet { + println!("Stop token seen. Stopping."); } if !be_quiet { println!("---"); diff --git a/src/tensor.rs b/src/tensor.rs index 63f84eb..a26c8e1 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1072,6 +1072,8 @@ impl Tensor { let src_cols: usize = src.cols as usize; let self_rows: usize = self.rows as usize; let self_cols: usize = self.cols as usize; + let _other_cols: usize = other.cols as usize; + let other_rows: usize = other.rows as usize; let other_cols_capacity: usize = other.capacity_cols as usize; let src_cols_capacity: usize = src.capacity_cols as usize; let self_cols_capacity: usize = self.capacity_cols as usize; @@ -1086,6 +1088,11 @@ impl Tensor { } else { self_rows / 4 + 1 }; + let self_cols_its = if self_cols % 4 == 0 { + self_cols / 4 + } else { + self_cols / 4 + 1 + }; unsafe { for row in 0..row_its { @@ -1093,18 +1100,66 @@ impl Tensor { let row1 = row * 4 + 1; let row2 = row * 4 + 2; let row3 = row * 4 + 3; - for col in 0..self_cols { - let mut targets8: [__m256; 4] = [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), + for col in 0..self_cols_its { + let col0 = col * 4; + let col1 = col * 4 + 1; + let col2 = col * 4 + 2; + let col3 = col * 4 + 3; + let mut targets8: [[__m256; 4]; 4] = [ + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], ]; for p in 0..src_cols_its { - let other8: __m256 = _mm256_loadu_ps( + let other8_0: __m256 = _mm256_loadu_ps( other_data - .add(col * other_cols_capacity + p * ITEMS_PER_CACHE_LINE), + .add(col0 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE), ); + let other8_1: __m256 = + if col1 < other_rows { + _mm256_loadu_ps(other_data.add( + col1 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE, + )) + } else { + _mm256_setzero_ps() + }; + let other8_2: __m256 = + if col2 < other_rows { + _mm256_loadu_ps(other_data.add( + col2 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE, + )) + } else { + _mm256_setzero_ps() + }; + let other8_3: __m256 = + if col3 < other_rows { + _mm256_loadu_ps(other_data.add( + col3 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE, + )) + } else { + _mm256_setzero_ps() + }; let src8_0: __m256 = _mm256_loadu_ps( src_data .add(row0 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE), @@ -1133,24 +1188,61 @@ impl Tensor { } else { _mm256_setzero_ps() }; - targets8[0] = _mm256_fmadd_ps(src8_0, other8, targets8[0]); - targets8[1] = _mm256_fmadd_ps(src8_1, other8, targets8[1]); - targets8[2] = _mm256_fmadd_ps(src8_2, other8, targets8[2]); - targets8[3] = _mm256_fmadd_ps(src8_3, other8, targets8[3]); + targets8[0][0] = _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]); + targets8[0][1] = _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]); + targets8[0][2] = _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]); + targets8[0][3] = _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]); + targets8[1][0] = _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]); + targets8[1][1] = _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]); + targets8[1][2] = _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]); + targets8[1][3] = _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]); + targets8[2][0] = _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]); + targets8[2][1] = _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]); + targets8[2][2] = _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]); + targets8[2][3] = _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]); + targets8[3][0] = _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]); + targets8[3][1] = _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]); + targets8[3][2] = _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]); + targets8[3][3] = _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]); } - let target0: f32 = horizontal_sum(targets8[0]); - let target1: f32 = horizontal_sum(targets8[1]); - let target2: f32 = horizontal_sum(targets8[2]); - let target3: f32 = horizontal_sum(targets8[3]); - *tgt_data.add(row0 * self_cols_capacity + col) = target0; + let target00: f32 = horizontal_sum(targets8[0][0]); + let target01: f32 = horizontal_sum(targets8[0][1]); + let target02: f32 = horizontal_sum(targets8[0][2]); + let target03: f32 = horizontal_sum(targets8[0][3]); + let target10: f32 = horizontal_sum(targets8[1][0]); + let target11: f32 = horizontal_sum(targets8[1][1]); + let target12: f32 = horizontal_sum(targets8[1][2]); + let target13: f32 = horizontal_sum(targets8[1][3]); + let target20: f32 = horizontal_sum(targets8[2][0]); + let target21: f32 = horizontal_sum(targets8[2][1]); + let target22: f32 = horizontal_sum(targets8[2][2]); + let target23: f32 = horizontal_sum(targets8[2][3]); + let target30: f32 = horizontal_sum(targets8[3][0]); + let target31: f32 = horizontal_sum(targets8[3][1]); + let target32: f32 = horizontal_sum(targets8[3][2]); + let target33: f32 = horizontal_sum(targets8[3][3]); + + *tgt_data.add(row0 * self_cols_capacity + col0) += target00; + *tgt_data.add(row0 * self_cols_capacity + col1) += target10; + *tgt_data.add(row0 * self_cols_capacity + col2) += target20; + *tgt_data.add(row0 * self_cols_capacity + col3) += target30; if row1 < self_rows { - *tgt_data.add(row1 * self_cols_capacity + col) = target1; + *tgt_data.add(row1 * self_cols_capacity + col0) += target01; + *tgt_data.add(row1 * self_cols_capacity + col1) += target11; + *tgt_data.add(row1 * self_cols_capacity + col2) += target21; + *tgt_data.add(row1 * self_cols_capacity + col3) += target31; } if row2 < self_rows { - *tgt_data.add(row2 * self_cols_capacity + col) = target2; + *tgt_data.add(row2 * self_cols_capacity + col0) += target02; + *tgt_data.add(row2 * self_cols_capacity + col1) += target12; + *tgt_data.add(row2 * self_cols_capacity + col2) += target22; + *tgt_data.add(row2 * self_cols_capacity + col3) += target32; } if row3 < self_rows { - *tgt_data.add(row3 * self_cols_capacity + col) = target3; + *tgt_data.add(row3 * self_cols_capacity + col0) += target03; + *tgt_data.add(row3 * self_cols_capacity + col1) += target13; + *tgt_data.add(row3 * self_cols_capacity + col2) += target23; + *tgt_data.add(row3 * self_cols_capacity + col3) += target33; } } } @@ -1222,8 +1314,8 @@ impl Tensor { ]; let self_data: *const f32 = self.data as *const f32; let other_data: *const f32 = other.data as *const f32; - let tgt_data: *mut f32 = result.data as *mut f32; - let ncols_capacity: usize = result.capacity_cols as usize; + let _tgt_data: *mut f32 = result.data as *mut f32; + let _ncols_capacity: usize = result.capacity_cols as usize; for row in 0..row_its { let row: i64 = row as i64; sum8s[0] = _mm256_setzero_ps(); @@ -1670,8 +1762,8 @@ impl Tensor { let self_data: *const f16 = self.data as *const f16; let tgt_data: *mut f32 = result.data as *mut f32; - let tgt_capacity_cols = result.capacity_cols as i64; - let self_capacity_cols = self.capacity_cols as i64; + let tgt_capacity_cols = result.capacity_cols; + let self_capacity_cols = self.capacity_cols; for row in 0..self.rows { for col in 0..cols_it { let col = col * 8; @@ -1719,8 +1811,8 @@ impl Tensor { let result = Tensor::uninitialized(self.rows, self.cols, TensorDType::Float16); let self_data: *const f32 = self.data as *const f32; let tgt_data: *mut f16 = result.data as *mut f16; - let tgt_capacity_cols = result.capacity_cols as i64; - let self_capacity_cols = self.capacity_cols as i64; + let tgt_capacity_cols = result.capacity_cols; + let self_capacity_cols = self.capacity_cols; for row in 0..self.rows { for col in 0..cols_it { @@ -1973,9 +2065,9 @@ mod tests { fn mat_mul_transposed_agrees_with_regular_mat_mul() { let mut rng = rand::thread_rng(); for _ in 0..1000 { - let a = rng.gen_range(8..64); - let b = rng.gen_range(8..64); - let r = rng.gen_range(8..64); + let a = rng.gen_range(1..=128); + let b = rng.gen_range(1..=128); + let r = rng.gen_range(1..=128); // Make matrixes AxR and RxB let a = Tensor::random(a, r, TensorDType::Float32); @@ -1990,7 +2082,7 @@ mod tests { for row in 0..c.rows { for col in 0..c.cols { - assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-5); + assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-3); } } } diff --git a/src/token_sampler.rs b/src/token_sampler.rs index 5426478..1b760ad 100644 --- a/src/token_sampler.rs +++ b/src/token_sampler.rs @@ -68,7 +68,7 @@ impl TokenSampler { pub fn sample( &self, logits: &Tensor, - tokenizer: &Tokenizer, + _tokenizer: &Tokenizer, existing_tokens: &[TokenId], ) -> (TokenId, f32) { let mut times_used: BTreeMap = BTreeMap::new(); @@ -119,7 +119,7 @@ impl TokenSampler { None => { // Sort NaNs to bottom if b.1.is_nan() { - return std::cmp::Ordering::Less; + std::cmp::Ordering::Less } else if a.1.is_nan() { return std::cmp::Ordering::Greater; } else { diff --git a/src/tokenizer.rs b/src/tokenizer.rs index c836c1e..d7c7174 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -133,7 +133,7 @@ impl Tokenizer { let mut skip_s: &str = ""; // Specially recognize newline. Otherwise it matches something we don't actually // want. - if s.starts_with("\n") { + if s.starts_with('\n') { if self.str_to_id("<0x0A>").is_some() { best_candidate = "<0x0A>"; best_candidate_len = best_candidate.len(); diff --git a/src/transformer.rs b/src/transformer.rs index a20da45..3af2a0f 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -354,7 +354,7 @@ impl RMSNorm { let data_dir: &Path = data_dir.as_ref(); let weights = Tensor::from_unpickled_pieces( &unpickled[0..=0], - name.clone(), + name, data_dir, FromPiecesDirection::Rows, )? @@ -448,11 +448,11 @@ impl FeedForward { #[cfg(not(feature = "opencl"))] if w1w3_out.rows() == 1 { - return self + self .w2 - .matrix_vector_mul_transposed_multithreaded(&w1w3_out); + .matrix_vector_mul_transposed_multithreaded(&w1w3_out) } else { - return self.w2.matrix_mul_transposed(&w1w3_out); + self.w2.matrix_mul_transposed(&w1w3_out) } #[cfg(feature = "opencl")] { diff --git a/src/unpickler.rs b/src/unpickler.rs index 84a167f..6ee148d 100644 --- a/src/unpickler.rs +++ b/src/unpickler.rs @@ -587,7 +587,7 @@ pub fn unpickle(bytes: &[u8]) -> Result { )); } let idx = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); - match memo.get(&(idx as u32)) { + match memo.get(&{ idx }) { None => { return Err(UnpicklingError::UnpicklingError( "LONG_BINGET index out of range".to_string(),