diff --git a/src/benches/benchmark.rs b/src/benches/benchmark.rs index 9917f25..f6a73c2 100644 --- a/src/benches/benchmark.rs +++ b/src/benches/benchmark.rs @@ -102,6 +102,18 @@ pub fn tensor_benchmarks(c: &mut Criterion) { let orig_f32 = Tensor::zeros(1024, 1024, TensorDType::Float32); let orig_f16 = Tensor::zeros(1024, 1024, TensorDType::Float16); + let m1 = Tensor::random(1024, 128, TensorDType::Float32); + let m2 = Tensor::random(1, 128, TensorDType::Float32); + + c.bench_function( + "1024x128 * 1x128 matrix vector transposed multiplication", + |b| { + b.iter(|| { + let _ = m1.matrix_vector_mul_transposed(black_box(&m2)); + }) + }, + ); + c.bench_function("1024x1024 matrix from f32->f16", |b| { b.iter(|| { let _ = black_box(&orig_f32).to_f16(); diff --git a/src/tensor.rs b/src/tensor.rs index 740e60a..d5530d6 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1156,19 +1156,74 @@ impl Tensor { } else { (self.cols / 8 + 1) as usize }; + let row_its: usize = if self.rows % 4 == 0 { + (self.rows / 4) as usize + } else { + (self.rows / 4 + 1) as usize + }; + let mut sum8s: [__m256; 4] = [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ]; let self_data: *const f32 = self.data as *const f32; let other_data: *const f32 = other.data as *const f32; - for row in 0..self.rows { - let mut sum8: __m256 = _mm256_setzero_ps(); + let tgt_data: *mut f32 = result.data as *mut f32; + let ncols_capacity: usize = result.capacity_cols as usize; + for row in 0..row_its { + let row: i64 = row as i64; + sum8s[0] = _mm256_setzero_ps(); + sum8s[1] = _mm256_setzero_ps(); + sum8s[2] = _mm256_setzero_ps(); + sum8s[3] = _mm256_setzero_ps(); + let row4_0 = row * 4; + let row4_1 = row * 4 + 1; + let row4_2 = row * 4 + 2; + let row4_3 = row * 4 + 3; + for col in 0..col_its { let col = col * 8; - let left_side8 = - _mm256_loadu_ps(self_data.add((row * self.capacity_cols) as usize + col)); let right_side8 = _mm256_loadu_ps(other_data.add(col)); - sum8 = _mm256_fmadd_ps(left_side8, right_side8, sum8); + let left_side8_0 = _mm256_loadu_ps( + self_data.add((row4_0 * self.capacity_cols) as usize + col), + ); + let left_side8_1 = if row4_1 < self.rows { + _mm256_loadu_ps(self_data.add((row4_1 * self.capacity_cols) as usize + col)) + } else { + _mm256_setzero_ps() + }; + let left_side8_2 = if row4_2 < self.rows { + _mm256_loadu_ps(self_data.add((row4_2 * self.capacity_cols) as usize + col)) + } else { + _mm256_setzero_ps() + }; + let left_side8_3 = if row4_3 < self.rows { + _mm256_loadu_ps(self_data.add((row4_3 * self.capacity_cols) as usize + col)) + } else { + _mm256_setzero_ps() + }; + sum8s[0] = _mm256_fmadd_ps(left_side8_0, right_side8, sum8s[0]); + sum8s[1] = _mm256_fmadd_ps(left_side8_1, right_side8, sum8s[1]); + sum8s[2] = _mm256_fmadd_ps(left_side8_2, right_side8, sum8s[2]); + sum8s[3] = _mm256_fmadd_ps(left_side8_3, right_side8, sum8s[3]); + } + let sum_0: f32 = horizontal_sum(sum8s[0]); + let sum_1: f32 = horizontal_sum(sum8s[1]); + let sum_2: f32 = horizontal_sum(sum8s[2]); + let sum_3: f32 = horizontal_sum(sum8s[3]); + if row4_0 < result.rows { + result.set_f32(row4_0, 0, sum_0); + } + if row4_1 < result.rows { + result.set_f32(row4_1, 0, sum_1); + } + if row4_2 < result.rows { + result.set_f32(row4_2, 0, sum_2); + } + if row4_3 < result.rows { + result.set_f32(row4_3, 0, sum_3); } - let sum: f32 = horizontal_sum(sum8); - result.set_f32(row, 0, sum); } result }