Add some OpenCL bits.

I wrote an OpenCL matrix_mul_inplace_transposed. It is much faster than
my CPU implementation for GPU, and also quite a lot faster on CPU
(OpenCL runs on CPU and GPU) than my own implementation.

Basically it can destroy all of my crappy code. So I think I will be
replacing some of my other operations with this stuff in near future.
broken-opencl-code
Mikko Juola 3 years ago
parent a92017bf56
commit 1a88482988

@ -25,7 +25,9 @@ PyTorch. Well almost, it doesn't unzip them automatically (see below).
# How to run
You will need Rust. Make sure you can run `cargo` from a command line.
You will need Rust. Make sure you can run `cargo` from a command line. In
particular, this is using unstable features so you need nightly rust. Make sure
if you write `cargo --version` it is nightly.
You will need to download LLaMA-7B weights. Refer to https://github.com/facebookresearch/llama/
@ -58,4 +60,10 @@ settings.
This is a hobby thing for me so don't expect updates or help.
* Some other CPU implementations use quantization to reduce the size of weights
* Put some of the operations on the OpenCL GPU
* Put some of the operations on the OpenCL GPU/CPU. I've made some initial
OpenCL code but it is not used in the transformer loop yet. The CPU OpenCL
improves my own AVX2 code by like 100% and massively so on GPU although I am
also like 20x slower than equivalent operation on PyTorch on the same GPU.
* I've heard there is some thing called Tensor Cores on nVidia GPUs. Not
accessible with OpenCL. But might be accessible on Vulkan with a an
extension.

@ -13,11 +13,39 @@ pub fn opencl_benchmarks(c: &mut Criterion) {
let mut orig32 = Tensor::random(4096, 4096, TensorDType::Float16);
let cl = OpenCL::new(false, 0).unwrap();
let mut mul_left = Tensor::random(1024, 1024, TensorDType::Float16);
mul_left.to_gpu(&cl).unwrap();
let mut mul_right = Tensor::random(1024, 1024, TensorDType::Float16);
mul_right.to_gpu(&cl).unwrap();
let mut mul_target = Tensor::zeros(1024, 1024, TensorDType::Float16);
mul_target.to_gpu(&cl).unwrap();
let mut mul_left_cpu = Tensor::random(1024, 1024, TensorDType::Float32);
let mut mul_right_cpu = Tensor::random(1024, 1024, TensorDType::Float32);
let mut mul_target_cpu = Tensor::random(1024, 1024, TensorDType::Float32);
c.bench_function(
"1024x1024 matrix multiplication transposed on OpenCL",
|b| {
b.iter(|| {
mul_target
.matrix_mul_inplace_transposed(black_box(&mul_left), black_box(&mul_right));
mul_target.finish();
})
},
);
c.bench_function("1024x1024 matrix multiplication transposed on CPU", |b| {
b.iter(|| {
let _ = mul_target_cpu.matrix_mul_inplace_transposed(&mul_left_cpu, &mul_right_cpu);
})
});
c.bench_function("1x1 matrix from CPU to OpenCL device and back", |b| {
b.iter(|| {
let _ = orig1.to_gpu(&cl).unwrap();
let _ = orig1.to_cpu();
orig1.process_waiting_for_data();
orig1.finish();
})
});
@ -25,7 +53,7 @@ pub fn opencl_benchmarks(c: &mut Criterion) {
b.iter(|| {
let _ = orig16.to_gpu(&cl).unwrap();
let _ = orig16.to_cpu();
orig16.process_waiting_for_data();
orig16.finish();
})
});
@ -33,7 +61,7 @@ pub fn opencl_benchmarks(c: &mut Criterion) {
b.iter(|| {
let _ = orig32.to_gpu(&cl).unwrap();
let _ = orig32.to_cpu();
orig32.process_waiting_for_data();
orig32.finish();
})
});
}

@ -63,11 +63,12 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
}
#[cfg(feature = "opencl")]
let opencl: Option<OpenCL> = {
let _opencl: Option<OpenCL> = {
let opencl_device = cli.opencl_device.unwrap_or(0);
match OpenCL::new(!be_quiet, opencl_device) {
Err(openclerr) => {
eprintln!("OpenCL error: {}", openclerr);
eprintln!("OpenCL is disabled because it failed to initialize.");
None
}
Ok(opencl) => {

@ -827,9 +827,45 @@ impl Tensor {
}
}
#[cfg(feature = "opencl")]
pub fn is_on_gpu(&self) -> bool {
if self.waiting_for_data.is_some() {
return false;
}
let od = self.opencl_data.read().unwrap();
if od.is_some() {
return true;
}
false
}
#[cfg(feature = "opencl")]
fn matrix_mul_inplace_transposed_gpu(&mut self, src: &Tensor, other: &Tensor) {
let mut self_od = self.opencl_data.write().unwrap();
let src_od = src.opencl_data.read().unwrap();
let other_od = other.opencl_data.read().unwrap();
let self_od: &mut OpenCLTensor = self_od.as_mut().unwrap();
let src_od: &OpenCLTensor = src_od.as_ref().unwrap();
let other_od: &OpenCLTensor = other_od.as_ref().unwrap();
// TODO: if this fails, we panic. Think about if this is alright. I think for now it's
// alright.
self_od
.matrix_mul_inplace_transposed(src_od, other_od)
.unwrap();
std::mem::drop(self_od);
std::mem::drop(src_od);
std::mem::drop(other_od);
}
/// Matrix multiplication done in-place, but the second matrix is transposed.
/// With this, you can avoid using .transpose() on the second matrix.
pub fn matrix_mul_inplace_transposed(&mut self, src: &Tensor, other: &Tensor) {
#[cfg(feature = "opencl")]
if self.is_on_gpu() && src.is_on_gpu() && other.is_on_gpu() {
self.matrix_mul_inplace_transposed_gpu(src, other);
return;
}
self.assume_on_cpu();
src.assume_on_cpu();
other.assume_on_cpu();
@ -1256,6 +1292,9 @@ impl Tensor {
self.data as *const u16,
self.layout,
(self.rows * self.capacity_cols) as usize,
self.rows,
self.cols,
self.capacity_cols,
)?;
self.data = std::ptr::null_mut();
*od = Some(cl_tensor);
@ -1263,7 +1302,7 @@ impl Tensor {
}
#[cfg(feature = "opencl")]
pub fn process_waiting_for_data_mut(&mut self) {
fn process_waiting_for_data_mut(&mut self) {
if let Some(ref wfd) = self.waiting_for_data {
wfd.wait();
let mut od = self.opencl_data.write().unwrap();
@ -1273,7 +1312,7 @@ impl Tensor {
}
#[cfg(feature = "opencl")]
pub fn process_waiting_for_data(&self) {
fn process_waiting_for_data(&self) {
if let Some(ref wfd) = self.waiting_for_data {
wfd.wait();
let mut od = self.opencl_data.write().unwrap();
@ -1281,12 +1320,22 @@ impl Tensor {
}
}
/// Waits until asynchronous all operations on this tensor are done
#[cfg(feature = "opencl")]
pub fn finish(&mut self) {
self.process_waiting_for_data_mut();
let mut od = self.opencl_data.write().unwrap();
if od.is_some() {
od.as_mut().unwrap().wait_until_ready();
}
}
/// Sends a tensor from the GPU to the CPU. This is a no-op if the tensor is already on the
/// CPU.
#[cfg(feature = "opencl")]
pub fn to_cpu(&mut self) -> Result<(), TensorError> {
self.process_waiting_for_data_mut();
let od = self.opencl_data.read().unwrap();
let mut od = self.opencl_data.write().unwrap();
if od.is_none() {
return Ok(());
}
@ -1294,7 +1343,7 @@ impl Tensor {
if data.is_null() {
panic!("to_cpu: Failed to allocate tensor");
}
let ev = od.as_ref().unwrap().data_u16_from_gpu(data as *mut u16)?;
let ev = od.as_mut().unwrap().data_u16_from_gpu(data as *mut u16)?;
self.data = data as *mut u16 as *mut u8;
self.waiting_for_data = Some(ev);
Ok(())
@ -1678,10 +1727,10 @@ mod tests {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let mut a: i64 = 0;
let mut b: i64 = 0;
let mut c: i64 = 0;
let mut d: i64 = 0;
let mut a: i64;
let mut b: i64;
let mut c: i64;
let d: i64;
loop {
a = rng.gen_range(8..64);
b = rng.gen_range(8..64);
@ -1954,4 +2003,75 @@ mod tests {
}
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_matrix_mul_transposed_is_close_to_cpu_matrix_mul_transposed_1024x1024() {
let cl = OpenCL::new(false, 0).unwrap();
let a = Tensor::random(1024, 1024, TensorDType::Float32);
let b = Tensor::random(1024, 1024, TensorDType::Float32);
let mut a2 = a.to_f16();
let mut b2 = b.to_f16();
let mut c = Tensor::random(1024, 1024, TensorDType::Float32);
let mut c2 = Tensor::zeros(1024, 1024, TensorDType::Float32).to_f16();
a2.to_gpu(&cl).unwrap();
b2.to_gpu(&cl).unwrap();
c2.to_gpu(&cl).unwrap();
c.matrix_mul_inplace_transposed(&a, &b);
c2.matrix_mul_inplace_transposed(&a2, &b2);
c2.to_cpu().unwrap();
assert_eq!(c.rows(), c2.rows());
assert_eq!(c.cols(), c2.cols());
for row in 0..c.rows {
for col in 0..c.cols {
assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-1);
}
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_matrix_mul_transposed_is_close_to_cpu_matrix_mul_transposed() {
let cl = OpenCL::new(true, 1).unwrap();
let mut rng = rand::thread_rng();
for _trial in 0..300 {
let a = rng.gen_range(1..=300);
let b = rng.gen_range(1..=300);
let c = rng.gen_range(1..=300);
let mat1 = Tensor::random(a, b, TensorDType::Float16);
let mat2 = Tensor::random(c, b, TensorDType::Float16);
let mat3 = Tensor::random(a, c, TensorDType::Float16);
let mut mat1_gpu = mat1.clone();
let mut mat2_gpu = mat2.clone();
let mut mat3_gpu = mat3.clone();
mat1_gpu.to_gpu(&cl).unwrap();
mat2_gpu.to_gpu(&cl).unwrap();
mat3_gpu.to_gpu(&cl).unwrap();
let mat1 = mat1.to_f32();
let mat2 = mat2.to_f32();
let mut mat3 = mat3.to_f32();
mat3.matrix_mul_inplace_transposed(&mat1, &mat2);
mat3_gpu.matrix_mul_inplace_transposed(&mat1_gpu, &mat2_gpu);
mat3_gpu.to_cpu().unwrap();
assert_eq!(mat3.rows(), mat3_gpu.rows());
assert_eq!(mat3.cols(), mat3_gpu.cols());
for row in 0..mat3.rows {
for col in 0..mat3.cols {
assert_relative_eq!(
mat3.get_f32(row, col),
mat3_gpu.get_f32(row, col),
epsilon = 1e-2,
);
}
}
}
}
}

@ -2,23 +2,39 @@
* OpenCL stuff to run (some) of the tensor operations.
*/
use ocl::{Buffer, Context, Device, Event, Platform, Queue};
use ocl::{Buffer, Context, Device, Event, Kernel, Platform, Program, Queue};
use std::alloc::Layout;
use std::sync::{Arc, RwLock};
use thiserror::Error;
#[derive(Debug)]
#[allow(dead_code)]
struct Programs {
matrix_mul_transposed_by_row_f16_program: Program,
matrix_mul_transposed_by_row_f16: Kernel,
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct OpenCL {
ctx: Context,
queue: Queue,
programs: Arc<RwLock<Programs>>,
}
#[derive(Debug)]
pub struct OpenCLTensor {
buf: Buffer<u16>, // really is f16
write_event: Option<ocl::Event>, // if Some, the buffer is being written to
data: *const u16, // if non-null, is host pointer that should be freed
buf: Buffer<u16>, // really is f16
initial_write_event: Option<ocl::Event>,
last_event: Option<ocl::Event>,
data: *const u16,
data_layout: Layout,
nitems: usize,
rows: i64,
cols: i64,
cols_capacity: i64,
queue: Queue,
programs: Arc<RwLock<Programs>>,
}
#[derive(Debug)]
@ -28,10 +44,15 @@ pub struct OpenCLEvent {
impl Drop for OpenCLTensor {
fn drop(&mut self) {
if self.initial_write_event.is_some() {
self.initial_write_event
.as_ref()
.unwrap()
.wait_for()
.unwrap();
}
self.initial_write_event = None;
if !self.data.is_null() {
if self.write_event.is_some() {
self.write_event.as_ref().unwrap().wait_for().unwrap();
}
unsafe {
std::alloc::dealloc(self.data as *mut u8, self.data_layout);
}
@ -59,9 +80,9 @@ impl OpenCL {
if verbose {
println!("Enumerating OpenCL devices:");
}
for (idx, (_, device)) in devices.iter().enumerate() {
for (idx, (_plat, device)) in devices.iter().enumerate() {
if verbose {
println!("OpenCL {} device: {}", idx, device.name()?);
println!("OpenCL {} device: {}", idx, device.name()?,);
}
}
if nth_device > devices.len() {
@ -78,8 +99,12 @@ impl OpenCL {
.build()?;
let queue = Queue::new(&ctx, devices[nth_device].1, None)?;
Ok(OpenCL { ctx, queue })
let programs = make_programs(&ctx, &queue)?;
Ok(OpenCL {
ctx: ctx,
queue: queue,
programs: Arc::new(RwLock::new(programs)),
})
}
pub fn flush(&self) {
@ -91,6 +116,9 @@ impl OpenCL {
data: *const u16,
data_layout: Layout,
nitems: usize,
rows: i64,
cols: i64,
cols_capacity: i64,
) -> Result<OpenCLTensor, OpenCLError> {
unsafe {
let buf = Buffer::builder()
@ -106,10 +134,16 @@ impl OpenCL {
.enq()?;
Ok(OpenCLTensor {
buf,
write_event: Some(event),
initial_write_event: Some(event),
last_event: None,
data,
data_layout,
nitems,
rows,
cols,
cols_capacity,
queue: self.queue.clone(),
programs: self.programs.clone(),
})
}
}
@ -117,34 +151,96 @@ impl OpenCL {
impl OpenCLTensor {
pub fn wait_until_ready(&mut self) {
if self.write_event.is_none() {
return;
if self.last_event.is_some() {
self.last_event.as_ref().unwrap().wait_for().unwrap();
self.last_event = None;
}
if self.initial_write_event.is_some() {
self.initial_write_event
.as_ref()
.unwrap()
.wait_for()
.unwrap();
self.initial_write_event = None;
}
self.write_event.as_ref().unwrap().wait_for().unwrap();
self.write_event = None;
if !self.data.is_null() {
if self.write_event.is_some() {
self.write_event.as_ref().unwrap().wait_for().unwrap();
}
unsafe {
std::alloc::dealloc(self.data as *mut u8, self.data_layout);
}
self.data = std::ptr::null();
}
}
pub fn data_u16_from_gpu(&self, data: *mut u16) -> Result<OpenCLEvent, OpenCLError> {
pub fn data_u16_from_gpu(&mut self, data: *mut u16) -> Result<OpenCLEvent, OpenCLError> {
unsafe {
let mut event = Event::empty();
let data_slice: &mut [u16] = std::slice::from_raw_parts_mut(data, self.nitems);
self.buf
let b = self
.buf
.cmd()
.read(data_slice)
.block(false)
.enew(&mut event)
.enq()?;
.enew(&mut event);
b.enq()?;
self.last_event = Some(event.clone());
return Ok(OpenCLEvent { event });
}
}
pub fn matrix_mul_inplace_transposed(
&mut self,
src: &OpenCLTensor,
other: &OpenCLTensor,
) -> Result<OpenCLEvent, OpenCLError> {
if src.cols != other.cols {
panic!(
"OpenCL matrix_mul_inplace_transposed: src.cols must equal other.cols: {}x{} vs {}x{}",
src.rows, src.cols, other.rows, other.cols
);
}
if self.rows != src.rows || self.cols != other.rows {
panic!(
"OpenCL matrix_mul_inplace_transposed: self.rows must equal src.rows and self.cols must equal other.cols: {}x{} vs {}x{} vs {}x{}",
self.rows, self.cols, src.rows, src.cols, other.rows, other.cols
);
}
// Clear out the target memory
unsafe { self.buf.cmd().fill(0u16, None).block(false).enq()? };
let prg = self.programs.write().unwrap();
prg.matrix_mul_transposed_by_row_f16
.set_arg(0, self.buf.clone())?;
prg.matrix_mul_transposed_by_row_f16
.set_arg(1, src.buf.clone())?;
prg.matrix_mul_transposed_by_row_f16
.set_arg(2, other.buf.clone())?;
prg.matrix_mul_transposed_by_row_f16
.set_arg(3, src.cols_capacity as i32)?;
prg.matrix_mul_transposed_by_row_f16
.set_arg(4, other.cols_capacity as i32)?;
prg.matrix_mul_transposed_by_row_f16
.set_arg(5, self.cols_capacity as i32)?;
prg.matrix_mul_transposed_by_row_f16
.set_arg(6, self.rows as i32)?;
prg.matrix_mul_transposed_by_row_f16
.set_arg(7, self.cols as i32)?;
prg.matrix_mul_transposed_by_row_f16
.set_arg(8, src.cols as i32)?;
let mut event = Event::empty();
unsafe {
let b = prg
.matrix_mul_transposed_by_row_f16
.cmd()
.queue(&self.queue)
.global_work_size([self.rows as usize, self.cols_capacity as usize])
.enew(&mut event);
b.enq()?;
}
self.last_event = Some(event.clone());
Ok(OpenCLEvent { event })
}
}
impl OpenCLEvent {
@ -153,3 +249,121 @@ impl OpenCLEvent {
self.event.wait_for().unwrap();
}
}
fn make_programs(ctx: &Context, queue: &Queue) -> Result<Programs, OpenCLError> {
let mut last_err: Option<OpenCLError> = None;
// There used to be more sources here but now it's just one. This can go through programs and
// accept first one that compiles
for src in &[MATRIX_MUL_TRANSPOSED_BY_ROW_F16_SRC] {
fn make_programs_with_src(
ctx: &Context,
queue: &Queue,
src: &str,
) -> Result<Programs, OpenCLError> {
let program = Program::builder().src(src).build(&ctx)?;
let kernel = Kernel::builder()
.program(&program)
.name("matrix_mul_transposed_by_row_f16")
.arg(None::<&Buffer<u16>>)
.arg(None::<&Buffer<u16>>)
.arg(None::<&Buffer<u16>>)
.arg(&0)
.arg(&0)
.arg(&0)
.arg(&0)
.arg(&0)
.arg(&0)
.queue(queue.clone())
.build()?;
Ok(Programs {
matrix_mul_transposed_by_row_f16_program: program,
matrix_mul_transposed_by_row_f16: kernel,
})
}
match make_programs_with_src(ctx, queue, src) {
Err(e) => {
last_err = Some(e);
continue;
}
Ok(p) => return Ok(p),
}
}
if last_err.is_none() {
unreachable!();
}
Err(last_err.unwrap())
}
const MATRIX_MUL_TRANSPOSED_BY_ROW_F16_SRC: &str = r#"
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
/*
* Matrix multiplication with a transposed second matrix, using 16-bit floats.
*
* One work unit per row.
*
* Assumes that each row in the matrices are zero-padded so that there's space for 32 bytes (or 16
* halfs) of data and we don't need to care if our loops go over the bounds.
*
* Operations are done in float32.
*
* This thing is not very fast right now. I compared with PyTorch and this is like 20x slower. It
* is still much faster than CPU. Not sure PyTorch uses cuBlas but if we could get at least
* somewhere like 50% of that speed I would be happy.
*
* The OpenCL on CPU for Ryzen 3950X seems to easily beat my own AVX2 operations.
*
* TODO: need to read resources like https://cnugteren.github.io/tutorial/pages/page1.html to
* figure out how matrix multiply faster.
*/
__kernel void matrix_mul_transposed_by_row_f16(
__global half *tgt,
__global const half *left,
__global const half *right,
const int left_cols_capacity,
const int right_cols_capacity,
const int ncols_capacity,
const int nrows,
const int ncols, // size of target
const int shared_sz
) {
int col_iterations = shared_sz / 16;
if (shared_sz % 16 != 0) {
col_iterations = col_iterations + 1;
}
const int tgt_row = get_global_id(0);
const int tgt_col = get_global_id(1);
float16 sum = 0;
for (int col16 = 0; col16 < col_iterations; col16++) {
const float16 left8 = vload_half16((tgt_row * left_cols_capacity)/16 + col16, (__global const half*) left);
const float16 right8 = vload_half16((tgt_col * right_cols_capacity)/16 + col16, (__global const half*) right);
// hadamard product FMA add it to sum
// const float16 result8 = left8 * right8;
// sum += result8;
sum = fma(left8, right8, sum);
}
// Reduce as accurately as possible
float sum1 = sum.s0 + sum.s1;
float sum2 = sum.s2 + sum.s3;
float sum3 = sum.s4 + sum.s5;
float sum4 = sum.s6 + sum.s7;
float sum5 = sum.s8 + sum.s9;
float sum6 = sum.sa + sum.sb;
float sum7 = sum.sc + sum.sd;
float sum8 = sum.se + sum.sf;
float sum11 = sum1 + sum2;
float sum12 = sum3 + sum4;
float sum13 = sum5 + sum6;
float sum14 = sum7 + sum8;
float sum21 = sum11 + sum12;
float sum22 = sum13 + sum14;
float total = sum21 + sum22;
vstore_half(total, 0, (__global half*) &tgt[tgt_row * ncols_capacity + tgt_col]);
}
"#;

Loading…
Cancel
Save