diff --git a/src/lib.rs b/src/lib.rs index 5e1cf29..2deb2f4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ pub mod embedding; pub mod protomodels; pub mod rllama_main; +pub mod simd_support; pub mod tensor; #[cfg(feature = "opencl")] pub mod tensor_opencl_support; diff --git a/src/simd_support.rs b/src/simd_support.rs new file mode 100644 index 0000000..660677b --- /dev/null +++ b/src/simd_support.rs @@ -0,0 +1,116 @@ +// This file contains platform-specific SIMD so that rest of rllama does not need to care which +// platform it is on. + +use core::arch::x86_64::*; +use half::f16; + +pub type I32x8 = __m256i; +pub type F32x8 = __m256; +pub type I16x8 = __m128i; + +/* ------------------ */ +/* Loading and storing things */ +/* ------------------ */ + +#[inline] +pub fn load_i16x8(ptr: *const I16x8) -> I16x8 { + unsafe { _mm_loadu_si128(ptr) } +} + +#[inline] +pub fn store_i16x8(ptr: *mut I16x8, a: I16x8) { + unsafe { _mm_storeu_si128(ptr, a) } +} + +#[inline] +pub fn load_f32x8(ptr: *const F32x8) -> F32x8 { + unsafe { _mm256_loadu_ps(ptr as *const f32) } +} + +#[inline] +pub fn store_f32x8(ptr: *mut F32x8, a: F32x8) { + unsafe { _mm256_storeu_ps(ptr as *mut f32, a) } +} + +#[inline] +pub fn gather_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 { + unsafe { _mm256_i32gather_ps(ptr, indices, 1) } +} + +/* ------------------ */ +/* Conversions */ +/* ------------------ */ + +#[inline] +pub fn i16x8_as_f16_to_f32x8(a: I16x8) -> F32x8 { + unsafe { _mm256_cvtph_ps(a) } +} + +#[inline] +pub fn f32x8_to_i16x8_as_f16(a: F32x8) -> I16x8 { + unsafe { _mm256_cvtps_ph(a, 0) } +} + +/* + * Constants, creating from constants + */ + +pub fn f32x8_zero() -> F32x8 { + unsafe { _mm256_setzero_ps() } +} + +pub fn i16x8_zero() -> I16x8 { + unsafe { _mm_setzero_si128() } +} + +pub fn f32x8_singleton(value: f32) -> F32x8 { + unsafe { _mm256_set1_ps(value) } +} + +pub fn i32x8_from_values( + val0: i32, + val1: i32, + val2: i32, + val3: i32, + val4: i32, + val5: i32, + val6: i32, + val7: i32, +) -> I32x8 { + unsafe { _mm256_set_epi32(val0, val1, val2, val3, val4, val5, val6, val7) } +} + +/* + * Operations + */ + +// FMA + +// a * b + c +pub fn fma_f32x8(a: F32x8, b: F32x8, c: F32x8) -> F32x8 { + unsafe { _mm256_fmadd_ps(a, b, c) } +} + +// Horizontal sums + +#[inline] +pub fn horizontal_sum_f32x8(mut ymm: __m256) -> f32 { + unsafe { + let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1); + ymm = _mm256_add_ps(ymm, ymm2); + ymm = _mm256_hadd_ps(ymm, ymm); + ymm = _mm256_hadd_ps(ymm, ymm); + _mm256_cvtss_f32(ymm) + } +} + +#[inline] +pub fn horizontal_sum_and_f32_to_f16(mut ymm: __m256) -> f16 { + unsafe { + let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1); + ymm = _mm256_add_ps(ymm, ymm2); + ymm = _mm256_hadd_ps(ymm, ymm); + ymm = _mm256_hadd_ps(ymm, ymm); + f16::from_f32(_mm256_cvtss_f32(ymm)) + } +} diff --git a/src/tensor.rs b/src/tensor.rs index 2bb35c4..8a99834 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -17,6 +17,7 @@ * If it's "XXX_inplace", then it has a &mut self and it modifies the tensor in place. */ +use crate::simd_support::*; #[cfg(feature = "opencl")] use crate::tensor_opencl_support::{OpenCL, OpenCLError, OpenCLEvent, OpenCLTensor}; use crate::unpickler; @@ -25,7 +26,6 @@ use half::f16; use rand::Rng; use rayon::prelude::*; use std::alloc::Layout; -use std::arch::x86_64::*; use std::io::{Read, Seek}; use std::path::{Path, PathBuf}; #[cfg(feature = "opencl")] @@ -175,28 +175,6 @@ fn compute_capacity_cols_f16(cols: i64) -> i64 { } } -#[inline] -fn horizontal_sum(mut ymm: __m256) -> f32 { - unsafe { - let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1); - ymm = _mm256_add_ps(ymm, ymm2); - ymm = _mm256_hadd_ps(ymm, ymm); - ymm = _mm256_hadd_ps(ymm, ymm); - _mm256_cvtss_f32(ymm) - } -} - -#[inline] -fn horizontal_sum_f32_to_f16(mut ymm: __m256) -> f16 { - unsafe { - let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1); - ymm = _mm256_add_ps(ymm, ymm2); - ymm = _mm256_hadd_ps(ymm, ymm); - ymm = _mm256_hadd_ps(ymm, ymm); - f16::from_f32(_mm256_cvtss_f32(ymm)) - } -} - impl Tensor { #[inline] pub fn assume_on_gpu(&self) { @@ -938,19 +916,24 @@ impl Tensor { let i2_self_cols = i2 * self_cols_capacity; let i2_src_cols = i2 * src_cols_capacity; for k2 in k..std::cmp::min(k + ITEMS_PER_CACHE_LINE, src_cols) { - let other_value8: __m256 = _mm256_loadu_ps( - other_data.add(k2 * other_cols_capacity + col), + let other_value8: F32x8 = load_f32x8( + other_data.add(k2 * other_cols_capacity + col) + as *const F32x8, + ); + let src_value8_broadcast: F32x8 = + f32x8_singleton(*src_data.add(i2_src_cols + k2)); + let tgt_value8: F32x8 = load_f32x8( + tgt_data.add(i2_self_cols + col) as *const F32x8, ); - let src_value8_broadcast: __m256 = - _mm256_broadcast_ss(&*src_data.add(i2_src_cols + k2)); - let tgt_value8: __m256 = - _mm256_loadu_ps(tgt_data.add(i2_self_cols + col)); - let result8: __m256 = _mm256_fmadd_ps( + let result8: F32x8 = fma_f32x8( src_value8_broadcast, other_value8, tgt_value8, ); - _mm256_storeu_ps(tgt_data.add(i2_self_cols + col), result8); + store_f32x8( + tgt_data.add(i2_self_cols + col) as *mut F32x8, + result8, + ); } } k += ITEMS_PER_CACHE_LINE; @@ -993,23 +976,20 @@ impl Tensor { let i2_self_cols = i2 * self_cols; let i2_src_cols = i2 * src_cols; for k2 in k..k + ITEMS_PER_CACHE_LINE { - let other_value8: __m256 = _mm256_cvtph_ps(_mm_loadu_si128( + let other_value8: F32x8 = i16x8_as_f16_to_f32x8(load_i16x8( other_data.add(k2 * other_cols + col) as *const _, )); let src_value8: f16 = *src_data.add(i2_src_cols + k2); - let src_value8_broadcast: __m256 = - _mm256_broadcast_ss(&src_value8.to_f32()); - let tgt_value8: __m256 = _mm256_cvtph_ps(_mm_loadu_si128( + let src_value8_broadcast: F32x8 = + f32x8_singleton(src_value8.to_f32()); + let tgt_value8: F32x8 = i16x8_as_f16_to_f32x8(load_i16x8( tgt_data.add(i2_self_cols + col) as *const _, )); - let result8: __m256 = _mm256_fmadd_ps( - src_value8_broadcast, - other_value8, - tgt_value8, - ); - let result8_packed: __m128i = _mm256_cvtps_ph(result8, 0); - _mm_storeu_si128( - tgt_data.add(i2_self_cols + col) as *mut _, + let result8: F32x8 = + fma_f32x8(src_value8_broadcast, other_value8, tgt_value8); + let result8_packed: I16x8 = f32x8_to_i16x8_as_f16(result8); + store_i16x8( + tgt_data.add(i2_self_cols + col) as *mut I16x8, result8_packed, ); } @@ -1186,137 +1166,109 @@ impl Tensor { let col1 = col * 4 + 1; let col2 = col * 4 + 2; let col3 = col * 4 + 3; - let mut targets8: [[__m256; 4]; 4] = [ - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], + let mut targets8: [[F32x8; 4]; 4] = [ + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()], + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()], + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()], + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()], ]; for p in 0..src_cols_its { - let other8_0: __m256 = _mm256_loadu_ps( + let other8_0: F32x8 = load_f32x8( other_data - .add(col0 * other_cols_capacity + p * ITEMS_PER_LINE), + .add(col0 * other_cols_capacity + p * ITEMS_PER_LINE) + as *const F32x8, ); - let other8_1: __m256 = + let other8_1: F32x8 = if col1 < other_rows { - _mm256_loadu_ps(other_data.add( + load_f32x8(other_data.add( col1 * other_cols_capacity + p * ITEMS_PER_LINE, - )) + ) + as *const F32x8) } else { - _mm256_setzero_ps() + f32x8_zero() }; - let other8_2: __m256 = + let other8_2: F32x8 = if col2 < other_rows { - _mm256_loadu_ps(other_data.add( + load_f32x8(other_data.add( col2 * other_cols_capacity + p * ITEMS_PER_LINE, - )) + ) + as *const F32x8) } else { - _mm256_setzero_ps() + f32x8_zero() }; - let other8_3: __m256 = + let other8_3: F32x8 = if col3 < other_rows { - _mm256_loadu_ps(other_data.add( + load_f32x8(other_data.add( col3 * other_cols_capacity + p * ITEMS_PER_LINE, - )) + ) + as *const F32x8) } else { - _mm256_setzero_ps() + f32x8_zero() }; - let src8_0: __m256 = _mm256_loadu_ps( - src_data.add(row0 * src_cols_capacity + p * ITEMS_PER_LINE), + let src8_0: F32x8 = load_f32x8( + src_data.add(row0 * src_cols_capacity + p * ITEMS_PER_LINE) + as *const F32x8, ); - let src8_1: __m256 = if row1 < src_rows { - _mm256_loadu_ps( + let src8_1: F32x8 = if row1 < src_rows { + load_f32x8( src_data - .add(row1 * src_cols_capacity + p * ITEMS_PER_LINE), + .add(row1 * src_cols_capacity + p * ITEMS_PER_LINE) + as *const F32x8, ) } else { - _mm256_setzero_ps() + f32x8_zero() }; - let src8_2: __m256 = if row2 < src_rows { - _mm256_loadu_ps( + let src8_2: F32x8 = if row2 < src_rows { + load_f32x8( src_data - .add(row2 * src_cols_capacity + p * ITEMS_PER_LINE), + .add(row2 * src_cols_capacity + p * ITEMS_PER_LINE) + as *const F32x8, ) } else { - _mm256_setzero_ps() + f32x8_zero() }; - let src8_3: __m256 = if row3 < src_rows { - _mm256_loadu_ps( + let src8_3: F32x8 = if row3 < src_rows { + load_f32x8( src_data - .add(row3 * src_cols_capacity + p * ITEMS_PER_LINE), + .add(row3 * src_cols_capacity + p * ITEMS_PER_LINE) + as *const F32x8, ) } else { - _mm256_setzero_ps() + f32x8_zero() }; - targets8[0][0] = - _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]); - targets8[0][1] = - _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]); - targets8[0][2] = - _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]); - targets8[0][3] = - _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]); - targets8[1][0] = - _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]); - targets8[1][1] = - _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]); - targets8[1][2] = - _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]); - targets8[1][3] = - _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]); - targets8[2][0] = - _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]); - targets8[2][1] = - _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]); - targets8[2][2] = - _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]); - targets8[2][3] = - _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]); - targets8[3][0] = - _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]); - targets8[3][1] = - _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]); - targets8[3][2] = - _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]); - targets8[3][3] = - _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]); + targets8[0][0] = fma_f32x8(src8_0, other8_0, targets8[0][0]); + targets8[0][1] = fma_f32x8(src8_1, other8_0, targets8[0][1]); + targets8[0][2] = fma_f32x8(src8_2, other8_0, targets8[0][2]); + targets8[0][3] = fma_f32x8(src8_3, other8_0, targets8[0][3]); + targets8[1][0] = fma_f32x8(src8_0, other8_1, targets8[1][0]); + targets8[1][1] = fma_f32x8(src8_1, other8_1, targets8[1][1]); + targets8[1][2] = fma_f32x8(src8_2, other8_1, targets8[1][2]); + targets8[1][3] = fma_f32x8(src8_3, other8_1, targets8[1][3]); + targets8[2][0] = fma_f32x8(src8_0, other8_2, targets8[2][0]); + targets8[2][1] = fma_f32x8(src8_1, other8_2, targets8[2][1]); + targets8[2][2] = fma_f32x8(src8_2, other8_2, targets8[2][2]); + targets8[2][3] = fma_f32x8(src8_3, other8_2, targets8[2][3]); + targets8[3][0] = fma_f32x8(src8_0, other8_3, targets8[3][0]); + targets8[3][1] = fma_f32x8(src8_1, other8_3, targets8[3][1]); + targets8[3][2] = fma_f32x8(src8_2, other8_3, targets8[3][2]); + targets8[3][3] = fma_f32x8(src8_3, other8_3, targets8[3][3]); } - let target00: f32 = horizontal_sum(targets8[0][0]); - let target01: f32 = horizontal_sum(targets8[0][1]); - let target02: f32 = horizontal_sum(targets8[0][2]); - let target03: f32 = horizontal_sum(targets8[0][3]); - let target10: f32 = horizontal_sum(targets8[1][0]); - let target11: f32 = horizontal_sum(targets8[1][1]); - let target12: f32 = horizontal_sum(targets8[1][2]); - let target13: f32 = horizontal_sum(targets8[1][3]); - let target20: f32 = horizontal_sum(targets8[2][0]); - let target21: f32 = horizontal_sum(targets8[2][1]); - let target22: f32 = horizontal_sum(targets8[2][2]); - let target23: f32 = horizontal_sum(targets8[2][3]); - let target30: f32 = horizontal_sum(targets8[3][0]); - let target31: f32 = horizontal_sum(targets8[3][1]); - let target32: f32 = horizontal_sum(targets8[3][2]); - let target33: f32 = horizontal_sum(targets8[3][3]); + let target00: f32 = horizontal_sum_f32x8(targets8[0][0]); + let target01: f32 = horizontal_sum_f32x8(targets8[0][1]); + let target02: f32 = horizontal_sum_f32x8(targets8[0][2]); + let target03: f32 = horizontal_sum_f32x8(targets8[0][3]); + let target10: f32 = horizontal_sum_f32x8(targets8[1][0]); + let target11: f32 = horizontal_sum_f32x8(targets8[1][1]); + let target12: f32 = horizontal_sum_f32x8(targets8[1][2]); + let target13: f32 = horizontal_sum_f32x8(targets8[1][3]); + let target20: f32 = horizontal_sum_f32x8(targets8[2][0]); + let target21: f32 = horizontal_sum_f32x8(targets8[2][1]); + let target22: f32 = horizontal_sum_f32x8(targets8[2][2]); + let target23: f32 = horizontal_sum_f32x8(targets8[2][3]); + let target30: f32 = horizontal_sum_f32x8(targets8[3][0]); + let target31: f32 = horizontal_sum_f32x8(targets8[3][1]); + let target32: f32 = horizontal_sum_f32x8(targets8[3][2]); + let target33: f32 = horizontal_sum_f32x8(targets8[3][3]); *tgt_data.add(row0 * self_cols_capacity + col0) += target00; *tgt_data.add(row0 * self_cols_capacity + col1) += target10; @@ -1407,31 +1359,11 @@ impl Tensor { let col1 = col * 4 + 1; let col2 = col * 4 + 2; let col3 = col * 4 + 3; - let mut targets8: [[__m256; 4]; 4] = [ - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], + let mut targets8: [[F32x8; 4]; 4] = [ + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()], + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()], + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()], + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()], ]; // Loads from (row, column..column+8) and (row+1, column..column+8) #[inline] @@ -1441,30 +1373,26 @@ impl Tensor { column: usize, cols_capacity: usize, nrows: usize, - ) -> (__m256, __m256) { + ) -> (F32x8, F32x8) { unsafe { let (left, right) = if row + 1 < nrows { ( - _mm_loadu_si128( - ptr.add(row * cols_capacity + column) - as *const __m128i, - ), - _mm_loadu_si128( + load_i16x8(ptr.add(row * cols_capacity + column) + as *const I16x8), + load_i16x8( ptr.add((row + 1) * cols_capacity + column) - as *const __m128i, + as *const I16x8, ), ) } else { ( - _mm_loadu_si128( - ptr.add(row * cols_capacity + column) - as *const __m128i, - ), - _mm_setzero_si128(), + load_i16x8(ptr.add(row * cols_capacity + column) + as *const I16x8), + i16x8_zero(), ) }; - let left: __m256 = _mm256_cvtph_ps(left); - let right: __m256 = _mm256_cvtph_ps(right); + let left: F32x8 = i16x8_as_f16_to_f32x8(left); + let right: F32x8 = i16x8_as_f16_to_f32x8(right); (left, right) } } @@ -1497,55 +1425,39 @@ impl Tensor { src_cols_capacity, src_rows, ); - targets8[0][0] = - _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]); - targets8[0][1] = - _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]); - targets8[0][2] = - _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]); - targets8[0][3] = - _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]); - targets8[1][0] = - _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]); - targets8[1][1] = - _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]); - targets8[1][2] = - _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]); - targets8[1][3] = - _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]); - targets8[2][0] = - _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]); - targets8[2][1] = - _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]); - targets8[2][2] = - _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]); - targets8[2][3] = - _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]); - targets8[3][0] = - _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]); - targets8[3][1] = - _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]); - targets8[3][2] = - _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]); - targets8[3][3] = - _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]); + targets8[0][0] = fma_f32x8(src8_0, other8_0, targets8[0][0]); + targets8[0][1] = fma_f32x8(src8_1, other8_0, targets8[0][1]); + targets8[0][2] = fma_f32x8(src8_2, other8_0, targets8[0][2]); + targets8[0][3] = fma_f32x8(src8_3, other8_0, targets8[0][3]); + targets8[1][0] = fma_f32x8(src8_0, other8_1, targets8[1][0]); + targets8[1][1] = fma_f32x8(src8_1, other8_1, targets8[1][1]); + targets8[1][2] = fma_f32x8(src8_2, other8_1, targets8[1][2]); + targets8[1][3] = fma_f32x8(src8_3, other8_1, targets8[1][3]); + targets8[2][0] = fma_f32x8(src8_0, other8_2, targets8[2][0]); + targets8[2][1] = fma_f32x8(src8_1, other8_2, targets8[2][1]); + targets8[2][2] = fma_f32x8(src8_2, other8_2, targets8[2][2]); + targets8[2][3] = fma_f32x8(src8_3, other8_2, targets8[2][3]); + targets8[3][0] = fma_f32x8(src8_0, other8_3, targets8[3][0]); + targets8[3][1] = fma_f32x8(src8_1, other8_3, targets8[3][1]); + targets8[3][2] = fma_f32x8(src8_2, other8_3, targets8[3][2]); + targets8[3][3] = fma_f32x8(src8_3, other8_3, targets8[3][3]); } - let target00: f16 = horizontal_sum_f32_to_f16(targets8[0][0]); - let target01: f16 = horizontal_sum_f32_to_f16(targets8[0][1]); - let target02: f16 = horizontal_sum_f32_to_f16(targets8[0][2]); - let target03: f16 = horizontal_sum_f32_to_f16(targets8[0][3]); - let target10: f16 = horizontal_sum_f32_to_f16(targets8[1][0]); - let target11: f16 = horizontal_sum_f32_to_f16(targets8[1][1]); - let target12: f16 = horizontal_sum_f32_to_f16(targets8[1][2]); - let target13: f16 = horizontal_sum_f32_to_f16(targets8[1][3]); - let target20: f16 = horizontal_sum_f32_to_f16(targets8[2][0]); - let target21: f16 = horizontal_sum_f32_to_f16(targets8[2][1]); - let target22: f16 = horizontal_sum_f32_to_f16(targets8[2][2]); - let target23: f16 = horizontal_sum_f32_to_f16(targets8[2][3]); - let target30: f16 = horizontal_sum_f32_to_f16(targets8[3][0]); - let target31: f16 = horizontal_sum_f32_to_f16(targets8[3][1]); - let target32: f16 = horizontal_sum_f32_to_f16(targets8[3][2]); - let target33: f16 = horizontal_sum_f32_to_f16(targets8[3][3]); + let target00: f16 = horizontal_sum_and_f32_to_f16(targets8[0][0]); + let target01: f16 = horizontal_sum_and_f32_to_f16(targets8[0][1]); + let target02: f16 = horizontal_sum_and_f32_to_f16(targets8[0][2]); + let target03: f16 = horizontal_sum_and_f32_to_f16(targets8[0][3]); + let target10: f16 = horizontal_sum_and_f32_to_f16(targets8[1][0]); + let target11: f16 = horizontal_sum_and_f32_to_f16(targets8[1][1]); + let target12: f16 = horizontal_sum_and_f32_to_f16(targets8[1][2]); + let target13: f16 = horizontal_sum_and_f32_to_f16(targets8[1][3]); + let target20: f16 = horizontal_sum_and_f32_to_f16(targets8[2][0]); + let target21: f16 = horizontal_sum_and_f32_to_f16(targets8[2][1]); + let target22: f16 = horizontal_sum_and_f32_to_f16(targets8[2][2]); + let target23: f16 = horizontal_sum_and_f32_to_f16(targets8[2][3]); + let target30: f16 = horizontal_sum_and_f32_to_f16(targets8[3][0]); + let target31: f16 = horizontal_sum_and_f32_to_f16(targets8[3][1]); + let target32: f16 = horizontal_sum_and_f32_to_f16(targets8[3][2]); + let target33: f16 = horizontal_sum_and_f32_to_f16(targets8[3][3]); *tgt_data.add(row0 * self_cols_capacity + col0) += target00; *tgt_data.add(row0 * self_cols_capacity + col1) += target10; @@ -1641,33 +1553,23 @@ impl Tensor { } else { (self.rows / 4 + 1) as usize }; - let mut sum8s: [[__m256; 4]; 2] = [ - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], + let mut sum8s: [[F32x8; 4]; 2] = [ + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()], + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()], ]; let self_data: *const f16 = self.data as *const f16; let other_data: *const f16 = other.data as *const f16; let _ncols_capacity: usize = result.capacity_cols as usize; for row in 0..row_its { let row: i64 = row as i64; - sum8s[0][0] = _mm256_setzero_ps(); - sum8s[0][1] = _mm256_setzero_ps(); - sum8s[0][2] = _mm256_setzero_ps(); - sum8s[0][3] = _mm256_setzero_ps(); - sum8s[1][0] = _mm256_setzero_ps(); - sum8s[1][1] = _mm256_setzero_ps(); - sum8s[1][2] = _mm256_setzero_ps(); - sum8s[1][3] = _mm256_setzero_ps(); + sum8s[0][0] = f32x8_zero(); + sum8s[0][1] = f32x8_zero(); + sum8s[0][2] = f32x8_zero(); + sum8s[0][3] = f32x8_zero(); + sum8s[1][0] = f32x8_zero(); + sum8s[1][1] = f32x8_zero(); + sum8s[1][2] = f32x8_zero(); + sum8s[1][3] = f32x8_zero(); let row4_0 = row * 4; let row4_1 = row * 4 + 1; let row4_2 = row * 4 + 2; @@ -1675,8 +1577,8 @@ impl Tensor { // Loads from (0, column..column+8) #[inline] - fn load2(ptr: *const f16, col: usize) -> __m256 { - unsafe { _mm256_cvtph_ps(_mm_loadu_si128(ptr.add(col) as *const __m128i)) } + fn load2(ptr: *const f16, col: usize) -> F32x8 { + unsafe { i16x8_as_f16_to_f32x8(load_i16x8(ptr.add(col) as *const I16x8)) } } // Loads from (row, column..column+8) #[inline] @@ -1686,15 +1588,15 @@ impl Tensor { col: usize, cols_capacity: i64, nrows: i64, - ) -> __m256 { + ) -> F32x8 { unsafe { if row < nrows { - _mm256_cvtph_ps(_mm_loadu_si128( + i16x8_as_f16_to_f32x8(load_i16x8( ptr.add(row as usize * cols_capacity as usize + col) - as *const __m128i, + as *const I16x8, )) } else { - _mm256_setzero_ps() + f32x8_zero() } } } @@ -1711,10 +1613,10 @@ impl Tensor { load2row(self_data, row4_2, col, self.capacity_cols, self.rows); let left_side8_30 = load2row(self_data, row4_3, col, self.capacity_cols, self.rows); - sum8s[0][0] = _mm256_fmadd_ps(left_side8_00, right_side8_0, sum8s[0][0]); - sum8s[0][1] = _mm256_fmadd_ps(left_side8_10, right_side8_0, sum8s[0][1]); - sum8s[0][2] = _mm256_fmadd_ps(left_side8_20, right_side8_0, sum8s[0][2]); - sum8s[0][3] = _mm256_fmadd_ps(left_side8_30, right_side8_0, sum8s[0][3]); + sum8s[0][0] = fma_f32x8(left_side8_00, right_side8_0, sum8s[0][0]); + sum8s[0][1] = fma_f32x8(left_side8_10, right_side8_0, sum8s[0][1]); + sum8s[0][2] = fma_f32x8(left_side8_20, right_side8_0, sum8s[0][2]); + sum8s[0][3] = fma_f32x8(left_side8_30, right_side8_0, sum8s[0][3]); let right_side8_1 = load2(other_data, col2); let left_side8_01 = load2row(self_data, row4_0, col2, self.capacity_cols, self.rows); @@ -1724,15 +1626,19 @@ impl Tensor { load2row(self_data, row4_2, col2, self.capacity_cols, self.rows); let left_side8_31 = load2row(self_data, row4_3, col2, self.capacity_cols, self.rows); - sum8s[1][0] = _mm256_fmadd_ps(left_side8_01, right_side8_1, sum8s[1][0]); - sum8s[1][1] = _mm256_fmadd_ps(left_side8_11, right_side8_1, sum8s[1][1]); - sum8s[1][2] = _mm256_fmadd_ps(left_side8_21, right_side8_1, sum8s[1][2]); - sum8s[1][3] = _mm256_fmadd_ps(left_side8_31, right_side8_1, sum8s[1][3]); + sum8s[1][0] = fma_f32x8(left_side8_01, right_side8_1, sum8s[1][0]); + sum8s[1][1] = fma_f32x8(left_side8_11, right_side8_1, sum8s[1][1]); + sum8s[1][2] = fma_f32x8(left_side8_21, right_side8_1, sum8s[1][2]); + sum8s[1][3] = fma_f32x8(left_side8_31, right_side8_1, sum8s[1][3]); } - let sum_0: f32 = horizontal_sum(sum8s[0][0]) + horizontal_sum(sum8s[1][0]); - let sum_1: f32 = horizontal_sum(sum8s[0][1]) + horizontal_sum(sum8s[1][1]); - let sum_2: f32 = horizontal_sum(sum8s[0][2]) + horizontal_sum(sum8s[1][2]); - let sum_3: f32 = horizontal_sum(sum8s[0][3]) + horizontal_sum(sum8s[1][3]); + let sum_0: f32 = + horizontal_sum_f32x8(sum8s[0][0]) + horizontal_sum_f32x8(sum8s[1][0]); + let sum_1: f32 = + horizontal_sum_f32x8(sum8s[0][1]) + horizontal_sum_f32x8(sum8s[1][1]); + let sum_2: f32 = + horizontal_sum_f32x8(sum8s[0][2]) + horizontal_sum_f32x8(sum8s[1][2]); + let sum_3: f32 = + horizontal_sum_f32x8(sum8s[0][3]) + horizontal_sum_f32x8(sum8s[1][3]); if row4_0 < result.rows { result.set_f32(row4_0, 0, sum_0); } @@ -1770,19 +1676,14 @@ impl Tensor { let tgt_data: *mut f32 = result.data as *mut f32; let ncols_capacity: usize = result.capacity_cols as usize; - let mut sum8s: [__m256; 4] = [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ]; + let mut sum8s: [F32x8; 4] = [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; for row in 0..row_its { let row: i64 = row as i64; - sum8s[0] = _mm256_setzero_ps(); - sum8s[1] = _mm256_setzero_ps(); - sum8s[2] = _mm256_setzero_ps(); - sum8s[3] = _mm256_setzero_ps(); + sum8s[0] = f32x8_zero(); + sum8s[1] = f32x8_zero(); + sum8s[2] = f32x8_zero(); + sum8s[3] = f32x8_zero(); let row4_0 = row * 4; let row4_1 = row * 4 + 1; let row4_2 = row * 4 + 2; @@ -1790,34 +1691,37 @@ impl Tensor { for col in 0..col_its { let col = col * 8; - let right_side8 = _mm256_loadu_ps(other_data.add(col)); - let left_side8_0 = _mm256_loadu_ps( - self_data.add((row4_0 * self.capacity_cols) as usize + col), - ); + let right_side8 = load_f32x8(other_data.add(col) as *const F32x8); + let left_side8_0 = + load_f32x8(self_data.add((row4_0 * self.capacity_cols) as usize + col) + as *const F32x8); let left_side8_1 = if row4_1 < self.rows { - _mm256_loadu_ps(self_data.add((row4_1 * self.capacity_cols) as usize + col)) + load_f32x8(self_data.add((row4_1 * self.capacity_cols) as usize + col) + as *const F32x8) } else { - _mm256_setzero_ps() + f32x8_zero() }; let left_side8_2 = if row4_2 < self.rows { - _mm256_loadu_ps(self_data.add((row4_2 * self.capacity_cols) as usize + col)) + load_f32x8(self_data.add((row4_2 * self.capacity_cols) as usize + col) + as *const F32x8) } else { - _mm256_setzero_ps() + f32x8_zero() }; let left_side8_3 = if row4_3 < self.rows { - _mm256_loadu_ps(self_data.add((row4_3 * self.capacity_cols) as usize + col)) + load_f32x8(self_data.add((row4_3 * self.capacity_cols) as usize + col) + as *const F32x8) } else { - _mm256_setzero_ps() + f32x8_zero() }; - sum8s[0] = _mm256_fmadd_ps(left_side8_0, right_side8, sum8s[0]); - sum8s[1] = _mm256_fmadd_ps(left_side8_1, right_side8, sum8s[1]); - sum8s[2] = _mm256_fmadd_ps(left_side8_2, right_side8, sum8s[2]); - sum8s[3] = _mm256_fmadd_ps(left_side8_3, right_side8, sum8s[3]); + sum8s[0] = fma_f32x8(left_side8_0, right_side8, sum8s[0]); + sum8s[1] = fma_f32x8(left_side8_1, right_side8, sum8s[1]); + sum8s[2] = fma_f32x8(left_side8_2, right_side8, sum8s[2]); + sum8s[3] = fma_f32x8(left_side8_3, right_side8, sum8s[3]); } - let sum_0: f32 = horizontal_sum(sum8s[0]); - let sum_1: f32 = horizontal_sum(sum8s[1]); - let sum_2: f32 = horizontal_sum(sum8s[2]); - let sum_3: f32 = horizontal_sum(sum8s[3]); + let sum_0: f32 = horizontal_sum_f32x8(sum8s[0]); + let sum_1: f32 = horizontal_sum_f32x8(sum8s[1]); + let sum_2: f32 = horizontal_sum_f32x8(sum8s[2]); + let sum_3: f32 = horizontal_sum_f32x8(sum8s[3]); if row4_0 < result.rows { *(tgt_data.add(row4_0 as usize * ncols_capacity)) = sum_0; } @@ -1871,10 +1775,10 @@ impl Tensor { for col in 0..other.cols { let col = col as usize; - let mut sum8: __m256 = _mm256_setzero_ps(); + let mut sum8: F32x8 = f32x8_zero(); for row8 in 0..col_its { let row = row8 * 8; - let left = _mm256_loadu_ps(left_data.add(row)); + let left = load_f32x8(left_data.add(row) as *const F32x8); let mut r = [0.0f32; 8]; // i hate you clippy because you ask me // to make code more unreadable @@ -1885,17 +1789,16 @@ impl Tensor { } } let right = if row + 8 <= other.rows as usize { - _mm256_i32gather_ps( + gather_f32x8( right_data.add(row * other_capacity_cols + col), - _mm256_set_epi32(o7, o6, o5, o4, o3, o2, o1, o0), - 1, + i32x8_from_values(o7, o6, o5, o4, o3, o2, o1, o0), ) } else { - _mm256_loadu_ps(r.as_ptr()) + load_f32x8(r.as_ptr() as *const F32x8) }; - sum8 = _mm256_fmadd_ps(left, right, sum8); + sum8 = fma_f32x8(left, right, sum8); } - *tgt_data.add(col) = horizontal_sum(sum8); + *tgt_data.add(col) = horizontal_sum_f32x8(sum8); } result } @@ -2159,11 +2062,14 @@ impl Tensor { for row in 0..self.rows { for col in 0..cols_it { let col = col * 8; - let val8: __m128i = - _mm_loadu_si128(self_data.add((row * self_capacity_cols + col) as usize) - as *const __m128i); - let val8: __m256 = _mm256_cvtph_ps(val8); - _mm256_storeu_ps(tgt_data.add((row * tgt_capacity_cols + col) as usize), val8); + let val8: I16x8 = + load_i16x8(self_data.add((row * self_capacity_cols + col) as usize) + as *const I16x8); + let val8: F32x8 = i16x8_as_f16_to_f32x8(val8); + store_f32x8( + tgt_data.add((row * tgt_capacity_cols + col) as usize) as *mut F32x8, + val8, + ); } } result @@ -2209,11 +2115,12 @@ impl Tensor { for row in 0..self.rows { for col in 0..cols_it { let col = col * 8; - let val8: __m256 = - _mm256_loadu_ps(self_data.add((row * self_capacity_cols + col) as usize)); - let val8: __m128i = _mm256_cvtps_ph(val8, 0); - _mm_storeu_si128( - tgt_data.add((row * tgt_capacity_cols + col) as usize) as *mut __m128i, + let val8: F32x8 = + load_f32x8(self_data.add((row * self_capacity_cols + col) as usize) + as *const F32x8); + let val8: I16x8 = f32x8_to_i16x8_as_f16(val8); + store_i16x8( + tgt_data.add((row * tgt_capacity_cols + col) as usize) as *mut I16x8, val8, ); }