From acfd6bd5bdde9106f484e1a8c14cc6b32ba7311c Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Fri, 17 Mar 2023 13:26:58 -0700 Subject: [PATCH] Add f16, non-OpenCL version of matrix_vector_mul_transposed as well. This seems to be 100% slower than the pure f32 version in benchmark. Not sure why as of this commit, but I'll investigate further. --- src/benches/benchmark.rs | 29 +++++--- src/tensor.rs | 138 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 157 insertions(+), 10 deletions(-) diff --git a/src/benches/benchmark.rs b/src/benches/benchmark.rs index 0d9aadf..4b126dc 100644 --- a/src/benches/benchmark.rs +++ b/src/benches/benchmark.rs @@ -108,6 +108,26 @@ pub fn tensor_benchmarks(c: &mut Criterion) { let m1 = Tensor::random(1024, 128, TensorDType::Float32); let m2 = Tensor::random(1, 128, TensorDType::Float32); + let m1_f16 = m1.to_f16(); + let m2_f16 = m2.to_f16(); + + c.bench_function( + "1024x128 * 1x128 matrix vector transposed multiplication, f32", + |b| { + b.iter(|| { + let _ = m1.matrix_vector_mul_transposed(black_box(&m2)); + }) + }, + ); + + c.bench_function( + "1024x128 * 1x128 matrix vector transposed multiplication, f16", + |b| { + b.iter(|| { + let _ = m1_f16.matrix_vector_mul_transposed(black_box(&m2_f16)); + }) + }, + ); c.bench_function( "matrix multiplication 8x4096 @ 4096x4096 f32 in-place, transposed", @@ -133,15 +153,6 @@ pub fn tensor_benchmarks(c: &mut Criterion) { }, ); - c.bench_function( - "1024x128 * 1x128 matrix vector transposed multiplication", - |b| { - b.iter(|| { - let _ = m1.matrix_vector_mul_transposed(black_box(&m2)); - }) - }, - ); - c.bench_function("1024x1024 matrix from f32->f16", |b| { b.iter(|| { let _ = black_box(&orig_f32).to_f16(); diff --git a/src/tensor.rs b/src/tensor.rs index edf9678..0f1edc9 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1495,8 +1495,115 @@ impl Tensor { } assert_eq!(other.rows, 1); assert_eq!(other.dtype, self.dtype); - assert_eq!(self.dtype, TensorDType::Float32); + match self.dtype { + TensorDType::Float32 => self.matrix_vector_mul_transposed_f32(other), + TensorDType::Float16 => self.matrix_vector_mul_transposed_f16(other), + _ => panic!("Unsupported dtype"), + } + } + + fn matrix_vector_mul_transposed_f16(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); + other.assume_on_cpu(); + unsafe { + let mut result = Tensor::uninitialized(self.rows, 1, self.dtype); + let col_its: usize = if self.cols % 8 == 0 { + (self.cols / 8) as usize + } else { + (self.cols / 8 + 1) as usize + }; + let row_its: usize = if self.rows % 4 == 0 { + (self.rows / 4) as usize + } else { + (self.rows / 4 + 1) as usize + }; + let mut sum8s: [__m256; 4] = [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ]; + let self_data: *const f16 = self.data as *const f16; + let other_data: *const f16 = other.data as *const f16; + let _ncols_capacity: usize = result.capacity_cols as usize; + for row in 0..row_its { + let row: i64 = row as i64; + sum8s[0] = _mm256_setzero_ps(); + sum8s[1] = _mm256_setzero_ps(); + sum8s[2] = _mm256_setzero_ps(); + sum8s[3] = _mm256_setzero_ps(); + let row4_0 = row * 4; + let row4_1 = row * 4 + 1; + let row4_2 = row * 4 + 2; + let row4_3 = row * 4 + 3; + + // Loads from (0, column..column+8) + #[inline] + fn load2(ptr: *const f16, col: usize) -> __m256 { + unsafe { _mm256_cvtph_ps(_mm_loadu_si128(ptr.add(col) as *const __m128i)) } + } + // Loads from (row, column..column+8) + #[inline] + fn load2row( + ptr: *const f16, + row: i64, + col: usize, + cols_capacity: i64, + nrows: i64, + ) -> __m256 { + unsafe { + if row < nrows { + _mm256_cvtph_ps(_mm_loadu_si128( + ptr.add(row as usize * cols_capacity as usize + col) + as *const __m128i, + )) + } else { + _mm256_setzero_ps() + } + } + } + + for col in 0..col_its { + let col = col * 8; + let right_side8 = load2(other_data, col); + let left_side8_0 = + load2row(self_data, row4_0, col, self.capacity_cols, self.rows); + let left_side8_1 = + load2row(self_data, row4_1, col, self.capacity_cols, self.rows); + let left_side8_2 = + load2row(self_data, row4_2, col, self.capacity_cols, self.rows); + let left_side8_3 = + load2row(self_data, row4_3, col, self.capacity_cols, self.rows); + sum8s[0] = _mm256_fmadd_ps(left_side8_0, right_side8, sum8s[0]); + sum8s[1] = _mm256_fmadd_ps(left_side8_1, right_side8, sum8s[1]); + sum8s[2] = _mm256_fmadd_ps(left_side8_2, right_side8, sum8s[2]); + sum8s[3] = _mm256_fmadd_ps(left_side8_3, right_side8, sum8s[3]); + } + let sum_0: f32 = horizontal_sum(sum8s[0]); + let sum_1: f32 = horizontal_sum(sum8s[1]); + let sum_2: f32 = horizontal_sum(sum8s[2]); + let sum_3: f32 = horizontal_sum(sum8s[3]); + if row4_0 < result.rows { + result.set_f32(row4_0, 0, sum_0); + } + if row4_1 < result.rows { + result.set_f32(row4_1, 0, sum_1); + } + if row4_2 < result.rows { + result.set_f32(row4_2, 0, sum_2); + } + if row4_3 < result.rows { + result.set_f32(row4_3, 0, sum_3); + } + } + result + } + } + + fn matrix_vector_mul_transposed_f32(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); + other.assume_on_cpu(); unsafe { let mut result = Tensor::uninitialized(self.rows, 1, self.dtype); let col_its: usize = if self.cols % 8 == 0 { @@ -2321,6 +2428,35 @@ mod tests { } } + #[test] + fn mat_vector_mul_transposed_f32_agrees_mat_vector_mul_transposed_f16() { + let mut rng = rand::thread_rng(); + for _ in 0..1000 { + let a = rng.gen_range(1..=128); + let r = rng.gen_range(1..=128); + + // Make matrixes AxR and Rx1 + let a = Tensor::random(a, r, TensorDType::Float32); + let b = Tensor::random(r, 1, TensorDType::Float32); + let a2 = a.clone().to_f16(); + let b2 = b.clone().to_f16(); + let b_transposed = b.transpose(); + let b2_transposed = b2.transpose(); + + let c = a.matrix_vector_mul_transposed(&b_transposed); + let c2 = a2.matrix_vector_mul_transposed(&b2_transposed); + + assert_eq!(c.rows, c2.rows); + assert_eq!(c.cols, c2.cols); + + for row in 0..c.rows { + for col in 0..c.cols { + assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-1); + } + } + } + } + #[test] fn view_preserves_values() { fn test_with_type(dtype: TensorDType) {