@ -17,6 +17,7 @@
* If it ' s "XXX_inplace" , then it has a & mut self and it modifies the tensor in place .
* If it ' s "XXX_inplace" , then it has a & mut self and it modifies the tensor in place .
* /
* /
use crate ::simd_support ::* ;
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
use crate ::tensor_opencl_support ::{ OpenCL , OpenCLError , OpenCLEvent , OpenCLTensor } ;
use crate ::tensor_opencl_support ::{ OpenCL , OpenCLError , OpenCLEvent , OpenCLTensor } ;
use crate ::unpickler ;
use crate ::unpickler ;
@ -25,7 +26,6 @@ use half::f16;
use rand ::Rng ;
use rand ::Rng ;
use rayon ::prelude ::* ;
use rayon ::prelude ::* ;
use std ::alloc ::Layout ;
use std ::alloc ::Layout ;
use std ::arch ::x86_64 ::* ;
use std ::io ::{ Read , Seek } ;
use std ::io ::{ Read , Seek } ;
use std ::path ::{ Path , PathBuf } ;
use std ::path ::{ Path , PathBuf } ;
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
@ -175,28 +175,6 @@ fn compute_capacity_cols_f16(cols: i64) -> i64 {
}
}
}
}
#[ inline ]
fn horizontal_sum ( mut ymm : __m256 ) -> f32 {
unsafe {
let ymm2 = _mm256_permute2f128_ps ( ymm , ymm , 1 ) ;
ymm = _mm256_add_ps ( ymm , ymm2 ) ;
ymm = _mm256_hadd_ps ( ymm , ymm ) ;
ymm = _mm256_hadd_ps ( ymm , ymm ) ;
_mm256_cvtss_f32 ( ymm )
}
}
#[ inline ]
fn horizontal_sum_f32_to_f16 ( mut ymm : __m256 ) -> f16 {
unsafe {
let ymm2 = _mm256_permute2f128_ps ( ymm , ymm , 1 ) ;
ymm = _mm256_add_ps ( ymm , ymm2 ) ;
ymm = _mm256_hadd_ps ( ymm , ymm ) ;
ymm = _mm256_hadd_ps ( ymm , ymm ) ;
f16 ::from_f32 ( _mm256_cvtss_f32 ( ymm ) )
}
}
impl Tensor {
impl Tensor {
#[ inline ]
#[ inline ]
pub fn assume_on_gpu ( & self ) {
pub fn assume_on_gpu ( & self ) {
@ -938,19 +916,24 @@ impl Tensor {
let i2_self_cols = i2 * self_cols_capacity ;
let i2_self_cols = i2 * self_cols_capacity ;
let i2_src_cols = i2 * src_cols_capacity ;
let i2_src_cols = i2 * src_cols_capacity ;
for k2 in k .. std ::cmp ::min ( k + ITEMS_PER_CACHE_LINE , src_cols ) {
for k2 in k .. std ::cmp ::min ( k + ITEMS_PER_CACHE_LINE , src_cols ) {
let other_value8 : __m256 = _mm256_loadu_ps (
let other_value8 : F32x8 = load_f32x8 (
other_data . add ( k2 * other_cols_capacity + col ) ,
other_data . add ( k2 * other_cols_capacity + col )
as * const F32x8 ,
) ;
let src_value8_broadcast : F32x8 =
f32x8_singleton ( * src_data . add ( i2_src_cols + k2 ) ) ;
let tgt_value8 : F32x8 = load_f32x8 (
tgt_data . add ( i2_self_cols + col ) as * const F32x8 ,
) ;
) ;
let src_value8_broadcast : __m256 =
let result8 : F32x8 = fma_f32x8 (
_mm256_broadcast_ss ( & * src_data . add ( i2_src_cols + k2 ) ) ;
let tgt_value8 : __m256 =
_mm256_loadu_ps ( tgt_data . add ( i2_self_cols + col ) ) ;
let result8 : __m256 = _mm256_fmadd_ps (
src_value8_broadcast ,
src_value8_broadcast ,
other_value8 ,
other_value8 ,
tgt_value8 ,
tgt_value8 ,
) ;
) ;
_mm256_storeu_ps ( tgt_data . add ( i2_self_cols + col ) , result8 ) ;
store_f32x8 (
tgt_data . add ( i2_self_cols + col ) as * mut F32x8 ,
result8 ,
) ;
}
}
}
}
k + = ITEMS_PER_CACHE_LINE ;
k + = ITEMS_PER_CACHE_LINE ;
@ -993,23 +976,20 @@ impl Tensor {
let i2_self_cols = i2 * self_cols ;
let i2_self_cols = i2 * self_cols ;
let i2_src_cols = i2 * src_cols ;
let i2_src_cols = i2 * src_cols ;
for k2 in k .. k + ITEMS_PER_CACHE_LINE {
for k2 in k .. k + ITEMS_PER_CACHE_LINE {
let other_value8 : __m256 = _mm256_cvtph_ps ( _mm_loadu_si12 8(
let other_value8 : F32x8 = i16x8_as_f16_to_f32x8 ( load_i16x 8(
other_data . add ( k2 * other_cols + col ) as * const _ ,
other_data . add ( k2 * other_cols + col ) as * const _ ,
) ) ;
) ) ;
let src_value8 : f16 = * src_data . add ( i2_src_cols + k2 ) ;
let src_value8 : f16 = * src_data . add ( i2_src_cols + k2 ) ;
let src_value8_broadcast : __m256 =
let src_value8_broadcast : F32x8 =
_mm256_broadcast_ss( & src_value8 . to_f32 ( ) ) ;
f32x8_singleton( src_value8 . to_f32 ( ) ) ;
let tgt_value8 : __m256 = _mm256_cvtph_ps ( _mm_loadu_si12 8(
let tgt_value8 : F32x8 = i16x8_as_f16_to_f32x8 ( load_i16x 8(
tgt_data . add ( i2_self_cols + col ) as * const _ ,
tgt_data . add ( i2_self_cols + col ) as * const _ ,
) ) ;
) ) ;
let result8 : __m256 = _mm256_fmadd_ps (
let result8 : F32x8 =
src_value8_broadcast ,
fma_f32x8 ( src_value8_broadcast , other_value8 , tgt_value8 ) ;
other_value8 ,
let result8_packed : I16x8 = f32x8_to_i16x8_as_f16 ( result8 ) ;
tgt_value8 ,
store_i16x8 (
) ;
tgt_data . add ( i2_self_cols + col ) as * mut I16x8 ,
let result8_packed : __m128i = _mm256_cvtps_ph ( result8 , 0 ) ;
_mm_storeu_si128 (
tgt_data . add ( i2_self_cols + col ) as * mut _ ,
result8_packed ,
result8_packed ,
) ;
) ;
}
}
@ -1186,137 +1166,109 @@ impl Tensor {
let col1 = col * 4 + 1 ;
let col1 = col * 4 + 1 ;
let col2 = col * 4 + 2 ;
let col2 = col * 4 + 2 ;
let col3 = col * 4 + 3 ;
let col3 = col * 4 + 3 ;
let mut targets8 : [ [ __m256 ; 4 ] ; 4 ] = [
let mut targets8 : [ [ F32x8 ; 4 ] ; 4 ] = [
[
[ f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) ] ,
_mm256_setzero_ps ( ) ,
[ f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) ] ,
_mm256_setzero_ps ( ) ,
[ f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) ] ,
_mm256_setzero_ps ( ) ,
[ f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) ] ,
_mm256_setzero_ps ( ) ,
] ,
[
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
] ,
[
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
] ,
[
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
] ,
] ;
] ;
for p in 0 .. src_cols_its {
for p in 0 .. src_cols_its {
let other8_0 : __m256 = _mm256_loadu_ps (
let other8_0 : F32x8 = load_f32x8 (
other_data
other_data
. add ( col0 * other_cols_capacity + p * ITEMS_PER_LINE ) ,
. add ( col0 * other_cols_capacity + p * ITEMS_PER_LINE )
as * const F32x8 ,
) ;
) ;
let other8_1 : __m256 =
let other8_1 : F32x8 =
if col1 < other_rows {
if col1 < other_rows {
_mm256_loadu_ps ( other_data . add (
load_f32x8 ( other_data . add (
col1 * other_cols_capacity + p * ITEMS_PER_LINE ,
col1 * other_cols_capacity + p * ITEMS_PER_LINE ,
) )
)
as * const F32x8 )
} else {
} else {
_mm256_setzero_ps ( )
f32x8_zero ( )
} ;
} ;
let other8_2 : __m256 =
let other8_2 : F32x8 =
if col2 < other_rows {
if col2 < other_rows {
_mm256_loadu_ps ( other_data . add (
load_f32x8 ( other_data . add (
col2 * other_cols_capacity + p * ITEMS_PER_LINE ,
col2 * other_cols_capacity + p * ITEMS_PER_LINE ,
) )
)
as * const F32x8 )
} else {
} else {
_mm256_setzero_ps ( )
f32x8_zero ( )
} ;
} ;
let other8_3 : __m256 =
let other8_3 : F32x8 =
if col3 < other_rows {
if col3 < other_rows {
_mm256_loadu_ps ( other_data . add (
load_f32x8 ( other_data . add (
col3 * other_cols_capacity + p * ITEMS_PER_LINE ,
col3 * other_cols_capacity + p * ITEMS_PER_LINE ,
) )
)
as * const F32x8 )
} else {
} else {
_mm256_setzero_ps ( )
f32x8_zero ( )
} ;
} ;
let src8_0 : __m256 = _mm256_loadu_ps (
let src8_0 : F32x8 = load_f32x8 (
src_data . add ( row0 * src_cols_capacity + p * ITEMS_PER_LINE ) ,
src_data . add ( row0 * src_cols_capacity + p * ITEMS_PER_LINE )
as * const F32x8 ,
) ;
) ;
let src8_1 : __m256 = if row1 < src_rows {
let src8_1 : F32x8 = if row1 < src_rows {
_mm256_loadu_ps (
load_f32x8 (
src_data
src_data
. add ( row1 * src_cols_capacity + p * ITEMS_PER_LINE ) ,
. add ( row1 * src_cols_capacity + p * ITEMS_PER_LINE )
as * const F32x8 ,
)
)
} else {
} else {
_mm256_setzero_ps ( )
f32x8_zero ( )
} ;
} ;
let src8_2 : __m256 = if row2 < src_rows {
let src8_2 : F32x8 = if row2 < src_rows {
_mm256_loadu_ps (
load_f32x8 (
src_data
src_data
. add ( row2 * src_cols_capacity + p * ITEMS_PER_LINE ) ,
. add ( row2 * src_cols_capacity + p * ITEMS_PER_LINE )
as * const F32x8 ,
)
)
} else {
} else {
_mm256_setzero_ps ( )
f32x8_zero ( )
} ;
} ;
let src8_3 : __m256 = if row3 < src_rows {
let src8_3 : F32x8 = if row3 < src_rows {
_mm256_loadu_ps (
load_f32x8 (
src_data
src_data
. add ( row3 * src_cols_capacity + p * ITEMS_PER_LINE ) ,
. add ( row3 * src_cols_capacity + p * ITEMS_PER_LINE )
as * const F32x8 ,
)
)
} else {
} else {
_mm256_setzero_ps ( )
f32x8_zero ( )
} ;
} ;
targets8 [ 0 ] [ 0 ] =
targets8 [ 0 ] [ 0 ] = fma_f32x8 ( src8_0 , other8_0 , targets8 [ 0 ] [ 0 ] ) ;
_mm256_fmadd_ps ( src8_0 , other8_0 , targets8 [ 0 ] [ 0 ] ) ;
targets8 [ 0 ] [ 1 ] = fma_f32x8 ( src8_1 , other8_0 , targets8 [ 0 ] [ 1 ] ) ;
targets8 [ 0 ] [ 1 ] =
targets8 [ 0 ] [ 2 ] = fma_f32x8 ( src8_2 , other8_0 , targets8 [ 0 ] [ 2 ] ) ;
_mm256_fmadd_ps ( src8_1 , other8_0 , targets8 [ 0 ] [ 1 ] ) ;
targets8 [ 0 ] [ 3 ] = fma_f32x8 ( src8_3 , other8_0 , targets8 [ 0 ] [ 3 ] ) ;
targets8 [ 0 ] [ 2 ] =
targets8 [ 1 ] [ 0 ] = fma_f32x8 ( src8_0 , other8_1 , targets8 [ 1 ] [ 0 ] ) ;
_mm256_fmadd_ps ( src8_2 , other8_0 , targets8 [ 0 ] [ 2 ] ) ;
targets8 [ 1 ] [ 1 ] = fma_f32x8 ( src8_1 , other8_1 , targets8 [ 1 ] [ 1 ] ) ;
targets8 [ 0 ] [ 3 ] =
targets8 [ 1 ] [ 2 ] = fma_f32x8 ( src8_2 , other8_1 , targets8 [ 1 ] [ 2 ] ) ;
_mm256_fmadd_ps ( src8_3 , other8_0 , targets8 [ 0 ] [ 3 ] ) ;
targets8 [ 1 ] [ 3 ] = fma_f32x8 ( src8_3 , other8_1 , targets8 [ 1 ] [ 3 ] ) ;
targets8 [ 1 ] [ 0 ] =
targets8 [ 2 ] [ 0 ] = fma_f32x8 ( src8_0 , other8_2 , targets8 [ 2 ] [ 0 ] ) ;
_mm256_fmadd_ps ( src8_0 , other8_1 , targets8 [ 1 ] [ 0 ] ) ;
targets8 [ 2 ] [ 1 ] = fma_f32x8 ( src8_1 , other8_2 , targets8 [ 2 ] [ 1 ] ) ;
targets8 [ 1 ] [ 1 ] =
targets8 [ 2 ] [ 2 ] = fma_f32x8 ( src8_2 , other8_2 , targets8 [ 2 ] [ 2 ] ) ;
_mm256_fmadd_ps ( src8_1 , other8_1 , targets8 [ 1 ] [ 1 ] ) ;
targets8 [ 2 ] [ 3 ] = fma_f32x8 ( src8_3 , other8_2 , targets8 [ 2 ] [ 3 ] ) ;
targets8 [ 1 ] [ 2 ] =
targets8 [ 3 ] [ 0 ] = fma_f32x8 ( src8_0 , other8_3 , targets8 [ 3 ] [ 0 ] ) ;
_mm256_fmadd_ps ( src8_2 , other8_1 , targets8 [ 1 ] [ 2 ] ) ;
targets8 [ 3 ] [ 1 ] = fma_f32x8 ( src8_1 , other8_3 , targets8 [ 3 ] [ 1 ] ) ;
targets8 [ 1 ] [ 3 ] =
targets8 [ 3 ] [ 2 ] = fma_f32x8 ( src8_2 , other8_3 , targets8 [ 3 ] [ 2 ] ) ;
_mm256_fmadd_ps ( src8_3 , other8_1 , targets8 [ 1 ] [ 3 ] ) ;
targets8 [ 3 ] [ 3 ] = fma_f32x8 ( src8_3 , other8_3 , targets8 [ 3 ] [ 3 ] ) ;
targets8 [ 2 ] [ 0 ] =
_mm256_fmadd_ps ( src8_0 , other8_2 , targets8 [ 2 ] [ 0 ] ) ;
targets8 [ 2 ] [ 1 ] =
_mm256_fmadd_ps ( src8_1 , other8_2 , targets8 [ 2 ] [ 1 ] ) ;
targets8 [ 2 ] [ 2 ] =
_mm256_fmadd_ps ( src8_2 , other8_2 , targets8 [ 2 ] [ 2 ] ) ;
targets8 [ 2 ] [ 3 ] =
_mm256_fmadd_ps ( src8_3 , other8_2 , targets8 [ 2 ] [ 3 ] ) ;
targets8 [ 3 ] [ 0 ] =
_mm256_fmadd_ps ( src8_0 , other8_3 , targets8 [ 3 ] [ 0 ] ) ;
targets8 [ 3 ] [ 1 ] =
_mm256_fmadd_ps ( src8_1 , other8_3 , targets8 [ 3 ] [ 1 ] ) ;
targets8 [ 3 ] [ 2 ] =
_mm256_fmadd_ps ( src8_2 , other8_3 , targets8 [ 3 ] [ 2 ] ) ;
targets8 [ 3 ] [ 3 ] =
_mm256_fmadd_ps ( src8_3 , other8_3 , targets8 [ 3 ] [ 3 ] ) ;
}
}
let target00 : f32 = horizontal_sum ( targets8 [ 0 ] [ 0 ] ) ;
let target00 : f32 = horizontal_sum_f32x8 ( targets8 [ 0 ] [ 0 ] ) ;
let target01 : f32 = horizontal_sum ( targets8 [ 0 ] [ 1 ] ) ;
let target01 : f32 = horizontal_sum_f32x8 ( targets8 [ 0 ] [ 1 ] ) ;
let target02 : f32 = horizontal_sum ( targets8 [ 0 ] [ 2 ] ) ;
let target02 : f32 = horizontal_sum_f32x8 ( targets8 [ 0 ] [ 2 ] ) ;
let target03 : f32 = horizontal_sum ( targets8 [ 0 ] [ 3 ] ) ;
let target03 : f32 = horizontal_sum_f32x8 ( targets8 [ 0 ] [ 3 ] ) ;
let target10 : f32 = horizontal_sum ( targets8 [ 1 ] [ 0 ] ) ;
let target10 : f32 = horizontal_sum_f32x8 ( targets8 [ 1 ] [ 0 ] ) ;
let target11 : f32 = horizontal_sum ( targets8 [ 1 ] [ 1 ] ) ;
let target11 : f32 = horizontal_sum_f32x8 ( targets8 [ 1 ] [ 1 ] ) ;
let target12 : f32 = horizontal_sum ( targets8 [ 1 ] [ 2 ] ) ;
let target12 : f32 = horizontal_sum_f32x8 ( targets8 [ 1 ] [ 2 ] ) ;
let target13 : f32 = horizontal_sum ( targets8 [ 1 ] [ 3 ] ) ;
let target13 : f32 = horizontal_sum_f32x8 ( targets8 [ 1 ] [ 3 ] ) ;
let target20 : f32 = horizontal_sum ( targets8 [ 2 ] [ 0 ] ) ;
let target20 : f32 = horizontal_sum_f32x8 ( targets8 [ 2 ] [ 0 ] ) ;
let target21 : f32 = horizontal_sum ( targets8 [ 2 ] [ 1 ] ) ;
let target21 : f32 = horizontal_sum_f32x8 ( targets8 [ 2 ] [ 1 ] ) ;
let target22 : f32 = horizontal_sum ( targets8 [ 2 ] [ 2 ] ) ;
let target22 : f32 = horizontal_sum_f32x8 ( targets8 [ 2 ] [ 2 ] ) ;
let target23 : f32 = horizontal_sum ( targets8 [ 2 ] [ 3 ] ) ;
let target23 : f32 = horizontal_sum_f32x8 ( targets8 [ 2 ] [ 3 ] ) ;
let target30 : f32 = horizontal_sum ( targets8 [ 3 ] [ 0 ] ) ;
let target30 : f32 = horizontal_sum_f32x8 ( targets8 [ 3 ] [ 0 ] ) ;
let target31 : f32 = horizontal_sum ( targets8 [ 3 ] [ 1 ] ) ;
let target31 : f32 = horizontal_sum_f32x8 ( targets8 [ 3 ] [ 1 ] ) ;
let target32 : f32 = horizontal_sum ( targets8 [ 3 ] [ 2 ] ) ;
let target32 : f32 = horizontal_sum_f32x8 ( targets8 [ 3 ] [ 2 ] ) ;
let target33 : f32 = horizontal_sum ( targets8 [ 3 ] [ 3 ] ) ;
let target33 : f32 = horizontal_sum _f32x8 ( targets8 [ 3 ] [ 3 ] ) ;
* tgt_data . add ( row0 * self_cols_capacity + col0 ) + = target00 ;
* tgt_data . add ( row0 * self_cols_capacity + col0 ) + = target00 ;
* tgt_data . add ( row0 * self_cols_capacity + col1 ) + = target10 ;
* tgt_data . add ( row0 * self_cols_capacity + col1 ) + = target10 ;
@ -1407,31 +1359,11 @@ impl Tensor {
let col1 = col * 4 + 1 ;
let col1 = col * 4 + 1 ;
let col2 = col * 4 + 2 ;
let col2 = col * 4 + 2 ;
let col3 = col * 4 + 3 ;
let col3 = col * 4 + 3 ;
let mut targets8 : [ [ __m256 ; 4 ] ; 4 ] = [
let mut targets8 : [ [ F32x8 ; 4 ] ; 4 ] = [
[
[ f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) ] ,
_mm256_setzero_ps ( ) ,
[ f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) ] ,
_mm256_setzero_ps ( ) ,
[ f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) ] ,
_mm256_setzero_ps ( ) ,
[ f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) ] ,
_mm256_setzero_ps ( ) ,
] ,
[
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
] ,
[
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
] ,
[
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
] ,
] ;
] ;
// Loads from (row, column..column+8) and (row+1, column..column+8)
// Loads from (row, column..column+8) and (row+1, column..column+8)
#[ inline ]
#[ inline ]
@ -1441,30 +1373,26 @@ impl Tensor {
column : usize ,
column : usize ,
cols_capacity : usize ,
cols_capacity : usize ,
nrows : usize ,
nrows : usize ,
) -> ( __m256, __m256 ) {
) -> ( F32x8, F32x8 ) {
unsafe {
unsafe {
let ( left , right ) = if row + 1 < nrows {
let ( left , right ) = if row + 1 < nrows {
(
(
_mm_loadu_si128 (
load_i16x8 ( ptr . add ( row * cols_capacity + column )
ptr . add ( row * cols_capacity + column )
as * const I16x8 ) ,
as * const __m128i ,
load_i16x8 (
) ,
_mm_loadu_si128 (
ptr . add ( ( row + 1 ) * cols_capacity + column )
ptr . add ( ( row + 1 ) * cols_capacity + column )
as * const __m128i ,
as * const I16x8 ,
) ,
) ,
)
)
} else {
} else {
(
(
_mm_loadu_si128 (
load_i16x8 ( ptr . add ( row * cols_capacity + column )
ptr . add ( row * cols_capacity + column )
as * const I16x8 ) ,
as * const __m128i ,
i16x8_zero ( ) ,
) ,
_mm_setzero_si128 ( ) ,
)
)
} ;
} ;
let left : __m256 = _mm256_cvtph_ps ( left ) ;
let left : F32x8 = i16x8_as_f16_to_f32x8 ( left ) ;
let right : __m256 = _mm256_cvtph_ps ( right ) ;
let right : F32x8 = i16x8_as_f16_to_f32x8 ( right ) ;
( left , right )
( left , right )
}
}
}
}
@ -1497,55 +1425,39 @@ impl Tensor {
src_cols_capacity ,
src_cols_capacity ,
src_rows ,
src_rows ,
) ;
) ;
targets8 [ 0 ] [ 0 ] =
targets8 [ 0 ] [ 0 ] = fma_f32x8 ( src8_0 , other8_0 , targets8 [ 0 ] [ 0 ] ) ;
_mm256_fmadd_ps ( src8_0 , other8_0 , targets8 [ 0 ] [ 0 ] ) ;
targets8 [ 0 ] [ 1 ] = fma_f32x8 ( src8_1 , other8_0 , targets8 [ 0 ] [ 1 ] ) ;
targets8 [ 0 ] [ 1 ] =
targets8 [ 0 ] [ 2 ] = fma_f32x8 ( src8_2 , other8_0 , targets8 [ 0 ] [ 2 ] ) ;
_mm256_fmadd_ps ( src8_1 , other8_0 , targets8 [ 0 ] [ 1 ] ) ;
targets8 [ 0 ] [ 3 ] = fma_f32x8 ( src8_3 , other8_0 , targets8 [ 0 ] [ 3 ] ) ;
targets8 [ 0 ] [ 2 ] =
targets8 [ 1 ] [ 0 ] = fma_f32x8 ( src8_0 , other8_1 , targets8 [ 1 ] [ 0 ] ) ;
_mm256_fmadd_ps ( src8_2 , other8_0 , targets8 [ 0 ] [ 2 ] ) ;
targets8 [ 1 ] [ 1 ] = fma_f32x8 ( src8_1 , other8_1 , targets8 [ 1 ] [ 1 ] ) ;
targets8 [ 0 ] [ 3 ] =
targets8 [ 1 ] [ 2 ] = fma_f32x8 ( src8_2 , other8_1 , targets8 [ 1 ] [ 2 ] ) ;
_mm256_fmadd_ps ( src8_3 , other8_0 , targets8 [ 0 ] [ 3 ] ) ;
targets8 [ 1 ] [ 3 ] = fma_f32x8 ( src8_3 , other8_1 , targets8 [ 1 ] [ 3 ] ) ;
targets8 [ 1 ] [ 0 ] =
targets8 [ 2 ] [ 0 ] = fma_f32x8 ( src8_0 , other8_2 , targets8 [ 2 ] [ 0 ] ) ;
_mm256_fmadd_ps ( src8_0 , other8_1 , targets8 [ 1 ] [ 0 ] ) ;
targets8 [ 2 ] [ 1 ] = fma_f32x8 ( src8_1 , other8_2 , targets8 [ 2 ] [ 1 ] ) ;
targets8 [ 1 ] [ 1 ] =
targets8 [ 2 ] [ 2 ] = fma_f32x8 ( src8_2 , other8_2 , targets8 [ 2 ] [ 2 ] ) ;
_mm256_fmadd_ps ( src8_1 , other8_1 , targets8 [ 1 ] [ 1 ] ) ;
targets8 [ 2 ] [ 3 ] = fma_f32x8 ( src8_3 , other8_2 , targets8 [ 2 ] [ 3 ] ) ;
targets8 [ 1 ] [ 2 ] =
targets8 [ 3 ] [ 0 ] = fma_f32x8 ( src8_0 , other8_3 , targets8 [ 3 ] [ 0 ] ) ;
_mm256_fmadd_ps ( src8_2 , other8_1 , targets8 [ 1 ] [ 2 ] ) ;
targets8 [ 3 ] [ 1 ] = fma_f32x8 ( src8_1 , other8_3 , targets8 [ 3 ] [ 1 ] ) ;
targets8 [ 1 ] [ 3 ] =
targets8 [ 3 ] [ 2 ] = fma_f32x8 ( src8_2 , other8_3 , targets8 [ 3 ] [ 2 ] ) ;
_mm256_fmadd_ps ( src8_3 , other8_1 , targets8 [ 1 ] [ 3 ] ) ;
targets8 [ 3 ] [ 3 ] = fma_f32x8 ( src8_3 , other8_3 , targets8 [ 3 ] [ 3 ] ) ;
targets8 [ 2 ] [ 0 ] =
_mm256_fmadd_ps ( src8_0 , other8_2 , targets8 [ 2 ] [ 0 ] ) ;
targets8 [ 2 ] [ 1 ] =
_mm256_fmadd_ps ( src8_1 , other8_2 , targets8 [ 2 ] [ 1 ] ) ;
targets8 [ 2 ] [ 2 ] =
_mm256_fmadd_ps ( src8_2 , other8_2 , targets8 [ 2 ] [ 2 ] ) ;
targets8 [ 2 ] [ 3 ] =
_mm256_fmadd_ps ( src8_3 , other8_2 , targets8 [ 2 ] [ 3 ] ) ;
targets8 [ 3 ] [ 0 ] =
_mm256_fmadd_ps ( src8_0 , other8_3 , targets8 [ 3 ] [ 0 ] ) ;
targets8 [ 3 ] [ 1 ] =
_mm256_fmadd_ps ( src8_1 , other8_3 , targets8 [ 3 ] [ 1 ] ) ;
targets8 [ 3 ] [ 2 ] =
_mm256_fmadd_ps ( src8_2 , other8_3 , targets8 [ 3 ] [ 2 ] ) ;
targets8 [ 3 ] [ 3 ] =
_mm256_fmadd_ps ( src8_3 , other8_3 , targets8 [ 3 ] [ 3 ] ) ;
}
}
let target00 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 0 ] [ 0 ] ) ;
let target00 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 0 ] [ 0 ] ) ;
let target01 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 0 ] [ 1 ] ) ;
let target01 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 0 ] [ 1 ] ) ;
let target02 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 0 ] [ 2 ] ) ;
let target02 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 0 ] [ 2 ] ) ;
let target03 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 0 ] [ 3 ] ) ;
let target03 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 0 ] [ 3 ] ) ;
let target10 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 1 ] [ 0 ] ) ;
let target10 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 1 ] [ 0 ] ) ;
let target11 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 1 ] [ 1 ] ) ;
let target11 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 1 ] [ 1 ] ) ;
let target12 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 1 ] [ 2 ] ) ;
let target12 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 1 ] [ 2 ] ) ;
let target13 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 1 ] [ 3 ] ) ;
let target13 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 1 ] [ 3 ] ) ;
let target20 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 2 ] [ 0 ] ) ;
let target20 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 2 ] [ 0 ] ) ;
let target21 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 2 ] [ 1 ] ) ;
let target21 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 2 ] [ 1 ] ) ;
let target22 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 2 ] [ 2 ] ) ;
let target22 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 2 ] [ 2 ] ) ;
let target23 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 2 ] [ 3 ] ) ;
let target23 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 2 ] [ 3 ] ) ;
let target30 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 3 ] [ 0 ] ) ;
let target30 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 3 ] [ 0 ] ) ;
let target31 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 3 ] [ 1 ] ) ;
let target31 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 3 ] [ 1 ] ) ;
let target32 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 3 ] [ 2 ] ) ;
let target32 : f16 = horizontal_sum_and_f32_to_f16 ( targets8 [ 3 ] [ 2 ] ) ;
let target33 : f16 = horizontal_sum_ f32_to_f16( targets8 [ 3 ] [ 3 ] ) ;
let target33 : f16 = horizontal_sum_ and_ f32_to_f16( targets8 [ 3 ] [ 3 ] ) ;
* tgt_data . add ( row0 * self_cols_capacity + col0 ) + = target00 ;
* tgt_data . add ( row0 * self_cols_capacity + col0 ) + = target00 ;
* tgt_data . add ( row0 * self_cols_capacity + col1 ) + = target10 ;
* tgt_data . add ( row0 * self_cols_capacity + col1 ) + = target10 ;
@ -1641,33 +1553,23 @@ impl Tensor {
} else {
} else {
( self . rows / 4 + 1 ) as usize
( self . rows / 4 + 1 ) as usize
} ;
} ;
let mut sum8s : [ [ __m256 ; 4 ] ; 2 ] = [
let mut sum8s : [ [ F32x8 ; 4 ] ; 2 ] = [
[
[ f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) ] ,
_mm256_setzero_ps ( ) ,
[ f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) ] ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
] ,
[
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
] ,
] ;
] ;
let self_data : * const f16 = self . data as * const f16 ;
let self_data : * const f16 = self . data as * const f16 ;
let other_data : * const f16 = other . data as * const f16 ;
let other_data : * const f16 = other . data as * const f16 ;
let _ncols_capacity : usize = result . capacity_cols as usize ;
let _ncols_capacity : usize = result . capacity_cols as usize ;
for row in 0 .. row_its {
for row in 0 .. row_its {
let row : i64 = row as i64 ;
let row : i64 = row as i64 ;
sum8s [ 0 ] [ 0 ] = _mm256_setzero_ps ( ) ;
sum8s [ 0 ] [ 0 ] = f32x8_zero ( ) ;
sum8s [ 0 ] [ 1 ] = _mm256_setzero_ps ( ) ;
sum8s [ 0 ] [ 1 ] = f32x8_zero ( ) ;
sum8s [ 0 ] [ 2 ] = _mm256_setzero_ps ( ) ;
sum8s [ 0 ] [ 2 ] = f32x8_zero ( ) ;
sum8s [ 0 ] [ 3 ] = _mm256_setzero_ps ( ) ;
sum8s [ 0 ] [ 3 ] = f32x8_zero ( ) ;
sum8s [ 1 ] [ 0 ] = _mm256_setzero_ps ( ) ;
sum8s [ 1 ] [ 0 ] = f32x8_zero ( ) ;
sum8s [ 1 ] [ 1 ] = _mm256_setzero_ps ( ) ;
sum8s [ 1 ] [ 1 ] = f32x8_zero ( ) ;
sum8s [ 1 ] [ 2 ] = _mm256_setzero_ps ( ) ;
sum8s [ 1 ] [ 2 ] = f32x8_zero ( ) ;
sum8s [ 1 ] [ 3 ] = _mm256_setzero_ps ( ) ;
sum8s [ 1 ] [ 3 ] = f32x8_zero ( ) ;
let row4_0 = row * 4 ;
let row4_0 = row * 4 ;
let row4_1 = row * 4 + 1 ;
let row4_1 = row * 4 + 1 ;
let row4_2 = row * 4 + 2 ;
let row4_2 = row * 4 + 2 ;
@ -1675,8 +1577,8 @@ impl Tensor {
// Loads from (0, column..column+8)
// Loads from (0, column..column+8)
#[ inline ]
#[ inline ]
fn load2 ( ptr : * const f16 , col : usize ) -> __m256 {
fn load2 ( ptr : * const f16 , col : usize ) -> F32x8 {
unsafe { _mm256_cvtph_ps( _mm_loadu_si12 8( ptr . add ( col ) as * const __m128i ) ) }
unsafe { i16x8_as_f16_to_f32x8( load_i16x 8( ptr . add ( col ) as * const I16x8 ) ) }
}
}
// Loads from (row, column..column+8)
// Loads from (row, column..column+8)
#[ inline ]
#[ inline ]
@ -1686,15 +1588,15 @@ impl Tensor {
col : usize ,
col : usize ,
cols_capacity : i64 ,
cols_capacity : i64 ,
nrows : i64 ,
nrows : i64 ,
) -> __m256 {
) -> F32x8 {
unsafe {
unsafe {
if row < nrows {
if row < nrows {
_mm256_cvtph_ps( _mm_loadu_si12 8(
i16x8_as_f16_to_f32x8( load_i16x 8(
ptr . add ( row as usize * cols_capacity as usize + col )
ptr . add ( row as usize * cols_capacity as usize + col )
as * const __m128i ,
as * const I16x8 ,
) )
) )
} else {
} else {
_mm256_setzero_ps ( )
f32x8_zero ( )
}
}
}
}
}
}
@ -1711,10 +1613,10 @@ impl Tensor {
load2row ( self_data , row4_2 , col , self . capacity_cols , self . rows ) ;
load2row ( self_data , row4_2 , col , self . capacity_cols , self . rows ) ;
let left_side8_30 =
let left_side8_30 =
load2row ( self_data , row4_3 , col , self . capacity_cols , self . rows ) ;
load2row ( self_data , row4_3 , col , self . capacity_cols , self . rows ) ;
sum8s [ 0 ] [ 0 ] = _mm256_fmadd_ps ( left_side8_00 , right_side8_0 , sum8s [ 0 ] [ 0 ] ) ;
sum8s [ 0 ] [ 0 ] = fma_f32x8 ( left_side8_00 , right_side8_0 , sum8s [ 0 ] [ 0 ] ) ;
sum8s [ 0 ] [ 1 ] = _mm256_fmadd_ps ( left_side8_10 , right_side8_0 , sum8s [ 0 ] [ 1 ] ) ;
sum8s [ 0 ] [ 1 ] = fma_f32x8 ( left_side8_10 , right_side8_0 , sum8s [ 0 ] [ 1 ] ) ;
sum8s [ 0 ] [ 2 ] = _mm256_fmadd_ps ( left_side8_20 , right_side8_0 , sum8s [ 0 ] [ 2 ] ) ;
sum8s [ 0 ] [ 2 ] = fma_f32x8 ( left_side8_20 , right_side8_0 , sum8s [ 0 ] [ 2 ] ) ;
sum8s [ 0 ] [ 3 ] = _mm256_fmadd_ps ( left_side8_30 , right_side8_0 , sum8s [ 0 ] [ 3 ] ) ;
sum8s [ 0 ] [ 3 ] = fma_f32x8 ( left_side8_30 , right_side8_0 , sum8s [ 0 ] [ 3 ] ) ;
let right_side8_1 = load2 ( other_data , col2 ) ;
let right_side8_1 = load2 ( other_data , col2 ) ;
let left_side8_01 =
let left_side8_01 =
load2row ( self_data , row4_0 , col2 , self . capacity_cols , self . rows ) ;
load2row ( self_data , row4_0 , col2 , self . capacity_cols , self . rows ) ;
@ -1724,15 +1626,19 @@ impl Tensor {
load2row ( self_data , row4_2 , col2 , self . capacity_cols , self . rows ) ;
load2row ( self_data , row4_2 , col2 , self . capacity_cols , self . rows ) ;
let left_side8_31 =
let left_side8_31 =
load2row ( self_data , row4_3 , col2 , self . capacity_cols , self . rows ) ;
load2row ( self_data , row4_3 , col2 , self . capacity_cols , self . rows ) ;
sum8s [ 1 ] [ 0 ] = _mm256_fmadd_ps ( left_side8_01 , right_side8_1 , sum8s [ 1 ] [ 0 ] ) ;
sum8s [ 1 ] [ 0 ] = fma_f32x8 ( left_side8_01 , right_side8_1 , sum8s [ 1 ] [ 0 ] ) ;
sum8s [ 1 ] [ 1 ] = _mm256_fmadd_ps ( left_side8_11 , right_side8_1 , sum8s [ 1 ] [ 1 ] ) ;
sum8s [ 1 ] [ 1 ] = fma_f32x8 ( left_side8_11 , right_side8_1 , sum8s [ 1 ] [ 1 ] ) ;
sum8s [ 1 ] [ 2 ] = _mm256_fmadd_ps ( left_side8_21 , right_side8_1 , sum8s [ 1 ] [ 2 ] ) ;
sum8s [ 1 ] [ 2 ] = fma_f32x8 ( left_side8_21 , right_side8_1 , sum8s [ 1 ] [ 2 ] ) ;
sum8s [ 1 ] [ 3 ] = _mm256_fmadd_ps ( left_side8_31 , right_side8_1 , sum8s [ 1 ] [ 3 ] ) ;
sum8s [ 1 ] [ 3 ] = fma_f32x8 ( left_side8_31 , right_side8_1 , sum8s [ 1 ] [ 3 ] ) ;
}
}
let sum_0 : f32 = horizontal_sum ( sum8s [ 0 ] [ 0 ] ) + horizontal_sum ( sum8s [ 1 ] [ 0 ] ) ;
let sum_0 : f32 =
let sum_1 : f32 = horizontal_sum ( sum8s [ 0 ] [ 1 ] ) + horizontal_sum ( sum8s [ 1 ] [ 1 ] ) ;
horizontal_sum_f32x8 ( sum8s [ 0 ] [ 0 ] ) + horizontal_sum_f32x8 ( sum8s [ 1 ] [ 0 ] ) ;
let sum_2 : f32 = horizontal_sum ( sum8s [ 0 ] [ 2 ] ) + horizontal_sum ( sum8s [ 1 ] [ 2 ] ) ;
let sum_1 : f32 =
let sum_3 : f32 = horizontal_sum ( sum8s [ 0 ] [ 3 ] ) + horizontal_sum ( sum8s [ 1 ] [ 3 ] ) ;
horizontal_sum_f32x8 ( sum8s [ 0 ] [ 1 ] ) + horizontal_sum_f32x8 ( sum8s [ 1 ] [ 1 ] ) ;
let sum_2 : f32 =
horizontal_sum_f32x8 ( sum8s [ 0 ] [ 2 ] ) + horizontal_sum_f32x8 ( sum8s [ 1 ] [ 2 ] ) ;
let sum_3 : f32 =
horizontal_sum_f32x8 ( sum8s [ 0 ] [ 3 ] ) + horizontal_sum_f32x8 ( sum8s [ 1 ] [ 3 ] ) ;
if row4_0 < result . rows {
if row4_0 < result . rows {
result . set_f32 ( row4_0 , 0 , sum_0 ) ;
result . set_f32 ( row4_0 , 0 , sum_0 ) ;
}
}
@ -1770,19 +1676,14 @@ impl Tensor {
let tgt_data : * mut f32 = result . data as * mut f32 ;
let tgt_data : * mut f32 = result . data as * mut f32 ;
let ncols_capacity : usize = result . capacity_cols as usize ;
let ncols_capacity : usize = result . capacity_cols as usize ;
let mut sum8s : [ __m256 ; 4 ] = [
let mut sum8s : [ F32x8 ; 4 ] = [ f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) , f32x8_zero ( ) ] ;
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
] ;
for row in 0 .. row_its {
for row in 0 .. row_its {
let row : i64 = row as i64 ;
let row : i64 = row as i64 ;
sum8s [ 0 ] = _mm256_setzero_ps ( ) ;
sum8s [ 0 ] = f32x8_zero ( ) ;
sum8s [ 1 ] = _mm256_setzero_ps ( ) ;
sum8s [ 1 ] = f32x8_zero ( ) ;
sum8s [ 2 ] = _mm256_setzero_ps ( ) ;
sum8s [ 2 ] = f32x8_zero ( ) ;
sum8s [ 3 ] = _mm256_setzero_ps ( ) ;
sum8s [ 3 ] = f32x8_zero ( ) ;
let row4_0 = row * 4 ;
let row4_0 = row * 4 ;
let row4_1 = row * 4 + 1 ;
let row4_1 = row * 4 + 1 ;
let row4_2 = row * 4 + 2 ;
let row4_2 = row * 4 + 2 ;
@ -1790,34 +1691,37 @@ impl Tensor {
for col in 0 .. col_its {
for col in 0 .. col_its {
let col = col * 8 ;
let col = col * 8 ;
let right_side8 = _mm256_loadu_ps ( other_data . add ( col ) ) ;
let right_side8 = load_f32x8 ( other_data . add ( col ) as * const F32x8 ) ;
let left_side8_0 = _mm256_loadu_ps (
let left_side8_0 =
self_data. add ( ( row4_0 * self . capacity_cols ) as usize + col ) ,
load_f32x8( self_data. add ( ( row4_0 * self . capacity_cols ) as usize + col )
) ;
as * const F32x8 ) ;
let left_side8_1 = if row4_1 < self . rows {
let left_side8_1 = if row4_1 < self . rows {
_mm256_loadu_ps ( self_data . add ( ( row4_1 * self . capacity_cols ) as usize + col ) )
load_f32x8 ( self_data . add ( ( row4_1 * self . capacity_cols ) as usize + col )
as * const F32x8 )
} else {
} else {
_mm256_setzero_ps ( )
f32x8_zero ( )
} ;
} ;
let left_side8_2 = if row4_2 < self . rows {
let left_side8_2 = if row4_2 < self . rows {
_mm256_loadu_ps ( self_data . add ( ( row4_2 * self . capacity_cols ) as usize + col ) )
load_f32x8 ( self_data . add ( ( row4_2 * self . capacity_cols ) as usize + col )
as * const F32x8 )
} else {
} else {
_mm256_setzero_ps ( )
f32x8_zero ( )
} ;
} ;
let left_side8_3 = if row4_3 < self . rows {
let left_side8_3 = if row4_3 < self . rows {
_mm256_loadu_ps ( self_data . add ( ( row4_3 * self . capacity_cols ) as usize + col ) )
load_f32x8 ( self_data . add ( ( row4_3 * self . capacity_cols ) as usize + col )
as * const F32x8 )
} else {
} else {
_mm256_setzero_ps ( )
f32x8_zero ( )
} ;
} ;
sum8s [ 0 ] = _mm256_fmadd_ps ( left_side8_0 , right_side8 , sum8s [ 0 ] ) ;
sum8s [ 0 ] = fma_f32x8 ( left_side8_0 , right_side8 , sum8s [ 0 ] ) ;
sum8s [ 1 ] = _mm256_fmadd_ps ( left_side8_1 , right_side8 , sum8s [ 1 ] ) ;
sum8s [ 1 ] = fma_f32x8 ( left_side8_1 , right_side8 , sum8s [ 1 ] ) ;
sum8s [ 2 ] = _mm256_fmadd_ps ( left_side8_2 , right_side8 , sum8s [ 2 ] ) ;
sum8s [ 2 ] = fma_f32x8 ( left_side8_2 , right_side8 , sum8s [ 2 ] ) ;
sum8s [ 3 ] = _mm256_fmadd_ps ( left_side8_3 , right_side8 , sum8s [ 3 ] ) ;
sum8s [ 3 ] = fma_f32x8 ( left_side8_3 , right_side8 , sum8s [ 3 ] ) ;
}
}
let sum_0 : f32 = horizontal_sum ( sum8s [ 0 ] ) ;
let sum_0 : f32 = horizontal_sum _f32x8 ( sum8s [ 0 ] ) ;
let sum_1 : f32 = horizontal_sum ( sum8s [ 1 ] ) ;
let sum_1 : f32 = horizontal_sum _f32x8 ( sum8s [ 1 ] ) ;
let sum_2 : f32 = horizontal_sum ( sum8s [ 2 ] ) ;
let sum_2 : f32 = horizontal_sum _f32x8 ( sum8s [ 2 ] ) ;
let sum_3 : f32 = horizontal_sum ( sum8s [ 3 ] ) ;
let sum_3 : f32 = horizontal_sum _f32x8 ( sum8s [ 3 ] ) ;
if row4_0 < result . rows {
if row4_0 < result . rows {
* ( tgt_data . add ( row4_0 as usize * ncols_capacity ) ) = sum_0 ;
* ( tgt_data . add ( row4_0 as usize * ncols_capacity ) ) = sum_0 ;
}
}
@ -1871,10 +1775,10 @@ impl Tensor {
for col in 0 .. other . cols {
for col in 0 .. other . cols {
let col = col as usize ;
let col = col as usize ;
let mut sum8 : __m256 = _mm256_setzero_ps ( ) ;
let mut sum8 : F32x8 = f32x8_zero ( ) ;
for row8 in 0 .. col_its {
for row8 in 0 .. col_its {
let row = row8 * 8 ;
let row = row8 * 8 ;
let left = _mm256_loadu_ps ( left_data . add ( row ) ) ;
let left = load_f32x8 ( left_data . add ( row ) as * const F32x8 ) ;
let mut r = [ 0.0 f32 ; 8 ] ;
let mut r = [ 0.0 f32 ; 8 ] ;
// i hate you clippy because you ask me
// i hate you clippy because you ask me
// to make code more unreadable
// to make code more unreadable
@ -1885,17 +1789,16 @@ impl Tensor {
}
}
}
}
let right = if row + 8 < = other . rows as usize {
let right = if row + 8 < = other . rows as usize {
_mm256_i32gather_ps (
gather_f32x8 (
right_data . add ( row * other_capacity_cols + col ) ,
right_data . add ( row * other_capacity_cols + col ) ,
_mm256_set_epi32 ( o7 , o6 , o5 , o4 , o3 , o2 , o1 , o0 ) ,
i32x8_from_values ( o7 , o6 , o5 , o4 , o3 , o2 , o1 , o0 ) ,
1 ,
)
)
} else {
} else {
_mm256_loadu_ps ( r . as_ptr ( ) )
load_f32x8 ( r . as_ptr ( ) as * const F32x8 )
} ;
} ;
sum8 = _mm256_fmadd_ps ( left , right , sum8 ) ;
sum8 = fma_f32x8 ( left , right , sum8 ) ;
}
}
* tgt_data . add ( col ) = horizontal_sum ( sum8 ) ;
* tgt_data . add ( col ) = horizontal_sum _f32x8 ( sum8 ) ;
}
}
result
result
}
}
@ -2159,11 +2062,14 @@ impl Tensor {
for row in 0 .. self . rows {
for row in 0 .. self . rows {
for col in 0 .. cols_it {
for col in 0 .. cols_it {
let col = col * 8 ;
let col = col * 8 ;
let val8 : __m128i =
let val8 : I16x8 =
_mm_loadu_si128 ( self_data . add ( ( row * self_capacity_cols + col ) as usize )
load_i16x8 ( self_data . add ( ( row * self_capacity_cols + col ) as usize )
as * const __m128i ) ;
as * const I16x8 ) ;
let val8 : __m256 = _mm256_cvtph_ps ( val8 ) ;
let val8 : F32x8 = i16x8_as_f16_to_f32x8 ( val8 ) ;
_mm256_storeu_ps ( tgt_data . add ( ( row * tgt_capacity_cols + col ) as usize ) , val8 ) ;
store_f32x8 (
tgt_data . add ( ( row * tgt_capacity_cols + col ) as usize ) as * mut F32x8 ,
val8 ,
) ;
}
}
}
}
result
result
@ -2209,11 +2115,12 @@ impl Tensor {
for row in 0 .. self . rows {
for row in 0 .. self . rows {
for col in 0 .. cols_it {
for col in 0 .. cols_it {
let col = col * 8 ;
let col = col * 8 ;
let val8 : __m256 =
let val8 : F32x8 =
_mm256_loadu_ps ( self_data . add ( ( row * self_capacity_cols + col ) as usize ) ) ;
load_f32x8 ( self_data . add ( ( row * self_capacity_cols + col ) as usize )
let val8 : __m128i = _mm256_cvtps_ph ( val8 , 0 ) ;
as * const F32x8 ) ;
_mm_storeu_si128 (
let val8 : I16x8 = f32x8_to_i16x8_as_f16 ( val8 ) ;
tgt_data . add ( ( row * tgt_capacity_cols + col ) as usize ) as * mut __m128i ,
store_i16x8 (
tgt_data . add ( ( row * tgt_capacity_cols + col ) as usize ) as * mut I16x8 ,
val8 ,
val8 ,
) ;
) ;
}
}