diff --git a/src/tensor.rs b/src/tensor.rs index 176ec93..a65d69b 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -2152,6 +2152,33 @@ mod tests { } } + #[cfg(feature = "opencl")] + #[test] + fn gpu_matrix_mul_transposed_is_close_to_cpu_matrix_mul_transposed_512x1024() { + let cl = OpenCL::new(false, 0).unwrap(); + let a = Tensor::random(512, 1024, TensorDType::Float32); + let b = Tensor::random(768, 1024, TensorDType::Float32); + let mut a2 = a.to_f16(); + let mut b2 = b.to_f16(); + let mut c = Tensor::random(512, 768, TensorDType::Float32); + let mut c2 = Tensor::zeros(512, 768, TensorDType::Float32).to_f16(); + a2.to_gpu(&cl).unwrap(); + b2.to_gpu(&cl).unwrap(); + c2.to_gpu(&cl).unwrap(); + c.matrix_mul_inplace_transposed(&a, &b); + c2.matrix_mul_inplace_transposed(&a2, &b2); + c2.to_cpu().unwrap(); + + assert_eq!(c.rows(), c2.rows()); + assert_eq!(c.cols(), c2.cols()); + + for row in 0..c.rows { + for col in 0..c.cols { + assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-1); + } + } + } + #[cfg(feature = "opencl")] #[test] fn gpu_matrix_mul_transposed_is_close_to_cpu_matrix_mul_transposed_1024x1024() { diff --git a/src/tensor_opencl_support.rs b/src/tensor_opencl_support.rs index 59e8ace..220c6f2 100644 --- a/src/tensor_opencl_support.rs +++ b/src/tensor_opencl_support.rs @@ -10,8 +10,8 @@ use thiserror::Error; #[derive(Debug)] #[allow(dead_code)] struct Programs { - matrix_mul_transposed_by_row_f16_program: Program, - matrix_mul_transposed_by_row_f16: Kernel, + matrix_mul_transposed_f16_program: Program, + matrix_mul_transposed_f16: Kernel, silu_f16_program: Program, silu_f16: Kernel, hadamard_product_f16_program: Program, @@ -306,32 +306,39 @@ impl OpenCLTensor { unsafe { self.buf.cmd().fill(0u16, None).block(false).enq()? }; let prg = self.cl.programs.write().unwrap(); - prg.matrix_mul_transposed_by_row_f16 - .set_arg(0, self.buf.clone())?; - prg.matrix_mul_transposed_by_row_f16 - .set_arg(1, src.buf.clone())?; - prg.matrix_mul_transposed_by_row_f16 + prg.matrix_mul_transposed_f16.set_arg(0, self.buf.clone())?; + prg.matrix_mul_transposed_f16.set_arg(1, src.buf.clone())?; + prg.matrix_mul_transposed_f16 .set_arg(2, other.buf.clone())?; - prg.matrix_mul_transposed_by_row_f16 + prg.matrix_mul_transposed_f16 .set_arg(3, src.cols_capacity as i32)?; - prg.matrix_mul_transposed_by_row_f16 + prg.matrix_mul_transposed_f16 .set_arg(4, other.cols_capacity as i32)?; - prg.matrix_mul_transposed_by_row_f16 + prg.matrix_mul_transposed_f16 .set_arg(5, self.cols_capacity as i32)?; - prg.matrix_mul_transposed_by_row_f16 - .set_arg(6, self.rows as i32)?; - prg.matrix_mul_transposed_by_row_f16 - .set_arg(7, self.cols as i32)?; - prg.matrix_mul_transposed_by_row_f16 - .set_arg(8, src.cols as i32)?; + prg.matrix_mul_transposed_f16.set_arg(6, self.rows as i32)?; + prg.matrix_mul_transposed_f16.set_arg(7, self.cols as i32)?; + prg.matrix_mul_transposed_f16.set_arg(8, src.cols as i32)?; let mut event = Event::empty(); + let rows16 = if self.rows % 16 == 0 { + self.rows + } else { + self.rows + 16 - (self.rows % 16) + }; + let cols16 = if self.cols % 16 == 0 { + self.cols + } else { + self.cols + 16 - (self.cols % 16) + }; + unsafe { let b = prg - .matrix_mul_transposed_by_row_f16 + .matrix_mul_transposed_f16 .cmd() .queue(&self.queue) - .global_work_size([self.rows as usize, self.cols as usize]) + .global_work_size([cols16 as usize, rows16 as usize]) + .local_work_size([16, 16]) .enew(&mut event); b.enq()?; } @@ -353,11 +360,11 @@ fn make_programs(ctx: &Context, queue: &Queue) -> Result Ok(program) } - let matrix_mul_transposed_by_row_f16_program = - make_program_with_src(ctx, MATRIX_MUL_TRANSPOSED_BY_ROW_F16_SRC)?; - let matrix_mul_transposed_by_row_f16 = Kernel::builder() - .program(&matrix_mul_transposed_by_row_f16_program) - .name("matrix_mul_transposed_by_row_f16") + let matrix_mul_transposed_f16_program = + make_program_with_src(ctx, MATRIX_MUL_TRANSPOSED_F16_SRC)?; + let matrix_mul_transposed_f16 = Kernel::builder() + .program(&matrix_mul_transposed_f16_program) + .name("matrix_mul_transposed_f16") .arg(None::<&Buffer>) .arg(None::<&Buffer>) .arg(None::<&Buffer>) @@ -398,8 +405,8 @@ fn make_programs(ctx: &Context, queue: &Queue) -> Result .queue(queue.clone()) .build()?; Ok(Programs { - matrix_mul_transposed_by_row_f16_program, - matrix_mul_transposed_by_row_f16, + matrix_mul_transposed_f16_program, + matrix_mul_transposed_f16, silu_f16_program, silu_f16, hadamard_product_f16_program, @@ -409,29 +416,14 @@ fn make_programs(ctx: &Context, queue: &Queue) -> Result }) } -const MATRIX_MUL_TRANSPOSED_BY_ROW_F16_SRC: &str = r#" +const MATRIX_MUL_TRANSPOSED_F16_SRC: &str = r#" #pragma OPENCL EXTENSION cl_khr_fp16 : enable /* - * Matrix multiplication with a transposed second matrix, using 16-bit floats. - * - * One work unit per row. - * - * Assumes that each row in the matrices are zero-padded so that there's space for 32 bytes (or 16 - * halfs) of data and we don't need to care if our loops go over the bounds. - * - * Operations are done in float32. - * - * This thing is not very fast right now. I compared with PyTorch and this is like 20x slower. It - * is still much faster than CPU. Not sure PyTorch uses cuBlas but if we could get at least - * somewhere like 50% of that speed I would be happy. - * - * The OpenCL on CPU for Ryzen 3950X seems to easily beat my own AVX2 operations. - * * TODO: need to read resources like https://cnugteren.github.io/tutorial/pages/page1.html to * figure out how matrix multiply faster. */ -__kernel void matrix_mul_transposed_by_row_f16( +__kernel void matrix_mul_transposed_f16( __global half *tgt, __global const half *left, __global const half *right, @@ -442,44 +434,36 @@ __kernel void matrix_mul_transposed_by_row_f16( const int ncols, // size of target const int shared_sz ) { - int col_iterations = shared_sz / 16; - if (shared_sz % 16 != 0) { - col_iterations = col_iterations + 1; + __local float lefttile[16][16]; + __local float righttile[16][16]; + + int global_x = get_global_id(0); + int global_y = get_global_id(1); + int local_x = get_local_id(0); + int local_y = get_local_id(1); + int num_tiles = (shared_sz + 15) / 16; + + float sum = 0.0f; + for (int t = 0; t < num_tiles; ++t) { + if (global_y < nrows) { + lefttile[local_y][local_x] = vload_half(global_y * left_cols_capacity + t * 16 + local_x, left); + } else { + lefttile[local_y][local_x] = 0.0f; + } + if (global_x < ncols) { + righttile[local_y][local_x] = vload_half(global_x * right_cols_capacity + t * 16 + local_y, right); + } else { + righttile[local_y][local_x] = 0.0f; + } + barrier(CLK_LOCAL_MEM_FENCE); + for (int k = 0; k < 16; ++k) { + sum += lefttile[local_y][k] * righttile[k][local_x]; + } + barrier(CLK_LOCAL_MEM_FENCE); } - - const int tgt_row = get_global_id(0); - const int tgt_col = get_global_id(1); - - float16 sum = 0; - for (int col16 = 0; col16 < col_iterations; col16++) { - const float16 left8 = vload_half16((tgt_row * left_cols_capacity)/16 + col16, (__global const half*) left); - const float16 right8 = vload_half16((tgt_col * right_cols_capacity)/16 + col16, (__global const half*) right); - // hadamard product FMA add it to sum - // const float16 result8 = left8 * right8; - // sum += result8; - sum = fma(left8, right8, sum); + if (global_x < ncols && global_y < nrows) { + vstore_half(sum, global_y * ncols_capacity + global_x, (__global half*) tgt); } - // Reduce as accurately as possible - float sum1 = sum.s0 + sum.s1; - float sum2 = sum.s2 + sum.s3; - float sum3 = sum.s4 + sum.s5; - float sum4 = sum.s6 + sum.s7; - float sum5 = sum.s8 + sum.s9; - float sum6 = sum.sa + sum.sb; - float sum7 = sum.sc + sum.sd; - float sum8 = sum.se + sum.sf; - - float sum11 = sum1 + sum2; - float sum12 = sum3 + sum4; - float sum13 = sum5 + sum6; - float sum14 = sum7 + sum8; - - float sum21 = sum11 + sum12; - float sum22 = sum13 + sum14; - - float total = sum21 + sum22; - - vstore_half(total, 0, (__global half*) &tgt[tgt_row * ncols_capacity + tgt_col]); } "#;