From 61bc42b728bf11add3731449a1718e1ae3ff429d Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Thu, 16 Mar 2023 08:53:31 -0700 Subject: [PATCH] Improve the handwritten AVX2 for matrix_mul_inplace_transposed. This is something like ~60% faster than old version. --- src/tensor.rs | 73 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 63 insertions(+), 10 deletions(-) diff --git a/src/tensor.rs b/src/tensor.rs index d5530d6..63f84eb 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1068,6 +1068,7 @@ impl Tensor { let src_data: *const f32 = src.data as *const f32; let other_data: *const f32 = other.data as *const f32; + let src_rows: usize = src.rows as usize; let src_cols: usize = src.cols as usize; let self_rows: usize = self.rows as usize; let self_cols: usize = self.cols as usize; @@ -1080,25 +1081,77 @@ impl Tensor { } else { src_cols / ITEMS_PER_CACHE_LINE + 1 }; + let row_its = if self_rows % 4 == 0 { + self_rows / 4 + } else { + self_rows / 4 + 1 + }; unsafe { - for row in 0..self_rows { - let row = row; + for row in 0..row_its { + let row0 = row * 4; + let row1 = row * 4 + 1; + let row2 = row * 4 + 2; + let row3 = row * 4 + 3; for col in 0..self_cols { - let mut target8: __m256 = _mm256_setzero_ps(); + let mut targets8: [__m256; 4] = [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ]; for p in 0..src_cols_its { - let src8: __m256 = _mm256_loadu_ps( - src_data - .add(row * src_cols_capacity + p * ITEMS_PER_CACHE_LINE), - ); let other8: __m256 = _mm256_loadu_ps( other_data .add(col * other_cols_capacity + p * ITEMS_PER_CACHE_LINE), ); - target8 = _mm256_fmadd_ps(src8, other8, target8); + let src8_0: __m256 = _mm256_loadu_ps( + src_data + .add(row0 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE), + ); + let src8_1: __m256 = + if row1 < src_rows { + _mm256_loadu_ps(src_data.add( + row1 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE, + )) + } else { + _mm256_setzero_ps() + }; + let src8_2: __m256 = + if row2 < src_rows { + _mm256_loadu_ps(src_data.add( + row2 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE, + )) + } else { + _mm256_setzero_ps() + }; + let src8_3: __m256 = + if row3 < src_rows { + _mm256_loadu_ps(src_data.add( + row3 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE, + )) + } else { + _mm256_setzero_ps() + }; + targets8[0] = _mm256_fmadd_ps(src8_0, other8, targets8[0]); + targets8[1] = _mm256_fmadd_ps(src8_1, other8, targets8[1]); + targets8[2] = _mm256_fmadd_ps(src8_2, other8, targets8[2]); + targets8[3] = _mm256_fmadd_ps(src8_3, other8, targets8[3]); + } + let target0: f32 = horizontal_sum(targets8[0]); + let target1: f32 = horizontal_sum(targets8[1]); + let target2: f32 = horizontal_sum(targets8[2]); + let target3: f32 = horizontal_sum(targets8[3]); + *tgt_data.add(row0 * self_cols_capacity + col) = target0; + if row1 < self_rows { + *tgt_data.add(row1 * self_cols_capacity + col) = target1; + } + if row2 < self_rows { + *tgt_data.add(row2 * self_cols_capacity + col) = target2; + } + if row3 < self_rows { + *tgt_data.add(row3 * self_cols_capacity + col) = target3; } - let target: f32 = horizontal_sum(target8); - *tgt_data.add(row * self_cols_capacity + col) = target; } } }