Rewrite the matrix multiplication.

This is something like ~10 times faster than the old one. But
surprisingly this didn't have much impact on text generation time. Maybe
most of the remaining slowness is no more from matrix multiplication.

Also this slowed down CPU implementation. I think I'll try adding
another kernel later for CPU OpenCL.
master
Mikko Juola 3 years ago
parent 862d4a15d6
commit 8c64313fec

@ -2152,6 +2152,33 @@ mod tests {
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_matrix_mul_transposed_is_close_to_cpu_matrix_mul_transposed_512x1024() {
let cl = OpenCL::new(false, 0).unwrap();
let a = Tensor::random(512, 1024, TensorDType::Float32);
let b = Tensor::random(768, 1024, TensorDType::Float32);
let mut a2 = a.to_f16();
let mut b2 = b.to_f16();
let mut c = Tensor::random(512, 768, TensorDType::Float32);
let mut c2 = Tensor::zeros(512, 768, TensorDType::Float32).to_f16();
a2.to_gpu(&cl).unwrap();
b2.to_gpu(&cl).unwrap();
c2.to_gpu(&cl).unwrap();
c.matrix_mul_inplace_transposed(&a, &b);
c2.matrix_mul_inplace_transposed(&a2, &b2);
c2.to_cpu().unwrap();
assert_eq!(c.rows(), c2.rows());
assert_eq!(c.cols(), c2.cols());
for row in 0..c.rows {
for col in 0..c.cols {
assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-1);
}
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_matrix_mul_transposed_is_close_to_cpu_matrix_mul_transposed_1024x1024() {

@ -10,8 +10,8 @@ use thiserror::Error;
#[derive(Debug)]
#[allow(dead_code)]
struct Programs {
matrix_mul_transposed_by_row_f16_program: Program,
matrix_mul_transposed_by_row_f16: Kernel,
matrix_mul_transposed_f16_program: Program,
matrix_mul_transposed_f16: Kernel,
silu_f16_program: Program,
silu_f16: Kernel,
hadamard_product_f16_program: Program,
@ -306,32 +306,39 @@ impl OpenCLTensor {
unsafe { self.buf.cmd().fill(0u16, None).block(false).enq()? };
let prg = self.cl.programs.write().unwrap();
prg.matrix_mul_transposed_by_row_f16
.set_arg(0, self.buf.clone())?;
prg.matrix_mul_transposed_by_row_f16
.set_arg(1, src.buf.clone())?;
prg.matrix_mul_transposed_by_row_f16
prg.matrix_mul_transposed_f16.set_arg(0, self.buf.clone())?;
prg.matrix_mul_transposed_f16.set_arg(1, src.buf.clone())?;
prg.matrix_mul_transposed_f16
.set_arg(2, other.buf.clone())?;
prg.matrix_mul_transposed_by_row_f16
prg.matrix_mul_transposed_f16
.set_arg(3, src.cols_capacity as i32)?;
prg.matrix_mul_transposed_by_row_f16
prg.matrix_mul_transposed_f16
.set_arg(4, other.cols_capacity as i32)?;
prg.matrix_mul_transposed_by_row_f16
prg.matrix_mul_transposed_f16
.set_arg(5, self.cols_capacity as i32)?;
prg.matrix_mul_transposed_by_row_f16
.set_arg(6, self.rows as i32)?;
prg.matrix_mul_transposed_by_row_f16
.set_arg(7, self.cols as i32)?;
prg.matrix_mul_transposed_by_row_f16
.set_arg(8, src.cols as i32)?;
prg.matrix_mul_transposed_f16.set_arg(6, self.rows as i32)?;
prg.matrix_mul_transposed_f16.set_arg(7, self.cols as i32)?;
prg.matrix_mul_transposed_f16.set_arg(8, src.cols as i32)?;
let mut event = Event::empty();
let rows16 = if self.rows % 16 == 0 {
self.rows
} else {
self.rows + 16 - (self.rows % 16)
};
let cols16 = if self.cols % 16 == 0 {
self.cols
} else {
self.cols + 16 - (self.cols % 16)
};
unsafe {
let b = prg
.matrix_mul_transposed_by_row_f16
.matrix_mul_transposed_f16
.cmd()
.queue(&self.queue)
.global_work_size([self.rows as usize, self.cols as usize])
.global_work_size([cols16 as usize, rows16 as usize])
.local_work_size([16, 16])
.enew(&mut event);
b.enq()?;
}
@ -353,11 +360,11 @@ fn make_programs(ctx: &Context, queue: &Queue) -> Result<Programs, OpenCLError>
Ok(program)
}
let matrix_mul_transposed_by_row_f16_program =
make_program_with_src(ctx, MATRIX_MUL_TRANSPOSED_BY_ROW_F16_SRC)?;
let matrix_mul_transposed_by_row_f16 = Kernel::builder()
.program(&matrix_mul_transposed_by_row_f16_program)
.name("matrix_mul_transposed_by_row_f16")
let matrix_mul_transposed_f16_program =
make_program_with_src(ctx, MATRIX_MUL_TRANSPOSED_F16_SRC)?;
let matrix_mul_transposed_f16 = Kernel::builder()
.program(&matrix_mul_transposed_f16_program)
.name("matrix_mul_transposed_f16")
.arg(None::<&Buffer<u16>>)
.arg(None::<&Buffer<u16>>)
.arg(None::<&Buffer<u16>>)
@ -398,8 +405,8 @@ fn make_programs(ctx: &Context, queue: &Queue) -> Result<Programs, OpenCLError>
.queue(queue.clone())
.build()?;
Ok(Programs {
matrix_mul_transposed_by_row_f16_program,
matrix_mul_transposed_by_row_f16,
matrix_mul_transposed_f16_program,
matrix_mul_transposed_f16,
silu_f16_program,
silu_f16,
hadamard_product_f16_program,
@ -409,29 +416,14 @@ fn make_programs(ctx: &Context, queue: &Queue) -> Result<Programs, OpenCLError>
})
}
const MATRIX_MUL_TRANSPOSED_BY_ROW_F16_SRC: &str = r#"
const MATRIX_MUL_TRANSPOSED_F16_SRC: &str = r#"
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
/*
* Matrix multiplication with a transposed second matrix, using 16-bit floats.
*
* One work unit per row.
*
* Assumes that each row in the matrices are zero-padded so that there's space for 32 bytes (or 16
* halfs) of data and we don't need to care if our loops go over the bounds.
*
* Operations are done in float32.
*
* This thing is not very fast right now. I compared with PyTorch and this is like 20x slower. It
* is still much faster than CPU. Not sure PyTorch uses cuBlas but if we could get at least
* somewhere like 50% of that speed I would be happy.
*
* The OpenCL on CPU for Ryzen 3950X seems to easily beat my own AVX2 operations.
*
* TODO: need to read resources like https://cnugteren.github.io/tutorial/pages/page1.html to
* figure out how matrix multiply faster.
*/
__kernel void matrix_mul_transposed_by_row_f16(
__kernel void matrix_mul_transposed_f16(
__global half *tgt,
__global const half *left,
__global const half *right,
@ -442,44 +434,36 @@ __kernel void matrix_mul_transposed_by_row_f16(
const int ncols, // size of target
const int shared_sz
) {
int col_iterations = shared_sz / 16;
if (shared_sz % 16 != 0) {
col_iterations = col_iterations + 1;
__local float lefttile[16][16];
__local float righttile[16][16];
int global_x = get_global_id(0);
int global_y = get_global_id(1);
int local_x = get_local_id(0);
int local_y = get_local_id(1);
int num_tiles = (shared_sz + 15) / 16;
float sum = 0.0f;
for (int t = 0; t < num_tiles; ++t) {
if (global_y < nrows) {
lefttile[local_y][local_x] = vload_half(global_y * left_cols_capacity + t * 16 + local_x, left);
} else {
lefttile[local_y][local_x] = 0.0f;
}
if (global_x < ncols) {
righttile[local_y][local_x] = vload_half(global_x * right_cols_capacity + t * 16 + local_y, right);
} else {
righttile[local_y][local_x] = 0.0f;
}
barrier(CLK_LOCAL_MEM_FENCE);
for (int k = 0; k < 16; ++k) {
sum += lefttile[local_y][k] * righttile[k][local_x];
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int tgt_row = get_global_id(0);
const int tgt_col = get_global_id(1);
float16 sum = 0;
for (int col16 = 0; col16 < col_iterations; col16++) {
const float16 left8 = vload_half16((tgt_row * left_cols_capacity)/16 + col16, (__global const half*) left);
const float16 right8 = vload_half16((tgt_col * right_cols_capacity)/16 + col16, (__global const half*) right);
// hadamard product FMA add it to sum
// const float16 result8 = left8 * right8;
// sum += result8;
sum = fma(left8, right8, sum);
if (global_x < ncols && global_y < nrows) {
vstore_half(sum, global_y * ncols_capacity + global_x, (__global half*) tgt);
}
// Reduce as accurately as possible
float sum1 = sum.s0 + sum.s1;
float sum2 = sum.s2 + sum.s3;
float sum3 = sum.s4 + sum.s5;
float sum4 = sum.s6 + sum.s7;
float sum5 = sum.s8 + sum.s9;
float sum6 = sum.sa + sum.sb;
float sum7 = sum.sc + sum.sd;
float sum8 = sum.se + sum.sf;
float sum11 = sum1 + sum2;
float sum12 = sum3 + sum4;
float sum13 = sum5 + sum6;
float sum14 = sum7 + sum8;
float sum21 = sum11 + sum12;
float sum22 = sum13 + sum14;
float total = sum21 + sum22;
vstore_half(total, 0, (__global half*) &tgt[tgt_row * ncols_capacity + tgt_col]);
}
"#;

Loading…
Cancel
Save