From 846759b2776f69b576522604c4f5067e682804fd Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Sat, 11 Mar 2023 23:21:00 -0800 Subject: [PATCH] Optimize conversions to and from f16<->32. x86 cannot do f16 operations natively, but it does have an instruction to convert them to f32. I optimized those to use SIMD instructions. --- src/benches/benchmark.rs | 15 +++++ src/tensor.rs | 135 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 145 insertions(+), 5 deletions(-) diff --git a/src/benches/benchmark.rs b/src/benches/benchmark.rs index a92c472..9a4c5fa 100644 --- a/src/benches/benchmark.rs +++ b/src/benches/benchmark.rs @@ -19,6 +19,21 @@ pub fn tensor_benchmarks(c: &mut Criterion) { let orig_84096_2 = Tensor::zeros(4096, 4096, TensorDType::Float32); let mut result_84096 = Tensor::zeros(8, 4096, TensorDType::Float32); + let orig_f32 = Tensor::zeros(1024, 1024, TensorDType::Float32); + let orig_f16 = Tensor::zeros(1024, 1024, TensorDType::Float16); + + c.bench_function("1024x1024 matrix from f32->f16", |b| { + b.iter(|| { + let _ = black_box(&orig_f32).to_f16(); + }) + }); + + c.bench_function("1024x1024 matrix from f16->f32", |b| { + b.iter(|| { + let _ = black_box(&orig_f16).to_f32(); + }) + }); + c.bench_function( "matrix multiplication 8x4096 @ 4096x4096 f32 in-place", |b| { diff --git a/src/tensor.rs b/src/tensor.rs index d1f94dd..f1f93e9 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -90,7 +90,14 @@ impl Drop for Tensor { } } -fn compute_capacity_cols(cols: i64) -> i64 { +fn compute_capacity_cols(dtype: TensorDType, cols: i64) -> i64 { + match dtype { + TensorDType::Float16 => compute_capacity_cols_f16(cols), + TensorDType::Float32 => compute_capacity_cols_f32(cols), + } +} + +fn compute_capacity_cols_f32(cols: i64) -> i64 { if cols % 8 == 0 { cols } else { @@ -98,6 +105,14 @@ fn compute_capacity_cols(cols: i64) -> i64 { } } +fn compute_capacity_cols_f16(cols: i64) -> i64 { + if cols % 16 == 0 { + cols + } else { + cols + 16 - cols % 16 + } +} + #[inline] fn horizontal_sum(mut ymm: __m256) -> f32 { unsafe { @@ -231,7 +246,7 @@ impl Tensor { return tensor; } // Rouns up cols to 8 - let capacity_cols = compute_capacity_cols(cols); + let capacity_cols = compute_capacity_cols(dtype, cols); let nitems = rows * capacity_cols; let layout = Layout::from_size_align((nitems as usize) * dtype.bytes_per_item(), 32).unwrap(); @@ -1030,7 +1045,7 @@ impl Tensor { tensor.cols = cols; return tensor; } - let capacity_cols = compute_capacity_cols(cols); + let capacity_cols = compute_capacity_cols(dtype, cols); let nitems = rows * capacity_cols; let layout = Layout::from_size_align((nitems as usize) * dtype.bytes_per_item(), 32).unwrap(); @@ -1129,7 +1144,8 @@ impl Tensor { } } - pub fn to_f32(&self) -> Tensor { + /// Naive implementation of to_f32, used for testing that the faster methods are correct. + pub fn to_f32_naive(&self) -> Tensor { if self.dtype == TensorDType::Float32 { return self.clone(); } @@ -1145,7 +1161,41 @@ impl Tensor { result } - pub fn to_f16(&self) -> Tensor { + pub fn to_f32(&self) -> Tensor { + if self.dtype == TensorDType::Float32 { + return self.clone(); + } + + assert_eq!(self.dtype, TensorDType::Float16); + + unsafe { + let cols_it = if self.cols % 8 == 0 { + self.cols / 8 + } else { + self.cols / 8 + 1 + }; + let result = Tensor::uninitialized(self.rows, self.cols, TensorDType::Float32); + + let self_data: *const f16 = self.data as *const f16; + let tgt_data: *mut f32 = result.data as *mut f32; + let tgt_capacity_cols = result.capacity_cols as i64; + let self_capacity_cols = self.capacity_cols as i64; + for row in 0..self.rows { + for col in 0..cols_it { + let col = col * 8; + let val8: __m128i = + _mm_loadu_si128(self_data.add((row * self_capacity_cols + col) as usize) + as *const __m128i); + let val8: __m256 = _mm256_cvtph_ps(val8); + _mm256_storeu_ps(tgt_data.add((row * tgt_capacity_cols + col) as usize), val8); + } + } + result + } + } + + /// Naive implementation of to_f16, used for testing that the faster methods are correct. + pub fn to_f16_naive(&self) -> Tensor { if self.dtype == TensorDType::Float16 { return self.clone(); } @@ -1161,6 +1211,39 @@ impl Tensor { result } + pub fn to_f16(&self) -> Tensor { + if self.dtype == TensorDType::Float16 { + return self.clone(); + } + + unsafe { + let cols_it = if self.cols % 8 == 0 { + self.cols / 8 + } else { + self.cols / 8 + 1 + }; + let result = Tensor::uninitialized(self.rows, self.cols, TensorDType::Float16); + let self_data: *const f32 = self.data as *const f32; + let tgt_data: *mut f16 = result.data as *mut f16; + let tgt_capacity_cols = result.capacity_cols as i64; + let self_capacity_cols = self.capacity_cols as i64; + + for row in 0..self.rows { + for col in 0..cols_it { + let col = col * 8; + let val8: __m256 = + _mm256_loadu_ps(self_data.add((row * self_capacity_cols + col) as usize)); + let val8: __m128i = _mm256_cvtps_ph(val8, 0); + _mm_storeu_si128( + tgt_data.add((row * tgt_capacity_cols + col) as usize) as *mut __m128i, + val8, + ); + } + } + result + } + } + pub fn row(&self, row: i64) -> Tensor { if row < 0 || row > self.rows { panic!("Invalid row index"); @@ -1658,4 +1741,46 @@ mod tests { } } } + + #[test] + fn conversion_from_f16_tensor_to_f32_tensor_agrees_with_naive() { + let mut rng = rand::thread_rng(); + for _ in 0..200 { + let rows = rng.gen_range(1..100); + let cols = rng.gen_range(1..100); + + let src = Tensor::random(rows, cols, TensorDType::Float16); + let tgt1 = src.to_f32_naive(); + let tgt2 = src.to_f32(); + + assert_eq!(tgt1.rows(), tgt2.rows()); + assert_eq!(tgt1.cols(), tgt2.cols()); + for row in 0..tgt1.rows { + for col in 0..tgt1.cols { + assert_eq!(tgt1.get_f32(row, col), tgt2.get_f32(row, col)); + } + } + } + } + + #[test] + fn conversion_from_f32_tensor_to_f16_tensor_agrees_with_naive() { + let mut rng = rand::thread_rng(); + for _ in 0..200 { + let rows = rng.gen_range(1..100); + let cols = rng.gen_range(1..100); + + let src = Tensor::random(rows, cols, TensorDType::Float32); + let tgt1 = src.to_f16_naive(); + let tgt2 = src.to_f16(); + + assert_eq!(tgt1.rows(), tgt2.rows()); + assert_eq!(tgt1.cols(), tgt2.cols()); + for row in 0..tgt1.rows { + for col in 0..tgt1.cols { + assert_eq!(tgt1.get_f32(row, col), tgt2.get_f32(row, col)); + } + } + } + } }