@ -10,8 +10,8 @@ use thiserror::Error;
#[ derive(Debug) ]
#[ derive(Debug) ]
#[ allow(dead_code) ]
#[ allow(dead_code) ]
struct Programs {
struct Programs {
matrix_mul_transposed_ by_row_ f16_program: Program ,
matrix_mul_transposed_ f16_program: Program ,
matrix_mul_transposed_ by_row_ f16: Kernel ,
matrix_mul_transposed_ f16: Kernel ,
silu_f16_program : Program ,
silu_f16_program : Program ,
silu_f16 : Kernel ,
silu_f16 : Kernel ,
hadamard_product_f16_program : Program ,
hadamard_product_f16_program : Program ,
@ -306,32 +306,39 @@ impl OpenCLTensor {
unsafe { self . buf . cmd ( ) . fill ( 0 u16 , None ) . block ( false ) . enq ( ) ? } ;
unsafe { self . buf . cmd ( ) . fill ( 0 u16 , None ) . block ( false ) . enq ( ) ? } ;
let prg = self . cl . programs . write ( ) . unwrap ( ) ;
let prg = self . cl . programs . write ( ) . unwrap ( ) ;
prg . matrix_mul_transposed_by_row_f16
prg . matrix_mul_transposed_f16 . set_arg ( 0 , self . buf . clone ( ) ) ? ;
. set_arg ( 0 , self . buf . clone ( ) ) ? ;
prg . matrix_mul_transposed_f16 . set_arg ( 1 , src . buf . clone ( ) ) ? ;
prg . matrix_mul_transposed_by_row_f16
prg . matrix_mul_transposed_f16
. set_arg ( 1 , src . buf . clone ( ) ) ? ;
prg . matrix_mul_transposed_by_row_f16
. set_arg ( 2 , other . buf . clone ( ) ) ? ;
. set_arg ( 2 , other . buf . clone ( ) ) ? ;
prg . matrix_mul_transposed_ by_row_ f16
prg . matrix_mul_transposed_f16
. set_arg ( 3 , src . cols_capacity as i32 ) ? ;
. set_arg ( 3 , src . cols_capacity as i32 ) ? ;
prg . matrix_mul_transposed_ by_row_ f16
prg . matrix_mul_transposed_ f16
. set_arg ( 4 , other . cols_capacity as i32 ) ? ;
. set_arg ( 4 , other . cols_capacity as i32 ) ? ;
prg . matrix_mul_transposed_ by_row_ f16
prg . matrix_mul_transposed_ f16
. set_arg ( 5 , self . cols_capacity as i32 ) ? ;
. set_arg ( 5 , self . cols_capacity as i32 ) ? ;
prg . matrix_mul_transposed_by_row_f16
prg . matrix_mul_transposed_f16 . set_arg ( 6 , self . rows as i32 ) ? ;
. set_arg ( 6 , self . rows as i32 ) ? ;
prg . matrix_mul_transposed_f16 . set_arg ( 7 , self . cols as i32 ) ? ;
prg . matrix_mul_transposed_by_row_f16
prg . matrix_mul_transposed_f16 . set_arg ( 8 , src . cols as i32 ) ? ;
. set_arg ( 7 , self . cols as i32 ) ? ;
prg . matrix_mul_transposed_by_row_f16
. set_arg ( 8 , src . cols as i32 ) ? ;
let mut event = Event ::empty ( ) ;
let mut event = Event ::empty ( ) ;
let rows16 = if self . rows % 16 = = 0 {
self . rows
} else {
self . rows + 16 - ( self . rows % 16 )
} ;
let cols16 = if self . cols % 16 = = 0 {
self . cols
} else {
self . cols + 16 - ( self . cols % 16 )
} ;
unsafe {
unsafe {
let b = prg
let b = prg
. matrix_mul_transposed_by_row_f16
. matrix_mul_transposed_ f16
. cmd ( )
. cmd ( )
. queue ( & self . queue )
. queue ( & self . queue )
. global_work_size ( [ self . rows as usize , self . cols as usize ] )
. global_work_size ( [ cols16 as usize , rows16 as usize ] )
. local_work_size ( [ 16 , 16 ] )
. enew ( & mut event ) ;
. enew ( & mut event ) ;
b . enq ( ) ? ;
b . enq ( ) ? ;
}
}
@ -353,11 +360,11 @@ fn make_programs(ctx: &Context, queue: &Queue) -> Result<Programs, OpenCLError>
Ok ( program )
Ok ( program )
}
}
let matrix_mul_transposed_ by_row_ f16_program =
let matrix_mul_transposed_ f16_program =
make_program_with_src ( ctx , MATRIX_MUL_TRANSPOSED_ BY_ROW_ F16_SRC) ? ;
make_program_with_src ( ctx , MATRIX_MUL_TRANSPOSED_ F16_SRC) ? ;
let matrix_mul_transposed_ by_row_ f16 = Kernel ::builder ( )
let matrix_mul_transposed_ f16 = Kernel ::builder ( )
. program ( & matrix_mul_transposed_ by_row_ f16_program)
. program ( & matrix_mul_transposed_ f16_program)
. name ( "matrix_mul_transposed_ by_row_ f16")
. name ( "matrix_mul_transposed_ f16")
. arg ( None ::< & Buffer < u16 > > )
. arg ( None ::< & Buffer < u16 > > )
. arg ( None ::< & Buffer < u16 > > )
. arg ( None ::< & Buffer < u16 > > )
. arg ( None ::< & Buffer < u16 > > )
. arg ( None ::< & Buffer < u16 > > )
@ -398,8 +405,8 @@ fn make_programs(ctx: &Context, queue: &Queue) -> Result<Programs, OpenCLError>
. queue ( queue . clone ( ) )
. queue ( queue . clone ( ) )
. build ( ) ? ;
. build ( ) ? ;
Ok ( Programs {
Ok ( Programs {
matrix_mul_transposed_ by_row_ f16_program,
matrix_mul_transposed_ f16_program,
matrix_mul_transposed_ by_row_ f16,
matrix_mul_transposed_ f16,
silu_f16_program ,
silu_f16_program ,
silu_f16 ,
silu_f16 ,
hadamard_product_f16_program ,
hadamard_product_f16_program ,
@ -409,29 +416,14 @@ fn make_programs(ctx: &Context, queue: &Queue) -> Result<Programs, OpenCLError>
} )
} )
}
}
const MATRIX_MUL_TRANSPOSED_ BY_ROW_ F16_SRC: & str = r #"
const MATRIX_MUL_TRANSPOSED_ F16_SRC: & str = r #"
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
/*
/*
* Matrix multiplication with a transposed second matrix , using 16 - bit floats .
*
* One work unit per row .
*
* Assumes that each row in the matrices are zero - padded so that there ' s space for 32 bytes ( or 16
* halfs ) of data and we don ' t need to care if our loops go over the bounds .
*
* Operations are done in float32 .
*
* This thing is not very fast right now . I compared with PyTorch and this is like 20 x slower . It
* is still much faster than CPU . Not sure PyTorch uses cuBlas but if we could get at least
* somewhere like 50 % of that speed I would be happy .
*
* The OpenCL on CPU for Ryzen 3950 X seems to easily beat my own AVX2 operations .
*
* TODO : need to read resources like https ://cnugteren.github.io/tutorial/pages/page1.html to
* TODO : need to read resources like https ://cnugteren.github.io/tutorial/pages/page1.html to
* figure out how matrix multiply faster .
* figure out how matrix multiply faster .
* /
* /
__kernel void matrix_mul_transposed_ by_row_ f16(
__kernel void matrix_mul_transposed_f16 (
__global half * tgt ,
__global half * tgt ,
__global const half * left ,
__global const half * left ,
__global const half * right ,
__global const half * right ,
@ -442,44 +434,36 @@ __kernel void matrix_mul_transposed_by_row_f16(
const int ncols , // size of target
const int ncols , // size of target
const int shared_sz
const int shared_sz
) {
) {
int col_iterations = shared_sz / 16 ;
__local float lefttile [ 16 ] [ 16 ] ;
if ( shared_sz % 16 ! = 0 ) {
__local float righttile [ 16 ] [ 16 ] ;
col_iterations = col_iterations + 1 ;
int global_x = get_global_id ( 0 ) ;
int global_y = get_global_id ( 1 ) ;
int local_x = get_local_id ( 0 ) ;
int local_y = get_local_id ( 1 ) ;
int num_tiles = ( shared_sz + 15 ) / 16 ;
float sum = 0.0 f ;
for ( int t = 0 ; t < num_tiles ; + + t ) {
if ( global_y < nrows ) {
lefttile [ local_y ] [ local_x ] = vload_half ( global_y * left_cols_capacity + t * 16 + local_x , left ) ;
} else {
lefttile [ local_y ] [ local_x ] = 0.0 f ;
}
if ( global_x < ncols ) {
righttile [ local_y ] [ local_x ] = vload_half ( global_x * right_cols_capacity + t * 16 + local_y , right ) ;
} else {
righttile [ local_y ] [ local_x ] = 0.0 f ;
}
barrier ( CLK_LOCAL_MEM_FENCE ) ;
for ( int k = 0 ; k < 16 ; + + k ) {
sum + = lefttile [ local_y ] [ k ] * righttile [ k ] [ local_x ] ;
}
barrier ( CLK_LOCAL_MEM_FENCE ) ;
}
}
if ( global_x < ncols & & global_y < nrows ) {
const int tgt_row = get_global_id ( 0 ) ;
vstore_half ( sum , global_y * ncols_capacity + global_x , ( __global half * ) tgt ) ;
const int tgt_col = get_global_id ( 1 ) ;
float16 sum = 0 ;
for ( int col16 = 0 ; col16 < col_iterations ; col16 + + ) {
const float16 left8 = vload_half16 ( ( tgt_row * left_cols_capacity ) / 16 + col16 , ( __global const half * ) left ) ;
const float16 right8 = vload_half16 ( ( tgt_col * right_cols_capacity ) / 16 + col16 , ( __global const half * ) right ) ;
// hadamard product FMA add it to sum
// const float16 result8 = left8 * right8;
// sum += result8;
sum = fma ( left8 , right8 , sum ) ;
}
}
// Reduce as accurately as possible
float sum1 = sum . s0 + sum . s1 ;
float sum2 = sum . s2 + sum . s3 ;
float sum3 = sum . s4 + sum . s5 ;
float sum4 = sum . s6 + sum . s7 ;
float sum5 = sum . s8 + sum . s9 ;
float sum6 = sum . sa + sum . sb ;
float sum7 = sum . sc + sum . sd ;
float sum8 = sum . se + sum . sf ;
float sum11 = sum1 + sum2 ;
float sum12 = sum3 + sum4 ;
float sum13 = sum5 + sum6 ;
float sum14 = sum7 + sum8 ;
float sum21 = sum11 + sum12 ;
float sum22 = sum13 + sum14 ;
float total = sum21 + sum22 ;
vstore_half ( total , 0 , ( __global half * ) & tgt [ tgt_row * ncols_capacity + tgt_col ] ) ;
}
}
" #;
" #;