diff --git a/src/benches/benchmark.rs b/src/benches/benchmark.rs index f6a73c2..0d9aadf 100644 --- a/src/benches/benchmark.rs +++ b/src/benches/benchmark.rs @@ -99,12 +99,40 @@ pub fn tensor_benchmarks(c: &mut Criterion) { let orig_84096_2 = Tensor::zeros(4096, 4096, TensorDType::Float32); let mut result_84096 = Tensor::zeros(8, 4096, TensorDType::Float32); + let orig_84096_1_f16 = Tensor::zeros(8, 4096, TensorDType::Float16); + let orig_84096_2_f16 = Tensor::zeros(4096, 4096, TensorDType::Float16); + let mut result_84096_f16 = Tensor::zeros(8, 4096, TensorDType::Float16); + let orig_f32 = Tensor::zeros(1024, 1024, TensorDType::Float32); let orig_f16 = Tensor::zeros(1024, 1024, TensorDType::Float16); let m1 = Tensor::random(1024, 128, TensorDType::Float32); let m2 = Tensor::random(1, 128, TensorDType::Float32); + c.bench_function( + "matrix multiplication 8x4096 @ 4096x4096 f32 in-place, transposed", + |b| { + b.iter(|| { + let _ = result_84096.matrix_mul_inplace_transposed( + black_box(&orig_84096_1), + black_box(&orig_84096_2), + ); + }) + }, + ); + + c.bench_function( + "matrix multiplication 8x4096 @ 4096x4096 f16 in-place, transposed", + |b| { + b.iter(|| { + let _ = result_84096_f16.matrix_mul_inplace_transposed( + black_box(&orig_84096_1_f16), + black_box(&orig_84096_2_f16), + ); + }) + }, + ); + c.bench_function( "1024x128 * 1x128 matrix vector transposed multiplication", |b| { @@ -136,18 +164,6 @@ pub fn tensor_benchmarks(c: &mut Criterion) { }, ); - c.bench_function( - "matrix multiplication 8x4096 @ 4096x4096 f32 in-place, transposed", - |b| { - b.iter(|| { - let _ = result_84096.matrix_mul_inplace_transposed( - black_box(&orig_84096_1), - black_box(&orig_84096_2), - ); - }) - }, - ); - c.bench_function("matrix multiplication f32 not in-place", |b| { b.iter(|| { let _ = black_box(&orig32_1).matrix_mul(black_box(&orig32_2)); diff --git a/src/tensor.rs b/src/tensor.rs index a26c8e1..edf9678 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -169,6 +169,17 @@ fn horizontal_sum(mut ymm: __m256) -> f32 { } } +#[inline] +fn horizontal_sum_f32_to_f16(mut ymm: __m256) -> f16 { + unsafe { + let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1); + ymm = _mm256_add_ps(ymm, ymm2); + ymm = _mm256_hadd_ps(ymm, ymm); + ymm = _mm256_hadd_ps(ymm, ymm); + f16::from_f32(_mm256_cvtss_f32(ymm)) + } +} + impl Tensor { #[inline] pub fn assume_on_gpu(&self) { @@ -824,8 +835,10 @@ impl Tensor { self.rows, self.cols, other.cols, other.rows ); } + // We don't have implementation for f16, so don't use the vector function if we have + // f16 #[cfg(not(feature = "opencl"))] - if other.rows == 1 { + if other.rows == 1 && other.dtype != TensorDType::Float16 { return self.matrix_vector_mul_transposed(other); } #[cfg(feature = "opencl")] @@ -1054,8 +1067,7 @@ impl Tensor { match src.dtype { TensorDType::Float32 => { - const CACHE_LINE_SIZE: usize = 32; - const ITEMS_PER_CACHE_LINE: usize = CACHE_LINE_SIZE / std::mem::size_of::(); + const ITEMS_PER_LINE: usize = 8; let tgt_data: *mut f32 = self.data as *mut f32; unsafe { @@ -1078,10 +1090,10 @@ impl Tensor { let src_cols_capacity: usize = src.capacity_cols as usize; let self_cols_capacity: usize = self.capacity_cols as usize; - let src_cols_its = if src_cols % ITEMS_PER_CACHE_LINE == 0 { - src_cols / ITEMS_PER_CACHE_LINE + let src_cols_its = if src_cols % ITEMS_PER_LINE == 0 { + src_cols / ITEMS_PER_LINE } else { - src_cols / ITEMS_PER_CACHE_LINE + 1 + src_cols / ITEMS_PER_LINE + 1 }; let row_its = if self_rows % 4 == 0 { self_rows / 4 @@ -1133,61 +1145,56 @@ impl Tensor { ]; for p in 0..src_cols_its { let other8_0: __m256 = _mm256_loadu_ps( - other_data - .add(col0 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE), + other_data.add(col0 * other_cols_capacity + p * ITEMS_PER_LINE), ); - let other8_1: __m256 = - if col1 < other_rows { - _mm256_loadu_ps(other_data.add( - col1 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE, - )) - } else { - _mm256_setzero_ps() - }; - let other8_2: __m256 = - if col2 < other_rows { - _mm256_loadu_ps(other_data.add( - col2 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE, - )) - } else { - _mm256_setzero_ps() - }; - let other8_3: __m256 = - if col3 < other_rows { - _mm256_loadu_ps(other_data.add( - col3 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE, - )) - } else { - _mm256_setzero_ps() - }; + let other8_1: __m256 = if col1 < other_rows { + _mm256_loadu_ps( + other_data + .add(col1 * other_cols_capacity + p * ITEMS_PER_LINE), + ) + } else { + _mm256_setzero_ps() + }; + let other8_2: __m256 = if col2 < other_rows { + _mm256_loadu_ps( + other_data + .add(col2 * other_cols_capacity + p * ITEMS_PER_LINE), + ) + } else { + _mm256_setzero_ps() + }; + let other8_3: __m256 = if col3 < other_rows { + _mm256_loadu_ps( + other_data + .add(col3 * other_cols_capacity + p * ITEMS_PER_LINE), + ) + } else { + _mm256_setzero_ps() + }; let src8_0: __m256 = _mm256_loadu_ps( - src_data - .add(row0 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE), + src_data.add(row0 * src_cols_capacity + p * ITEMS_PER_LINE), ); - let src8_1: __m256 = - if row1 < src_rows { - _mm256_loadu_ps(src_data.add( - row1 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE, - )) - } else { - _mm256_setzero_ps() - }; - let src8_2: __m256 = - if row2 < src_rows { - _mm256_loadu_ps(src_data.add( - row2 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE, - )) - } else { - _mm256_setzero_ps() - }; - let src8_3: __m256 = - if row3 < src_rows { - _mm256_loadu_ps(src_data.add( - row3 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE, - )) - } else { - _mm256_setzero_ps() - }; + let src8_1: __m256 = if row1 < src_rows { + _mm256_loadu_ps( + src_data.add(row1 * src_cols_capacity + p * ITEMS_PER_LINE), + ) + } else { + _mm256_setzero_ps() + }; + let src8_2: __m256 = if row2 < src_rows { + _mm256_loadu_ps( + src_data.add(row2 * src_cols_capacity + p * ITEMS_PER_LINE), + ) + } else { + _mm256_setzero_ps() + }; + let src8_3: __m256 = if row3 < src_rows { + _mm256_loadu_ps( + src_data.add(row3 * src_cols_capacity + p * ITEMS_PER_LINE), + ) + } else { + _mm256_setzero_ps() + }; targets8[0][0] = _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]); targets8[0][1] = _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]); targets8[0][2] = _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]); @@ -1248,7 +1255,203 @@ impl Tensor { } } } - TensorDType::Float16 => unimplemented!(), + TensorDType::Float16 => { + const ITEMS_PER_LINE: usize = 8; + + let tgt_data: *mut f16 = self.data as *mut f16; + unsafe { + std::ptr::write_bytes( + tgt_data, + 0, + self.rows as usize * self.capacity_cols as usize, + ); + } + let src_data: *const f16 = src.data as *const f16; + let other_data: *const f16 = other.data as *const f16; + + let src_rows: usize = src.rows as usize; + let src_cols: usize = src.cols as usize; + let self_rows: usize = self.rows as usize; + let self_cols: usize = self.cols as usize; + let _other_cols: usize = other.cols as usize; + let other_rows: usize = other.rows as usize; + let other_cols_capacity: usize = other.capacity_cols as usize; + let src_cols_capacity: usize = src.capacity_cols as usize; + let self_cols_capacity: usize = self.capacity_cols as usize; + + let src_cols_its = if src_cols % ITEMS_PER_LINE == 0 { + src_cols / ITEMS_PER_LINE + } else { + src_cols / ITEMS_PER_LINE + 1 + }; + let row_its = if self_rows % 4 == 0 { + self_rows / 4 + } else { + self_rows / 4 + 1 + }; + let self_cols_its = if self_cols % 4 == 0 { + self_cols / 4 + } else { + self_cols / 4 + 1 + }; + + unsafe { + for row in 0..row_its { + let row0 = row * 4; + let row1 = row * 4 + 1; + let row2 = row * 4 + 2; + let row3 = row * 4 + 3; + for col in 0..self_cols_its { + let col0 = col * 4; + let col1 = col * 4 + 1; + let col2 = col * 4 + 2; + let col3 = col * 4 + 3; + let mut targets8: [[__m256; 4]; 4] = [ + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + ]; + // Loads from (row, column..column+8) and (row+1, column..column+8) + #[inline] + fn load2_rows( + ptr: *const f16, + row: usize, + column: usize, + cols_capacity: usize, + nrows: usize, + ) -> (__m256, __m256) { + unsafe { + let (left, right) = if row + 1 < nrows { + ( + _mm_loadu_si128(ptr.add(row * cols_capacity + column) + as *const __m128i), + _mm_loadu_si128( + ptr.add((row + 1) * cols_capacity + column) + as *const __m128i, + ), + ) + } else { + ( + _mm_loadu_si128(ptr.add(row * cols_capacity + column) + as *const __m128i), + _mm_setzero_si128(), + ) + }; + let left: __m256 = _mm256_cvtph_ps(left); + let right: __m256 = _mm256_cvtph_ps(right); + (left, right) + } + } + for p in 0..src_cols_its { + let (other8_0, other8_1) = load2_rows( + other_data, + col0, + p * ITEMS_PER_LINE, + other_cols_capacity, + other_rows, + ); + let (other8_2, other8_3) = load2_rows( + other_data, + col2, + p * ITEMS_PER_LINE, + other_cols_capacity, + other_rows, + ); + let (src8_0, src8_1) = load2_rows( + src_data, + row0, + p * ITEMS_PER_LINE, + src_cols_capacity, + src_rows, + ); + let (src8_2, src8_3) = load2_rows( + src_data, + row2, + p * ITEMS_PER_LINE, + src_cols_capacity, + src_rows, + ); + targets8[0][0] = _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]); + targets8[0][1] = _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]); + targets8[0][2] = _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]); + targets8[0][3] = _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]); + targets8[1][0] = _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]); + targets8[1][1] = _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]); + targets8[1][2] = _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]); + targets8[1][3] = _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]); + targets8[2][0] = _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]); + targets8[2][1] = _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]); + targets8[2][2] = _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]); + targets8[2][3] = _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]); + targets8[3][0] = _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]); + targets8[3][1] = _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]); + targets8[3][2] = _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]); + targets8[3][3] = _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]); + } + let target00: f16 = horizontal_sum_f32_to_f16(targets8[0][0]); + let target01: f16 = horizontal_sum_f32_to_f16(targets8[0][1]); + let target02: f16 = horizontal_sum_f32_to_f16(targets8[0][2]); + let target03: f16 = horizontal_sum_f32_to_f16(targets8[0][3]); + let target10: f16 = horizontal_sum_f32_to_f16(targets8[1][0]); + let target11: f16 = horizontal_sum_f32_to_f16(targets8[1][1]); + let target12: f16 = horizontal_sum_f32_to_f16(targets8[1][2]); + let target13: f16 = horizontal_sum_f32_to_f16(targets8[1][3]); + let target20: f16 = horizontal_sum_f32_to_f16(targets8[2][0]); + let target21: f16 = horizontal_sum_f32_to_f16(targets8[2][1]); + let target22: f16 = horizontal_sum_f32_to_f16(targets8[2][2]); + let target23: f16 = horizontal_sum_f32_to_f16(targets8[2][3]); + let target30: f16 = horizontal_sum_f32_to_f16(targets8[3][0]); + let target31: f16 = horizontal_sum_f32_to_f16(targets8[3][1]); + let target32: f16 = horizontal_sum_f32_to_f16(targets8[3][2]); + let target33: f16 = horizontal_sum_f32_to_f16(targets8[3][3]); + + *tgt_data.add(row0 * self_cols_capacity + col0) += target00; + *tgt_data.add(row0 * self_cols_capacity + col1) += target10; + *tgt_data.add(row0 * self_cols_capacity + col2) += target20; + *tgt_data.add(row0 * self_cols_capacity + col3) += target30; + if row1 < self_rows { + *tgt_data.add(row1 * self_cols_capacity + col0) += target01; + *tgt_data.add(row1 * self_cols_capacity + col1) += target11; + *tgt_data.add(row1 * self_cols_capacity + col2) += target21; + *tgt_data.add(row1 * self_cols_capacity + col3) += target31; + } + if row2 < self_rows { + *tgt_data.add(row2 * self_cols_capacity + col0) += target02; + *tgt_data.add(row2 * self_cols_capacity + col1) += target12; + *tgt_data.add(row2 * self_cols_capacity + col2) += target22; + *tgt_data.add(row2 * self_cols_capacity + col3) += target32; + } + if row3 < self_rows { + *tgt_data.add(row3 * self_cols_capacity + col0) += target03; + *tgt_data.add(row3 * self_cols_capacity + col1) += target13; + *tgt_data.add(row3 * self_cols_capacity + col2) += target23; + *tgt_data.add(row3 * self_cols_capacity + col3) += target33; + } + } + } + } + } } } @@ -2088,6 +2291,36 @@ mod tests { } } + #[test] + fn mat_mul_transposed_f32_agrees_mat_mul_transposed_f16() { + let mut rng = rand::thread_rng(); + for _ in 0..1000 { + let a = rng.gen_range(1..=128); + let b = rng.gen_range(1..=128); + let r = rng.gen_range(1..=128); + + // Make matrixes AxR and RxB + let a = Tensor::random(a, r, TensorDType::Float32); + let b = Tensor::random(r, b, TensorDType::Float32); + let a2 = a.clone().to_f16(); + let b2 = b.clone().to_f16(); + let b_transposed = b.transpose(); + let b2_transposed = b2.transpose(); + + let c = a.matrix_mul_transposed(&b_transposed); + let c2 = a2.matrix_mul_transposed(&b2_transposed); + + assert_eq!(c.rows, c2.rows); + assert_eq!(c.cols, c2.cols); + + for row in 0..c.rows { + for col in 0..c.cols { + assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-1); + } + } + } + } + #[test] fn view_preserves_values() { fn test_with_type(dtype: TensorDType) {