Refactor all SIMD to one file, simd_support.rs

This should make it a bit easier to port to other SIMD instruction sets
when the SIMD instructions are not littered randomly around the
tensor.rs file.
master
Mikko Juola 3 years ago
parent 25e3e12d9d
commit 9c86c17318

@ -3,6 +3,7 @@
pub mod embedding; pub mod embedding;
pub mod protomodels; pub mod protomodels;
pub mod rllama_main; pub mod rllama_main;
pub mod simd_support;
pub mod tensor; pub mod tensor;
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
pub mod tensor_opencl_support; pub mod tensor_opencl_support;

@ -0,0 +1,116 @@
// This file contains platform-specific SIMD so that rest of rllama does not need to care which
// platform it is on.
use core::arch::x86_64::*;
use half::f16;
pub type I32x8 = __m256i;
pub type F32x8 = __m256;
pub type I16x8 = __m128i;
/* ------------------ */
/* Loading and storing things */
/* ------------------ */
#[inline]
pub fn load_i16x8(ptr: *const I16x8) -> I16x8 {
unsafe { _mm_loadu_si128(ptr) }
}
#[inline]
pub fn store_i16x8(ptr: *mut I16x8, a: I16x8) {
unsafe { _mm_storeu_si128(ptr, a) }
}
#[inline]
pub fn load_f32x8(ptr: *const F32x8) -> F32x8 {
unsafe { _mm256_loadu_ps(ptr as *const f32) }
}
#[inline]
pub fn store_f32x8(ptr: *mut F32x8, a: F32x8) {
unsafe { _mm256_storeu_ps(ptr as *mut f32, a) }
}
#[inline]
pub fn gather_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 {
unsafe { _mm256_i32gather_ps(ptr, indices, 1) }
}
/* ------------------ */
/* Conversions */
/* ------------------ */
#[inline]
pub fn i16x8_as_f16_to_f32x8(a: I16x8) -> F32x8 {
unsafe { _mm256_cvtph_ps(a) }
}
#[inline]
pub fn f32x8_to_i16x8_as_f16(a: F32x8) -> I16x8 {
unsafe { _mm256_cvtps_ph(a, 0) }
}
/*
* Constants, creating from constants
*/
pub fn f32x8_zero() -> F32x8 {
unsafe { _mm256_setzero_ps() }
}
pub fn i16x8_zero() -> I16x8 {
unsafe { _mm_setzero_si128() }
}
pub fn f32x8_singleton(value: f32) -> F32x8 {
unsafe { _mm256_set1_ps(value) }
}
pub fn i32x8_from_values(
val0: i32,
val1: i32,
val2: i32,
val3: i32,
val4: i32,
val5: i32,
val6: i32,
val7: i32,
) -> I32x8 {
unsafe { _mm256_set_epi32(val0, val1, val2, val3, val4, val5, val6, val7) }
}
/*
* Operations
*/
// FMA
// a * b + c
pub fn fma_f32x8(a: F32x8, b: F32x8, c: F32x8) -> F32x8 {
unsafe { _mm256_fmadd_ps(a, b, c) }
}
// Horizontal sums
#[inline]
pub fn horizontal_sum_f32x8(mut ymm: __m256) -> f32 {
unsafe {
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1);
ymm = _mm256_add_ps(ymm, ymm2);
ymm = _mm256_hadd_ps(ymm, ymm);
ymm = _mm256_hadd_ps(ymm, ymm);
_mm256_cvtss_f32(ymm)
}
}
#[inline]
pub fn horizontal_sum_and_f32_to_f16(mut ymm: __m256) -> f16 {
unsafe {
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1);
ymm = _mm256_add_ps(ymm, ymm2);
ymm = _mm256_hadd_ps(ymm, ymm);
ymm = _mm256_hadd_ps(ymm, ymm);
f16::from_f32(_mm256_cvtss_f32(ymm))
}
}

@ -17,6 +17,7 @@
* If it's "XXX_inplace", then it has a &mut self and it modifies the tensor in place. * If it's "XXX_inplace", then it has a &mut self and it modifies the tensor in place.
*/ */
use crate::simd_support::*;
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
use crate::tensor_opencl_support::{OpenCL, OpenCLError, OpenCLEvent, OpenCLTensor}; use crate::tensor_opencl_support::{OpenCL, OpenCLError, OpenCLEvent, OpenCLTensor};
use crate::unpickler; use crate::unpickler;
@ -25,7 +26,6 @@ use half::f16;
use rand::Rng; use rand::Rng;
use rayon::prelude::*; use rayon::prelude::*;
use std::alloc::Layout; use std::alloc::Layout;
use std::arch::x86_64::*;
use std::io::{Read, Seek}; use std::io::{Read, Seek};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
@ -175,28 +175,6 @@ fn compute_capacity_cols_f16(cols: i64) -> i64 {
} }
} }
#[inline]
fn horizontal_sum(mut ymm: __m256) -> f32 {
unsafe {
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1);
ymm = _mm256_add_ps(ymm, ymm2);
ymm = _mm256_hadd_ps(ymm, ymm);
ymm = _mm256_hadd_ps(ymm, ymm);
_mm256_cvtss_f32(ymm)
}
}
#[inline]
fn horizontal_sum_f32_to_f16(mut ymm: __m256) -> f16 {
unsafe {
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1);
ymm = _mm256_add_ps(ymm, ymm2);
ymm = _mm256_hadd_ps(ymm, ymm);
ymm = _mm256_hadd_ps(ymm, ymm);
f16::from_f32(_mm256_cvtss_f32(ymm))
}
}
impl Tensor { impl Tensor {
#[inline] #[inline]
pub fn assume_on_gpu(&self) { pub fn assume_on_gpu(&self) {
@ -938,19 +916,24 @@ impl Tensor {
let i2_self_cols = i2 * self_cols_capacity; let i2_self_cols = i2 * self_cols_capacity;
let i2_src_cols = i2 * src_cols_capacity; let i2_src_cols = i2 * src_cols_capacity;
for k2 in k..std::cmp::min(k + ITEMS_PER_CACHE_LINE, src_cols) { for k2 in k..std::cmp::min(k + ITEMS_PER_CACHE_LINE, src_cols) {
let other_value8: __m256 = _mm256_loadu_ps( let other_value8: F32x8 = load_f32x8(
other_data.add(k2 * other_cols_capacity + col), other_data.add(k2 * other_cols_capacity + col)
as *const F32x8,
);
let src_value8_broadcast: F32x8 =
f32x8_singleton(*src_data.add(i2_src_cols + k2));
let tgt_value8: F32x8 = load_f32x8(
tgt_data.add(i2_self_cols + col) as *const F32x8,
); );
let src_value8_broadcast: __m256 = let result8: F32x8 = fma_f32x8(
_mm256_broadcast_ss(&*src_data.add(i2_src_cols + k2));
let tgt_value8: __m256 =
_mm256_loadu_ps(tgt_data.add(i2_self_cols + col));
let result8: __m256 = _mm256_fmadd_ps(
src_value8_broadcast, src_value8_broadcast,
other_value8, other_value8,
tgt_value8, tgt_value8,
); );
_mm256_storeu_ps(tgt_data.add(i2_self_cols + col), result8); store_f32x8(
tgt_data.add(i2_self_cols + col) as *mut F32x8,
result8,
);
} }
} }
k += ITEMS_PER_CACHE_LINE; k += ITEMS_PER_CACHE_LINE;
@ -993,23 +976,20 @@ impl Tensor {
let i2_self_cols = i2 * self_cols; let i2_self_cols = i2 * self_cols;
let i2_src_cols = i2 * src_cols; let i2_src_cols = i2 * src_cols;
for k2 in k..k + ITEMS_PER_CACHE_LINE { for k2 in k..k + ITEMS_PER_CACHE_LINE {
let other_value8: __m256 = _mm256_cvtph_ps(_mm_loadu_si128( let other_value8: F32x8 = i16x8_as_f16_to_f32x8(load_i16x8(
other_data.add(k2 * other_cols + col) as *const _, other_data.add(k2 * other_cols + col) as *const _,
)); ));
let src_value8: f16 = *src_data.add(i2_src_cols + k2); let src_value8: f16 = *src_data.add(i2_src_cols + k2);
let src_value8_broadcast: __m256 = let src_value8_broadcast: F32x8 =
_mm256_broadcast_ss(&src_value8.to_f32()); f32x8_singleton(src_value8.to_f32());
let tgt_value8: __m256 = _mm256_cvtph_ps(_mm_loadu_si128( let tgt_value8: F32x8 = i16x8_as_f16_to_f32x8(load_i16x8(
tgt_data.add(i2_self_cols + col) as *const _, tgt_data.add(i2_self_cols + col) as *const _,
)); ));
let result8: __m256 = _mm256_fmadd_ps( let result8: F32x8 =
src_value8_broadcast, fma_f32x8(src_value8_broadcast, other_value8, tgt_value8);
other_value8, let result8_packed: I16x8 = f32x8_to_i16x8_as_f16(result8);
tgt_value8, store_i16x8(
); tgt_data.add(i2_self_cols + col) as *mut I16x8,
let result8_packed: __m128i = _mm256_cvtps_ph(result8, 0);
_mm_storeu_si128(
tgt_data.add(i2_self_cols + col) as *mut _,
result8_packed, result8_packed,
); );
} }
@ -1186,137 +1166,109 @@ impl Tensor {
let col1 = col * 4 + 1; let col1 = col * 4 + 1;
let col2 = col * 4 + 2; let col2 = col * 4 + 2;
let col3 = col * 4 + 3; let col3 = col * 4 + 3;
let mut targets8: [[__m256; 4]; 4] = [ let mut targets8: [[F32x8; 4]; 4] = [
[ [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()],
_mm256_setzero_ps(), [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()],
_mm256_setzero_ps(), [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()],
_mm256_setzero_ps(), [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()],
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
]; ];
for p in 0..src_cols_its { for p in 0..src_cols_its {
let other8_0: __m256 = _mm256_loadu_ps( let other8_0: F32x8 = load_f32x8(
other_data other_data
.add(col0 * other_cols_capacity + p * ITEMS_PER_LINE), .add(col0 * other_cols_capacity + p * ITEMS_PER_LINE)
as *const F32x8,
); );
let other8_1: __m256 = let other8_1: F32x8 =
if col1 < other_rows { if col1 < other_rows {
_mm256_loadu_ps(other_data.add( load_f32x8(other_data.add(
col1 * other_cols_capacity + p * ITEMS_PER_LINE, col1 * other_cols_capacity + p * ITEMS_PER_LINE,
)) )
as *const F32x8)
} else { } else {
_mm256_setzero_ps() f32x8_zero()
}; };
let other8_2: __m256 = let other8_2: F32x8 =
if col2 < other_rows { if col2 < other_rows {
_mm256_loadu_ps(other_data.add( load_f32x8(other_data.add(
col2 * other_cols_capacity + p * ITEMS_PER_LINE, col2 * other_cols_capacity + p * ITEMS_PER_LINE,
)) )
as *const F32x8)
} else { } else {
_mm256_setzero_ps() f32x8_zero()
}; };
let other8_3: __m256 = let other8_3: F32x8 =
if col3 < other_rows { if col3 < other_rows {
_mm256_loadu_ps(other_data.add( load_f32x8(other_data.add(
col3 * other_cols_capacity + p * ITEMS_PER_LINE, col3 * other_cols_capacity + p * ITEMS_PER_LINE,
)) )
as *const F32x8)
} else { } else {
_mm256_setzero_ps() f32x8_zero()
}; };
let src8_0: __m256 = _mm256_loadu_ps( let src8_0: F32x8 = load_f32x8(
src_data.add(row0 * src_cols_capacity + p * ITEMS_PER_LINE), src_data.add(row0 * src_cols_capacity + p * ITEMS_PER_LINE)
as *const F32x8,
); );
let src8_1: __m256 = if row1 < src_rows { let src8_1: F32x8 = if row1 < src_rows {
_mm256_loadu_ps( load_f32x8(
src_data src_data
.add(row1 * src_cols_capacity + p * ITEMS_PER_LINE), .add(row1 * src_cols_capacity + p * ITEMS_PER_LINE)
as *const F32x8,
) )
} else { } else {
_mm256_setzero_ps() f32x8_zero()
}; };
let src8_2: __m256 = if row2 < src_rows { let src8_2: F32x8 = if row2 < src_rows {
_mm256_loadu_ps( load_f32x8(
src_data src_data
.add(row2 * src_cols_capacity + p * ITEMS_PER_LINE), .add(row2 * src_cols_capacity + p * ITEMS_PER_LINE)
as *const F32x8,
) )
} else { } else {
_mm256_setzero_ps() f32x8_zero()
}; };
let src8_3: __m256 = if row3 < src_rows { let src8_3: F32x8 = if row3 < src_rows {
_mm256_loadu_ps( load_f32x8(
src_data src_data
.add(row3 * src_cols_capacity + p * ITEMS_PER_LINE), .add(row3 * src_cols_capacity + p * ITEMS_PER_LINE)
as *const F32x8,
) )
} else { } else {
_mm256_setzero_ps() f32x8_zero()
}; };
targets8[0][0] = targets8[0][0] = fma_f32x8(src8_0, other8_0, targets8[0][0]);
_mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]); targets8[0][1] = fma_f32x8(src8_1, other8_0, targets8[0][1]);
targets8[0][1] = targets8[0][2] = fma_f32x8(src8_2, other8_0, targets8[0][2]);
_mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]); targets8[0][3] = fma_f32x8(src8_3, other8_0, targets8[0][3]);
targets8[0][2] = targets8[1][0] = fma_f32x8(src8_0, other8_1, targets8[1][0]);
_mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]); targets8[1][1] = fma_f32x8(src8_1, other8_1, targets8[1][1]);
targets8[0][3] = targets8[1][2] = fma_f32x8(src8_2, other8_1, targets8[1][2]);
_mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]); targets8[1][3] = fma_f32x8(src8_3, other8_1, targets8[1][3]);
targets8[1][0] = targets8[2][0] = fma_f32x8(src8_0, other8_2, targets8[2][0]);
_mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]); targets8[2][1] = fma_f32x8(src8_1, other8_2, targets8[2][1]);
targets8[1][1] = targets8[2][2] = fma_f32x8(src8_2, other8_2, targets8[2][2]);
_mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]); targets8[2][3] = fma_f32x8(src8_3, other8_2, targets8[2][3]);
targets8[1][2] = targets8[3][0] = fma_f32x8(src8_0, other8_3, targets8[3][0]);
_mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]); targets8[3][1] = fma_f32x8(src8_1, other8_3, targets8[3][1]);
targets8[1][3] = targets8[3][2] = fma_f32x8(src8_2, other8_3, targets8[3][2]);
_mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]); targets8[3][3] = fma_f32x8(src8_3, other8_3, targets8[3][3]);
targets8[2][0] = }
_mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]); let target00: f32 = horizontal_sum_f32x8(targets8[0][0]);
targets8[2][1] = let target01: f32 = horizontal_sum_f32x8(targets8[0][1]);
_mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]); let target02: f32 = horizontal_sum_f32x8(targets8[0][2]);
targets8[2][2] = let target03: f32 = horizontal_sum_f32x8(targets8[0][3]);
_mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]); let target10: f32 = horizontal_sum_f32x8(targets8[1][0]);
targets8[2][3] = let target11: f32 = horizontal_sum_f32x8(targets8[1][1]);
_mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]); let target12: f32 = horizontal_sum_f32x8(targets8[1][2]);
targets8[3][0] = let target13: f32 = horizontal_sum_f32x8(targets8[1][3]);
_mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]); let target20: f32 = horizontal_sum_f32x8(targets8[2][0]);
targets8[3][1] = let target21: f32 = horizontal_sum_f32x8(targets8[2][1]);
_mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]); let target22: f32 = horizontal_sum_f32x8(targets8[2][2]);
targets8[3][2] = let target23: f32 = horizontal_sum_f32x8(targets8[2][3]);
_mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]); let target30: f32 = horizontal_sum_f32x8(targets8[3][0]);
targets8[3][3] = let target31: f32 = horizontal_sum_f32x8(targets8[3][1]);
_mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]); let target32: f32 = horizontal_sum_f32x8(targets8[3][2]);
} let target33: f32 = horizontal_sum_f32x8(targets8[3][3]);
let target00: f32 = horizontal_sum(targets8[0][0]);
let target01: f32 = horizontal_sum(targets8[0][1]);
let target02: f32 = horizontal_sum(targets8[0][2]);
let target03: f32 = horizontal_sum(targets8[0][3]);
let target10: f32 = horizontal_sum(targets8[1][0]);
let target11: f32 = horizontal_sum(targets8[1][1]);
let target12: f32 = horizontal_sum(targets8[1][2]);
let target13: f32 = horizontal_sum(targets8[1][3]);
let target20: f32 = horizontal_sum(targets8[2][0]);
let target21: f32 = horizontal_sum(targets8[2][1]);
let target22: f32 = horizontal_sum(targets8[2][2]);
let target23: f32 = horizontal_sum(targets8[2][3]);
let target30: f32 = horizontal_sum(targets8[3][0]);
let target31: f32 = horizontal_sum(targets8[3][1]);
let target32: f32 = horizontal_sum(targets8[3][2]);
let target33: f32 = horizontal_sum(targets8[3][3]);
*tgt_data.add(row0 * self_cols_capacity + col0) += target00; *tgt_data.add(row0 * self_cols_capacity + col0) += target00;
*tgt_data.add(row0 * self_cols_capacity + col1) += target10; *tgt_data.add(row0 * self_cols_capacity + col1) += target10;
@ -1407,31 +1359,11 @@ impl Tensor {
let col1 = col * 4 + 1; let col1 = col * 4 + 1;
let col2 = col * 4 + 2; let col2 = col * 4 + 2;
let col3 = col * 4 + 3; let col3 = col * 4 + 3;
let mut targets8: [[__m256; 4]; 4] = [ let mut targets8: [[F32x8; 4]; 4] = [
[ [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()],
_mm256_setzero_ps(), [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()],
_mm256_setzero_ps(), [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()],
_mm256_setzero_ps(), [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()],
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
]; ];
// Loads from (row, column..column+8) and (row+1, column..column+8) // Loads from (row, column..column+8) and (row+1, column..column+8)
#[inline] #[inline]
@ -1441,30 +1373,26 @@ impl Tensor {
column: usize, column: usize,
cols_capacity: usize, cols_capacity: usize,
nrows: usize, nrows: usize,
) -> (__m256, __m256) { ) -> (F32x8, F32x8) {
unsafe { unsafe {
let (left, right) = if row + 1 < nrows { let (left, right) = if row + 1 < nrows {
( (
_mm_loadu_si128( load_i16x8(ptr.add(row * cols_capacity + column)
ptr.add(row * cols_capacity + column) as *const I16x8),
as *const __m128i, load_i16x8(
),
_mm_loadu_si128(
ptr.add((row + 1) * cols_capacity + column) ptr.add((row + 1) * cols_capacity + column)
as *const __m128i, as *const I16x8,
), ),
) )
} else { } else {
( (
_mm_loadu_si128( load_i16x8(ptr.add(row * cols_capacity + column)
ptr.add(row * cols_capacity + column) as *const I16x8),
as *const __m128i, i16x8_zero(),
),
_mm_setzero_si128(),
) )
}; };
let left: __m256 = _mm256_cvtph_ps(left); let left: F32x8 = i16x8_as_f16_to_f32x8(left);
let right: __m256 = _mm256_cvtph_ps(right); let right: F32x8 = i16x8_as_f16_to_f32x8(right);
(left, right) (left, right)
} }
} }
@ -1497,55 +1425,39 @@ impl Tensor {
src_cols_capacity, src_cols_capacity,
src_rows, src_rows,
); );
targets8[0][0] = targets8[0][0] = fma_f32x8(src8_0, other8_0, targets8[0][0]);
_mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]); targets8[0][1] = fma_f32x8(src8_1, other8_0, targets8[0][1]);
targets8[0][1] = targets8[0][2] = fma_f32x8(src8_2, other8_0, targets8[0][2]);
_mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]); targets8[0][3] = fma_f32x8(src8_3, other8_0, targets8[0][3]);
targets8[0][2] = targets8[1][0] = fma_f32x8(src8_0, other8_1, targets8[1][0]);
_mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]); targets8[1][1] = fma_f32x8(src8_1, other8_1, targets8[1][1]);
targets8[0][3] = targets8[1][2] = fma_f32x8(src8_2, other8_1, targets8[1][2]);
_mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]); targets8[1][3] = fma_f32x8(src8_3, other8_1, targets8[1][3]);
targets8[1][0] = targets8[2][0] = fma_f32x8(src8_0, other8_2, targets8[2][0]);
_mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]); targets8[2][1] = fma_f32x8(src8_1, other8_2, targets8[2][1]);
targets8[1][1] = targets8[2][2] = fma_f32x8(src8_2, other8_2, targets8[2][2]);
_mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]); targets8[2][3] = fma_f32x8(src8_3, other8_2, targets8[2][3]);
targets8[1][2] = targets8[3][0] = fma_f32x8(src8_0, other8_3, targets8[3][0]);
_mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]); targets8[3][1] = fma_f32x8(src8_1, other8_3, targets8[3][1]);
targets8[1][3] = targets8[3][2] = fma_f32x8(src8_2, other8_3, targets8[3][2]);
_mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]); targets8[3][3] = fma_f32x8(src8_3, other8_3, targets8[3][3]);
targets8[2][0] = }
_mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]); let target00: f16 = horizontal_sum_and_f32_to_f16(targets8[0][0]);
targets8[2][1] = let target01: f16 = horizontal_sum_and_f32_to_f16(targets8[0][1]);
_mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]); let target02: f16 = horizontal_sum_and_f32_to_f16(targets8[0][2]);
targets8[2][2] = let target03: f16 = horizontal_sum_and_f32_to_f16(targets8[0][3]);
_mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]); let target10: f16 = horizontal_sum_and_f32_to_f16(targets8[1][0]);
targets8[2][3] = let target11: f16 = horizontal_sum_and_f32_to_f16(targets8[1][1]);
_mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]); let target12: f16 = horizontal_sum_and_f32_to_f16(targets8[1][2]);
targets8[3][0] = let target13: f16 = horizontal_sum_and_f32_to_f16(targets8[1][3]);
_mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]); let target20: f16 = horizontal_sum_and_f32_to_f16(targets8[2][0]);
targets8[3][1] = let target21: f16 = horizontal_sum_and_f32_to_f16(targets8[2][1]);
_mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]); let target22: f16 = horizontal_sum_and_f32_to_f16(targets8[2][2]);
targets8[3][2] = let target23: f16 = horizontal_sum_and_f32_to_f16(targets8[2][3]);
_mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]); let target30: f16 = horizontal_sum_and_f32_to_f16(targets8[3][0]);
targets8[3][3] = let target31: f16 = horizontal_sum_and_f32_to_f16(targets8[3][1]);
_mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]); let target32: f16 = horizontal_sum_and_f32_to_f16(targets8[3][2]);
} let target33: f16 = horizontal_sum_and_f32_to_f16(targets8[3][3]);
let target00: f16 = horizontal_sum_f32_to_f16(targets8[0][0]);
let target01: f16 = horizontal_sum_f32_to_f16(targets8[0][1]);
let target02: f16 = horizontal_sum_f32_to_f16(targets8[0][2]);
let target03: f16 = horizontal_sum_f32_to_f16(targets8[0][3]);
let target10: f16 = horizontal_sum_f32_to_f16(targets8[1][0]);
let target11: f16 = horizontal_sum_f32_to_f16(targets8[1][1]);
let target12: f16 = horizontal_sum_f32_to_f16(targets8[1][2]);
let target13: f16 = horizontal_sum_f32_to_f16(targets8[1][3]);
let target20: f16 = horizontal_sum_f32_to_f16(targets8[2][0]);
let target21: f16 = horizontal_sum_f32_to_f16(targets8[2][1]);
let target22: f16 = horizontal_sum_f32_to_f16(targets8[2][2]);
let target23: f16 = horizontal_sum_f32_to_f16(targets8[2][3]);
let target30: f16 = horizontal_sum_f32_to_f16(targets8[3][0]);
let target31: f16 = horizontal_sum_f32_to_f16(targets8[3][1]);
let target32: f16 = horizontal_sum_f32_to_f16(targets8[3][2]);
let target33: f16 = horizontal_sum_f32_to_f16(targets8[3][3]);
*tgt_data.add(row0 * self_cols_capacity + col0) += target00; *tgt_data.add(row0 * self_cols_capacity + col0) += target00;
*tgt_data.add(row0 * self_cols_capacity + col1) += target10; *tgt_data.add(row0 * self_cols_capacity + col1) += target10;
@ -1641,33 +1553,23 @@ impl Tensor {
} else { } else {
(self.rows / 4 + 1) as usize (self.rows / 4 + 1) as usize
}; };
let mut sum8s: [[__m256; 4]; 2] = [ let mut sum8s: [[F32x8; 4]; 2] = [
[ [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()],
_mm256_setzero_ps(), [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()],
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
]; ];
let self_data: *const f16 = self.data as *const f16; let self_data: *const f16 = self.data as *const f16;
let other_data: *const f16 = other.data as *const f16; let other_data: *const f16 = other.data as *const f16;
let _ncols_capacity: usize = result.capacity_cols as usize; let _ncols_capacity: usize = result.capacity_cols as usize;
for row in 0..row_its { for row in 0..row_its {
let row: i64 = row as i64; let row: i64 = row as i64;
sum8s[0][0] = _mm256_setzero_ps(); sum8s[0][0] = f32x8_zero();
sum8s[0][1] = _mm256_setzero_ps(); sum8s[0][1] = f32x8_zero();
sum8s[0][2] = _mm256_setzero_ps(); sum8s[0][2] = f32x8_zero();
sum8s[0][3] = _mm256_setzero_ps(); sum8s[0][3] = f32x8_zero();
sum8s[1][0] = _mm256_setzero_ps(); sum8s[1][0] = f32x8_zero();
sum8s[1][1] = _mm256_setzero_ps(); sum8s[1][1] = f32x8_zero();
sum8s[1][2] = _mm256_setzero_ps(); sum8s[1][2] = f32x8_zero();
sum8s[1][3] = _mm256_setzero_ps(); sum8s[1][3] = f32x8_zero();
let row4_0 = row * 4; let row4_0 = row * 4;
let row4_1 = row * 4 + 1; let row4_1 = row * 4 + 1;
let row4_2 = row * 4 + 2; let row4_2 = row * 4 + 2;
@ -1675,8 +1577,8 @@ impl Tensor {
// Loads from (0, column..column+8) // Loads from (0, column..column+8)
#[inline] #[inline]
fn load2(ptr: *const f16, col: usize) -> __m256 { fn load2(ptr: *const f16, col: usize) -> F32x8 {
unsafe { _mm256_cvtph_ps(_mm_loadu_si128(ptr.add(col) as *const __m128i)) } unsafe { i16x8_as_f16_to_f32x8(load_i16x8(ptr.add(col) as *const I16x8)) }
} }
// Loads from (row, column..column+8) // Loads from (row, column..column+8)
#[inline] #[inline]
@ -1686,15 +1588,15 @@ impl Tensor {
col: usize, col: usize,
cols_capacity: i64, cols_capacity: i64,
nrows: i64, nrows: i64,
) -> __m256 { ) -> F32x8 {
unsafe { unsafe {
if row < nrows { if row < nrows {
_mm256_cvtph_ps(_mm_loadu_si128( i16x8_as_f16_to_f32x8(load_i16x8(
ptr.add(row as usize * cols_capacity as usize + col) ptr.add(row as usize * cols_capacity as usize + col)
as *const __m128i, as *const I16x8,
)) ))
} else { } else {
_mm256_setzero_ps() f32x8_zero()
} }
} }
} }
@ -1711,10 +1613,10 @@ impl Tensor {
load2row(self_data, row4_2, col, self.capacity_cols, self.rows); load2row(self_data, row4_2, col, self.capacity_cols, self.rows);
let left_side8_30 = let left_side8_30 =
load2row(self_data, row4_3, col, self.capacity_cols, self.rows); load2row(self_data, row4_3, col, self.capacity_cols, self.rows);
sum8s[0][0] = _mm256_fmadd_ps(left_side8_00, right_side8_0, sum8s[0][0]); sum8s[0][0] = fma_f32x8(left_side8_00, right_side8_0, sum8s[0][0]);
sum8s[0][1] = _mm256_fmadd_ps(left_side8_10, right_side8_0, sum8s[0][1]); sum8s[0][1] = fma_f32x8(left_side8_10, right_side8_0, sum8s[0][1]);
sum8s[0][2] = _mm256_fmadd_ps(left_side8_20, right_side8_0, sum8s[0][2]); sum8s[0][2] = fma_f32x8(left_side8_20, right_side8_0, sum8s[0][2]);
sum8s[0][3] = _mm256_fmadd_ps(left_side8_30, right_side8_0, sum8s[0][3]); sum8s[0][3] = fma_f32x8(left_side8_30, right_side8_0, sum8s[0][3]);
let right_side8_1 = load2(other_data, col2); let right_side8_1 = load2(other_data, col2);
let left_side8_01 = let left_side8_01 =
load2row(self_data, row4_0, col2, self.capacity_cols, self.rows); load2row(self_data, row4_0, col2, self.capacity_cols, self.rows);
@ -1724,15 +1626,19 @@ impl Tensor {
load2row(self_data, row4_2, col2, self.capacity_cols, self.rows); load2row(self_data, row4_2, col2, self.capacity_cols, self.rows);
let left_side8_31 = let left_side8_31 =
load2row(self_data, row4_3, col2, self.capacity_cols, self.rows); load2row(self_data, row4_3, col2, self.capacity_cols, self.rows);
sum8s[1][0] = _mm256_fmadd_ps(left_side8_01, right_side8_1, sum8s[1][0]); sum8s[1][0] = fma_f32x8(left_side8_01, right_side8_1, sum8s[1][0]);
sum8s[1][1] = _mm256_fmadd_ps(left_side8_11, right_side8_1, sum8s[1][1]); sum8s[1][1] = fma_f32x8(left_side8_11, right_side8_1, sum8s[1][1]);
sum8s[1][2] = _mm256_fmadd_ps(left_side8_21, right_side8_1, sum8s[1][2]); sum8s[1][2] = fma_f32x8(left_side8_21, right_side8_1, sum8s[1][2]);
sum8s[1][3] = _mm256_fmadd_ps(left_side8_31, right_side8_1, sum8s[1][3]); sum8s[1][3] = fma_f32x8(left_side8_31, right_side8_1, sum8s[1][3]);
} }
let sum_0: f32 = horizontal_sum(sum8s[0][0]) + horizontal_sum(sum8s[1][0]); let sum_0: f32 =
let sum_1: f32 = horizontal_sum(sum8s[0][1]) + horizontal_sum(sum8s[1][1]); horizontal_sum_f32x8(sum8s[0][0]) + horizontal_sum_f32x8(sum8s[1][0]);
let sum_2: f32 = horizontal_sum(sum8s[0][2]) + horizontal_sum(sum8s[1][2]); let sum_1: f32 =
let sum_3: f32 = horizontal_sum(sum8s[0][3]) + horizontal_sum(sum8s[1][3]); horizontal_sum_f32x8(sum8s[0][1]) + horizontal_sum_f32x8(sum8s[1][1]);
let sum_2: f32 =
horizontal_sum_f32x8(sum8s[0][2]) + horizontal_sum_f32x8(sum8s[1][2]);
let sum_3: f32 =
horizontal_sum_f32x8(sum8s[0][3]) + horizontal_sum_f32x8(sum8s[1][3]);
if row4_0 < result.rows { if row4_0 < result.rows {
result.set_f32(row4_0, 0, sum_0); result.set_f32(row4_0, 0, sum_0);
} }
@ -1770,19 +1676,14 @@ impl Tensor {
let tgt_data: *mut f32 = result.data as *mut f32; let tgt_data: *mut f32 = result.data as *mut f32;
let ncols_capacity: usize = result.capacity_cols as usize; let ncols_capacity: usize = result.capacity_cols as usize;
let mut sum8s: [__m256; 4] = [ let mut sum8s: [F32x8; 4] = [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()];
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
];
for row in 0..row_its { for row in 0..row_its {
let row: i64 = row as i64; let row: i64 = row as i64;
sum8s[0] = _mm256_setzero_ps(); sum8s[0] = f32x8_zero();
sum8s[1] = _mm256_setzero_ps(); sum8s[1] = f32x8_zero();
sum8s[2] = _mm256_setzero_ps(); sum8s[2] = f32x8_zero();
sum8s[3] = _mm256_setzero_ps(); sum8s[3] = f32x8_zero();
let row4_0 = row * 4; let row4_0 = row * 4;
let row4_1 = row * 4 + 1; let row4_1 = row * 4 + 1;
let row4_2 = row * 4 + 2; let row4_2 = row * 4 + 2;
@ -1790,34 +1691,37 @@ impl Tensor {
for col in 0..col_its { for col in 0..col_its {
let col = col * 8; let col = col * 8;
let right_side8 = _mm256_loadu_ps(other_data.add(col)); let right_side8 = load_f32x8(other_data.add(col) as *const F32x8);
let left_side8_0 = _mm256_loadu_ps( let left_side8_0 =
self_data.add((row4_0 * self.capacity_cols) as usize + col), load_f32x8(self_data.add((row4_0 * self.capacity_cols) as usize + col)
); as *const F32x8);
let left_side8_1 = if row4_1 < self.rows { let left_side8_1 = if row4_1 < self.rows {
_mm256_loadu_ps(self_data.add((row4_1 * self.capacity_cols) as usize + col)) load_f32x8(self_data.add((row4_1 * self.capacity_cols) as usize + col)
as *const F32x8)
} else { } else {
_mm256_setzero_ps() f32x8_zero()
}; };
let left_side8_2 = if row4_2 < self.rows { let left_side8_2 = if row4_2 < self.rows {
_mm256_loadu_ps(self_data.add((row4_2 * self.capacity_cols) as usize + col)) load_f32x8(self_data.add((row4_2 * self.capacity_cols) as usize + col)
as *const F32x8)
} else { } else {
_mm256_setzero_ps() f32x8_zero()
}; };
let left_side8_3 = if row4_3 < self.rows { let left_side8_3 = if row4_3 < self.rows {
_mm256_loadu_ps(self_data.add((row4_3 * self.capacity_cols) as usize + col)) load_f32x8(self_data.add((row4_3 * self.capacity_cols) as usize + col)
as *const F32x8)
} else { } else {
_mm256_setzero_ps() f32x8_zero()
}; };
sum8s[0] = _mm256_fmadd_ps(left_side8_0, right_side8, sum8s[0]); sum8s[0] = fma_f32x8(left_side8_0, right_side8, sum8s[0]);
sum8s[1] = _mm256_fmadd_ps(left_side8_1, right_side8, sum8s[1]); sum8s[1] = fma_f32x8(left_side8_1, right_side8, sum8s[1]);
sum8s[2] = _mm256_fmadd_ps(left_side8_2, right_side8, sum8s[2]); sum8s[2] = fma_f32x8(left_side8_2, right_side8, sum8s[2]);
sum8s[3] = _mm256_fmadd_ps(left_side8_3, right_side8, sum8s[3]); sum8s[3] = fma_f32x8(left_side8_3, right_side8, sum8s[3]);
} }
let sum_0: f32 = horizontal_sum(sum8s[0]); let sum_0: f32 = horizontal_sum_f32x8(sum8s[0]);
let sum_1: f32 = horizontal_sum(sum8s[1]); let sum_1: f32 = horizontal_sum_f32x8(sum8s[1]);
let sum_2: f32 = horizontal_sum(sum8s[2]); let sum_2: f32 = horizontal_sum_f32x8(sum8s[2]);
let sum_3: f32 = horizontal_sum(sum8s[3]); let sum_3: f32 = horizontal_sum_f32x8(sum8s[3]);
if row4_0 < result.rows { if row4_0 < result.rows {
*(tgt_data.add(row4_0 as usize * ncols_capacity)) = sum_0; *(tgt_data.add(row4_0 as usize * ncols_capacity)) = sum_0;
} }
@ -1871,10 +1775,10 @@ impl Tensor {
for col in 0..other.cols { for col in 0..other.cols {
let col = col as usize; let col = col as usize;
let mut sum8: __m256 = _mm256_setzero_ps(); let mut sum8: F32x8 = f32x8_zero();
for row8 in 0..col_its { for row8 in 0..col_its {
let row = row8 * 8; let row = row8 * 8;
let left = _mm256_loadu_ps(left_data.add(row)); let left = load_f32x8(left_data.add(row) as *const F32x8);
let mut r = [0.0f32; 8]; let mut r = [0.0f32; 8];
// i hate you clippy because you ask me // i hate you clippy because you ask me
// to make code more unreadable // to make code more unreadable
@ -1885,17 +1789,16 @@ impl Tensor {
} }
} }
let right = if row + 8 <= other.rows as usize { let right = if row + 8 <= other.rows as usize {
_mm256_i32gather_ps( gather_f32x8(
right_data.add(row * other_capacity_cols + col), right_data.add(row * other_capacity_cols + col),
_mm256_set_epi32(o7, o6, o5, o4, o3, o2, o1, o0), i32x8_from_values(o7, o6, o5, o4, o3, o2, o1, o0),
1,
) )
} else { } else {
_mm256_loadu_ps(r.as_ptr()) load_f32x8(r.as_ptr() as *const F32x8)
}; };
sum8 = _mm256_fmadd_ps(left, right, sum8); sum8 = fma_f32x8(left, right, sum8);
} }
*tgt_data.add(col) = horizontal_sum(sum8); *tgt_data.add(col) = horizontal_sum_f32x8(sum8);
} }
result result
} }
@ -2159,11 +2062,14 @@ impl Tensor {
for row in 0..self.rows { for row in 0..self.rows {
for col in 0..cols_it { for col in 0..cols_it {
let col = col * 8; let col = col * 8;
let val8: __m128i = let val8: I16x8 =
_mm_loadu_si128(self_data.add((row * self_capacity_cols + col) as usize) load_i16x8(self_data.add((row * self_capacity_cols + col) as usize)
as *const __m128i); as *const I16x8);
let val8: __m256 = _mm256_cvtph_ps(val8); let val8: F32x8 = i16x8_as_f16_to_f32x8(val8);
_mm256_storeu_ps(tgt_data.add((row * tgt_capacity_cols + col) as usize), val8); store_f32x8(
tgt_data.add((row * tgt_capacity_cols + col) as usize) as *mut F32x8,
val8,
);
} }
} }
result result
@ -2209,11 +2115,12 @@ impl Tensor {
for row in 0..self.rows { for row in 0..self.rows {
for col in 0..cols_it { for col in 0..cols_it {
let col = col * 8; let col = col * 8;
let val8: __m256 = let val8: F32x8 =
_mm256_loadu_ps(self_data.add((row * self_capacity_cols + col) as usize)); load_f32x8(self_data.add((row * self_capacity_cols + col) as usize)
let val8: __m128i = _mm256_cvtps_ph(val8, 0); as *const F32x8);
_mm_storeu_si128( let val8: I16x8 = f32x8_to_i16x8_as_f16(val8);
tgt_data.add((row * tgt_capacity_cols + col) as usize) as *mut __m128i, store_i16x8(
tgt_data.add((row * tgt_capacity_cols + col) as usize) as *mut I16x8,
val8, val8,
); );
} }

Loading…
Cancel
Save