@ -169,6 +169,17 @@ fn horizontal_sum(mut ymm: __m256) -> f32 {
}
}
}
}
#[ inline ]
fn horizontal_sum_f32_to_f16 ( mut ymm : __m256 ) -> f16 {
unsafe {
let ymm2 = _mm256_permute2f128_ps ( ymm , ymm , 1 ) ;
ymm = _mm256_add_ps ( ymm , ymm2 ) ;
ymm = _mm256_hadd_ps ( ymm , ymm ) ;
ymm = _mm256_hadd_ps ( ymm , ymm ) ;
f16 ::from_f32 ( _mm256_cvtss_f32 ( ymm ) )
}
}
impl Tensor {
impl Tensor {
#[ inline ]
#[ inline ]
pub fn assume_on_gpu ( & self ) {
pub fn assume_on_gpu ( & self ) {
@ -824,8 +835,10 @@ impl Tensor {
self . rows , self . cols , other . cols , other . rows
self . rows , self . cols , other . cols , other . rows
) ;
) ;
}
}
// We don't have implementation for f16, so don't use the vector function if we have
// f16
#[ cfg(not(feature = " opencl " )) ]
#[ cfg(not(feature = " opencl " )) ]
if other . rows = = 1 {
if other . rows = = 1 & & other . dtype ! = TensorDType ::Float16 {
return self . matrix_vector_mul_transposed ( other ) ;
return self . matrix_vector_mul_transposed ( other ) ;
}
}
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
@ -1054,8 +1067,7 @@ impl Tensor {
match src . dtype {
match src . dtype {
TensorDType ::Float32 = > {
TensorDType ::Float32 = > {
const CACHE_LINE_SIZE : usize = 32 ;
const ITEMS_PER_LINE : usize = 8 ;
const ITEMS_PER_CACHE_LINE : usize = CACHE_LINE_SIZE / std ::mem ::size_of ::< f32 > ( ) ;
let tgt_data : * mut f32 = self . data as * mut f32 ;
let tgt_data : * mut f32 = self . data as * mut f32 ;
unsafe {
unsafe {
@ -1078,10 +1090,10 @@ impl Tensor {
let src_cols_capacity : usize = src . capacity_cols as usize ;
let src_cols_capacity : usize = src . capacity_cols as usize ;
let self_cols_capacity : usize = self . capacity_cols as usize ;
let self_cols_capacity : usize = self . capacity_cols as usize ;
let src_cols_its = if src_cols % ITEMS_PER_ CACHE_ LINE = = 0 {
let src_cols_its = if src_cols % ITEMS_PER_ LINE = = 0 {
src_cols / ITEMS_PER_ CACHE_ LINE
src_cols / ITEMS_PER_ LINE
} else {
} else {
src_cols / ITEMS_PER_ CACHE_ LINE + 1
src_cols / ITEMS_PER_ LINE + 1
} ;
} ;
let row_its = if self_rows % 4 = = 0 {
let row_its = if self_rows % 4 = = 0 {
self_rows / 4
self_rows / 4
@ -1133,61 +1145,56 @@ impl Tensor {
] ;
] ;
for p in 0 .. src_cols_its {
for p in 0 .. src_cols_its {
let other8_0 : __m256 = _mm256_loadu_ps (
let other8_0 : __m256 = _mm256_loadu_ps (
other_data
other_data . add ( col0 * other_cols_capacity + p * ITEMS_PER_LINE ) ,
. add ( col0 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE ) ,
) ;
) ;
let other8_1 : __m256 =
let other8_1 : __m256 = if col1 < other_rows {
if col1 < other_rows {
_mm256_loadu_ps (
_mm256_loadu_ps( other_data. add (
other_data
col1 * other_cols_capacity + p * ITEMS_PER_ CACHE_ LINE,
. add ( col1 * other_cols_capacity + p * ITEMS_PER_ LINE) ,
) )
)
} else {
} else {
_mm256_setzero_ps ( )
_mm256_setzero_ps ( )
} ;
} ;
let other8_2 : __m256 =
let other8_2 : __m256 = if col2 < other_rows {
if col2 < other_rows {
_mm256_loadu_ps (
_mm256_loadu_ps( other_data. add (
other_data
col2 * other_cols_capacity + p * ITEMS_PER_ CACHE_ LINE,
. add ( col2 * other_cols_capacity + p * ITEMS_PER_ LINE) ,
) )
)
} else {
} else {
_mm256_setzero_ps ( )
_mm256_setzero_ps ( )
} ;
} ;
let other8_3 : __m256 =
let other8_3 : __m256 = if col3 < other_rows {
if col3 < other_rows {
_mm256_loadu_ps (
_mm256_loadu_ps( other_data. add (
other_data
col3 * other_cols_capacity + p * ITEMS_PER_ CACHE_ LINE,
. add ( col3 * other_cols_capacity + p * ITEMS_PER_ LINE) ,
) )
)
} else {
} else {
_mm256_setzero_ps ( )
_mm256_setzero_ps ( )
} ;
} ;
let src8_0 : __m256 = _mm256_loadu_ps (
let src8_0 : __m256 = _mm256_loadu_ps (
src_data
src_data . add ( row0 * src_cols_capacity + p * ITEMS_PER_LINE ) ,
. add ( row0 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE ) ,
) ;
) ;
let src8_1 : __m256 =
let src8_1 : __m256 = if row1 < src_rows {
if row1 < src_rows {
_mm256_loadu_ps (
_mm256_loadu_ps ( src_data . add (
src_data . add ( row1 * src_cols_capacity + p * ITEMS_PER_LINE ) ,
row1 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE ,
)
) )
} else {
} else {
_mm256_setzero_ps ( )
_mm256_setzero_ps ( )
} ;
} ;
let src8_2 : __m256 = if row2 < src_rows {
let src8_2 : __m256 =
_mm256_loadu_ps (
if row2 < src_rows {
src_data . add ( row2 * src_cols_capacity + p * ITEMS_PER_LINE ) ,
_mm256_loadu_ps ( src_data . add (
)
row2 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE ,
} else {
) )
_mm256_setzero_ps ( )
} else {
} ;
_mm256_setzero_ps ( )
let src8_3 : __m256 = if row3 < src_rows {
} ;
_mm256_loadu_ps (
let src8_3 : __m256 =
src_data . add ( row3 * src_cols_capacity + p * ITEMS_PER_LINE ) ,
if row3 < src_rows {
)
_mm256_loadu_ps ( src_data . add (
} else {
row3 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE ,
_mm256_setzero_ps ( )
) )
} ;
} else {
_mm256_setzero_ps ( )
} ;
targets8 [ 0 ] [ 0 ] = _mm256_fmadd_ps ( src8_0 , other8_0 , targets8 [ 0 ] [ 0 ] ) ;
targets8 [ 0 ] [ 0 ] = _mm256_fmadd_ps ( src8_0 , other8_0 , targets8 [ 0 ] [ 0 ] ) ;
targets8 [ 0 ] [ 1 ] = _mm256_fmadd_ps ( src8_1 , other8_0 , targets8 [ 0 ] [ 1 ] ) ;
targets8 [ 0 ] [ 1 ] = _mm256_fmadd_ps ( src8_1 , other8_0 , targets8 [ 0 ] [ 1 ] ) ;
targets8 [ 0 ] [ 2 ] = _mm256_fmadd_ps ( src8_2 , other8_0 , targets8 [ 0 ] [ 2 ] ) ;
targets8 [ 0 ] [ 2 ] = _mm256_fmadd_ps ( src8_2 , other8_0 , targets8 [ 0 ] [ 2 ] ) ;
@ -1248,7 +1255,203 @@ impl Tensor {
}
}
}
}
}
}
TensorDType ::Float16 = > unimplemented! ( ) ,
TensorDType ::Float16 = > {
const ITEMS_PER_LINE : usize = 8 ;
let tgt_data : * mut f16 = self . data as * mut f16 ;
unsafe {
std ::ptr ::write_bytes (
tgt_data ,
0 ,
self . rows as usize * self . capacity_cols as usize ,
) ;
}
let src_data : * const f16 = src . data as * const f16 ;
let other_data : * const f16 = other . data as * const f16 ;
let src_rows : usize = src . rows as usize ;
let src_cols : usize = src . cols as usize ;
let self_rows : usize = self . rows as usize ;
let self_cols : usize = self . cols as usize ;
let _other_cols : usize = other . cols as usize ;
let other_rows : usize = other . rows as usize ;
let other_cols_capacity : usize = other . capacity_cols as usize ;
let src_cols_capacity : usize = src . capacity_cols as usize ;
let self_cols_capacity : usize = self . capacity_cols as usize ;
let src_cols_its = if src_cols % ITEMS_PER_LINE = = 0 {
src_cols / ITEMS_PER_LINE
} else {
src_cols / ITEMS_PER_LINE + 1
} ;
let row_its = if self_rows % 4 = = 0 {
self_rows / 4
} else {
self_rows / 4 + 1
} ;
let self_cols_its = if self_cols % 4 = = 0 {
self_cols / 4
} else {
self_cols / 4 + 1
} ;
unsafe {
for row in 0 .. row_its {
let row0 = row * 4 ;
let row1 = row * 4 + 1 ;
let row2 = row * 4 + 2 ;
let row3 = row * 4 + 3 ;
for col in 0 .. self_cols_its {
let col0 = col * 4 ;
let col1 = col * 4 + 1 ;
let col2 = col * 4 + 2 ;
let col3 = col * 4 + 3 ;
let mut targets8 : [ [ __m256 ; 4 ] ; 4 ] = [
[
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
] ,
[
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
] ,
[
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
] ,
[
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
_mm256_setzero_ps ( ) ,
] ,
] ;
// Loads from (row, column..column+8) and (row+1, column..column+8)
#[ inline ]
fn load2_rows (
ptr : * const f16 ,
row : usize ,
column : usize ,
cols_capacity : usize ,
nrows : usize ,
) -> ( __m256 , __m256 ) {
unsafe {
let ( left , right ) = if row + 1 < nrows {
(
_mm_loadu_si128 ( ptr . add ( row * cols_capacity + column )
as * const __m128i ) ,
_mm_loadu_si128 (
ptr . add ( ( row + 1 ) * cols_capacity + column )
as * const __m128i ,
) ,
)
} else {
(
_mm_loadu_si128 ( ptr . add ( row * cols_capacity + column )
as * const __m128i ) ,
_mm_setzero_si128 ( ) ,
)
} ;
let left : __m256 = _mm256_cvtph_ps ( left ) ;
let right : __m256 = _mm256_cvtph_ps ( right ) ;
( left , right )
}
}
for p in 0 .. src_cols_its {
let ( other8_0 , other8_1 ) = load2_rows (
other_data ,
col0 ,
p * ITEMS_PER_LINE ,
other_cols_capacity ,
other_rows ,
) ;
let ( other8_2 , other8_3 ) = load2_rows (
other_data ,
col2 ,
p * ITEMS_PER_LINE ,
other_cols_capacity ,
other_rows ,
) ;
let ( src8_0 , src8_1 ) = load2_rows (
src_data ,
row0 ,
p * ITEMS_PER_LINE ,
src_cols_capacity ,
src_rows ,
) ;
let ( src8_2 , src8_3 ) = load2_rows (
src_data ,
row2 ,
p * ITEMS_PER_LINE ,
src_cols_capacity ,
src_rows ,
) ;
targets8 [ 0 ] [ 0 ] = _mm256_fmadd_ps ( src8_0 , other8_0 , targets8 [ 0 ] [ 0 ] ) ;
targets8 [ 0 ] [ 1 ] = _mm256_fmadd_ps ( src8_1 , other8_0 , targets8 [ 0 ] [ 1 ] ) ;
targets8 [ 0 ] [ 2 ] = _mm256_fmadd_ps ( src8_2 , other8_0 , targets8 [ 0 ] [ 2 ] ) ;
targets8 [ 0 ] [ 3 ] = _mm256_fmadd_ps ( src8_3 , other8_0 , targets8 [ 0 ] [ 3 ] ) ;
targets8 [ 1 ] [ 0 ] = _mm256_fmadd_ps ( src8_0 , other8_1 , targets8 [ 1 ] [ 0 ] ) ;
targets8 [ 1 ] [ 1 ] = _mm256_fmadd_ps ( src8_1 , other8_1 , targets8 [ 1 ] [ 1 ] ) ;
targets8 [ 1 ] [ 2 ] = _mm256_fmadd_ps ( src8_2 , other8_1 , targets8 [ 1 ] [ 2 ] ) ;
targets8 [ 1 ] [ 3 ] = _mm256_fmadd_ps ( src8_3 , other8_1 , targets8 [ 1 ] [ 3 ] ) ;
targets8 [ 2 ] [ 0 ] = _mm256_fmadd_ps ( src8_0 , other8_2 , targets8 [ 2 ] [ 0 ] ) ;
targets8 [ 2 ] [ 1 ] = _mm256_fmadd_ps ( src8_1 , other8_2 , targets8 [ 2 ] [ 1 ] ) ;
targets8 [ 2 ] [ 2 ] = _mm256_fmadd_ps ( src8_2 , other8_2 , targets8 [ 2 ] [ 2 ] ) ;
targets8 [ 2 ] [ 3 ] = _mm256_fmadd_ps ( src8_3 , other8_2 , targets8 [ 2 ] [ 3 ] ) ;
targets8 [ 3 ] [ 0 ] = _mm256_fmadd_ps ( src8_0 , other8_3 , targets8 [ 3 ] [ 0 ] ) ;
targets8 [ 3 ] [ 1 ] = _mm256_fmadd_ps ( src8_1 , other8_3 , targets8 [ 3 ] [ 1 ] ) ;
targets8 [ 3 ] [ 2 ] = _mm256_fmadd_ps ( src8_2 , other8_3 , targets8 [ 3 ] [ 2 ] ) ;
targets8 [ 3 ] [ 3 ] = _mm256_fmadd_ps ( src8_3 , other8_3 , targets8 [ 3 ] [ 3 ] ) ;
}
let target00 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 0 ] [ 0 ] ) ;
let target01 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 0 ] [ 1 ] ) ;
let target02 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 0 ] [ 2 ] ) ;
let target03 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 0 ] [ 3 ] ) ;
let target10 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 1 ] [ 0 ] ) ;
let target11 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 1 ] [ 1 ] ) ;
let target12 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 1 ] [ 2 ] ) ;
let target13 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 1 ] [ 3 ] ) ;
let target20 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 2 ] [ 0 ] ) ;
let target21 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 2 ] [ 1 ] ) ;
let target22 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 2 ] [ 2 ] ) ;
let target23 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 2 ] [ 3 ] ) ;
let target30 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 3 ] [ 0 ] ) ;
let target31 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 3 ] [ 1 ] ) ;
let target32 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 3 ] [ 2 ] ) ;
let target33 : f16 = horizontal_sum_f32_to_f16 ( targets8 [ 3 ] [ 3 ] ) ;
* tgt_data . add ( row0 * self_cols_capacity + col0 ) + = target00 ;
* tgt_data . add ( row0 * self_cols_capacity + col1 ) + = target10 ;
* tgt_data . add ( row0 * self_cols_capacity + col2 ) + = target20 ;
* tgt_data . add ( row0 * self_cols_capacity + col3 ) + = target30 ;
if row1 < self_rows {
* tgt_data . add ( row1 * self_cols_capacity + col0 ) + = target01 ;
* tgt_data . add ( row1 * self_cols_capacity + col1 ) + = target11 ;
* tgt_data . add ( row1 * self_cols_capacity + col2 ) + = target21 ;
* tgt_data . add ( row1 * self_cols_capacity + col3 ) + = target31 ;
}
if row2 < self_rows {
* tgt_data . add ( row2 * self_cols_capacity + col0 ) + = target02 ;
* tgt_data . add ( row2 * self_cols_capacity + col1 ) + = target12 ;
* tgt_data . add ( row2 * self_cols_capacity + col2 ) + = target22 ;
* tgt_data . add ( row2 * self_cols_capacity + col3 ) + = target32 ;
}
if row3 < self_rows {
* tgt_data . add ( row3 * self_cols_capacity + col0 ) + = target03 ;
* tgt_data . add ( row3 * self_cols_capacity + col1 ) + = target13 ;
* tgt_data . add ( row3 * self_cols_capacity + col2 ) + = target23 ;
* tgt_data . add ( row3 * self_cols_capacity + col3 ) + = target33 ;
}
}
}
}
}
}
}
}
}
@ -2088,6 +2291,36 @@ mod tests {
}
}
}
}
#[ test ]
fn mat_mul_transposed_f32_agrees_mat_mul_transposed_f16 ( ) {
let mut rng = rand ::thread_rng ( ) ;
for _ in 0 .. 1000 {
let a = rng . gen_range ( 1 ..= 128 ) ;
let b = rng . gen_range ( 1 ..= 128 ) ;
let r = rng . gen_range ( 1 ..= 128 ) ;
// Make matrixes AxR and RxB
let a = Tensor ::random ( a , r , TensorDType ::Float32 ) ;
let b = Tensor ::random ( r , b , TensorDType ::Float32 ) ;
let a2 = a . clone ( ) . to_f16 ( ) ;
let b2 = b . clone ( ) . to_f16 ( ) ;
let b_transposed = b . transpose ( ) ;
let b2_transposed = b2 . transpose ( ) ;
let c = a . matrix_mul_transposed ( & b_transposed ) ;
let c2 = a2 . matrix_mul_transposed ( & b2_transposed ) ;
assert_eq! ( c . rows , c2 . rows ) ;
assert_eq! ( c . cols , c2 . cols ) ;
for row in 0 .. c . rows {
for col in 0 .. c . cols {
assert_relative_eq ! ( c . get_f32 ( row , col ) , c2 . get_f32 ( row , col ) , epsilon = 1e-1 ) ;
}
}
}
}
#[ test ]
#[ test ]
fn view_preserves_values ( ) {
fn view_preserves_values ( ) {
fn test_with_type ( dtype : TensorDType ) {
fn test_with_type ( dtype : TensorDType ) {