diff --git a/README.md b/README.md index f399f56..5c6f36c 100644 --- a/README.md +++ b/README.md @@ -62,10 +62,10 @@ This is a hobby thing for me so don't expect updates or help. * Some other CPU implementations use quantization to reduce the size of weights * Put some of the operations on the OpenCL GPU/CPU. I've made some initial - OpenCL code but it is not used in the transformer loop yet. The CPU OpenCL - improves my own AVX2 code by like 100% and massively so on GPU although I am - also like 20x slower than equivalent operation on PyTorch on the same GPU. -* I've heard there is some thing called Tensor Cores on nVidia GPUs. Not + OpenCL code but there's still bunch of stuff that could be OpenCLified. + The OpenCL code is fast for both GPU OpenCL and CPU OpenCL (better than my + own handwritten AVX2 code which makes me sad). +* I've heard there is some thing called Tensor Cores on NVidia GPUs. Not accessible with OpenCL. But might be accessible on Vulkan with a an extension. * More sophisticated token sampling. I saw on Hackernews some comments how the diff --git a/src/tensor.rs b/src/tensor.rs index 176ec93..c019691 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -386,8 +386,40 @@ impl Tensor { tensor } - // Computes mean for each row, so that columns become 1. + // Computes mean for each row. The resulting matrix will have 1 column (which will contain the + // mean) pub fn mean_cols(&self) -> Tensor { + #[cfg(feature = "opencl")] + { + if self.is_on_gpu() { + self.mean_cols_gpu() + } else { + self.mean_cols_cpu() + } + } + #[cfg(not(feature = "opencl"))] + { + self.mean_cols_cpu() + } + } + + #[cfg(feature = "opencl")] + fn mean_cols_gpu(&self) -> Tensor { + self.assume_on_gpu(); + self.with_opencl_data(|src_tensor| { + let cl: OpenCL = src_tensor.cl(); + // TODO: don't generate a CPU-side copy, create the result directly on OpenCL side + let mut result = unsafe { Tensor::uninitialized(self.rows, 1, self.dtype) }; + result = result.to_f16(); + result.to_gpu(&cl).unwrap(); + result.with_opencl_data_mut(|tgt_tensor| { + tgt_tensor.mean_cols_from(src_tensor).unwrap(); + }); + result + }) + } + + fn mean_cols_cpu(&self) -> Tensor { self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, 1, self.dtype) }; for row in 0..self.rows { @@ -414,6 +446,38 @@ impl Tensor { } pub fn pow(&self, power: f32) -> Tensor { + #[cfg(feature = "opencl")] + { + if self.is_on_gpu() { + return self.pow_gpu(power); + } else { + return self.pow_cpu(power); + } + } + #[cfg(not(feature = "opencl"))] + { + return self.pow_cpu(power); + } + } + + #[cfg(feature = "opencl")] + fn pow_gpu(&self, power: f32) -> Tensor { + self.assume_on_gpu(); + self.with_opencl_data(|src_tensor| { + let cl: OpenCL = src_tensor.cl(); + // TODO: don't generate a CPU-side copy, create the result directly on OpenCL side + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + result = result.to_f16(); + result.to_gpu(&cl).unwrap(); + result.with_opencl_data_mut(|tgt_tensor| { + tgt_tensor.copy_inplace(src_tensor).unwrap(); + tgt_tensor.pow_inplace(power).unwrap(); + }); + result + }) + } + + fn pow_cpu(&self, power: f32) -> Tensor { self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { @@ -437,7 +501,39 @@ impl Tensor { result } + /// Computes 1/sqrt(x) for each element in the tensor. pub fn rsqrt(&self) -> Tensor { + #[cfg(feature = "opencl")] + { + if self.is_on_gpu() { + return self.rsqrt_gpu(); + } else { + return self.rsqrt_cpu(); + } + } + #[cfg(not(feature = "opencl"))] + { + return self.rsqrt_cpu(); + } + } + + fn rsqrt_gpu(&self) -> Tensor { + self.assume_on_gpu(); + self.with_opencl_data(|src_tensor| { + let cl: OpenCL = src_tensor.cl(); + // TODO: don't generate a CPU-side copy, create the result directly on OpenCL side + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + result = result.to_f16(); + result.to_gpu(&cl).unwrap(); + result.with_opencl_data_mut(|tgt_tensor| { + tgt_tensor.copy_inplace(src_tensor).unwrap(); + tgt_tensor.rsqrt_inplace().unwrap(); + }); + result + }) + } + + fn rsqrt_cpu(&self) -> Tensor { self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { @@ -450,8 +546,6 @@ impl Tensor { } pub fn add(&self, other: &Tensor) -> Tensor { - self.assume_on_cpu(); - other.assume_on_cpu(); if self.rows() != other.rows() || self.cols() != other.cols() { panic!( "add: Tensors must have the same shape, left: {}x{} right: {}x{}", @@ -461,6 +555,43 @@ impl Tensor { other.cols() ); } + + #[cfg(feature = "opencl")] + { + if self.is_on_gpu() { + return self.add_gpu(other); + } else { + return self.add_cpu(other); + } + } + #[cfg(not(feature = "opencl"))] + { + return self.add_cpu(other); + } + } + + fn add_gpu(&self, other: &Tensor) -> Tensor { + self.assume_on_gpu(); + other.assume_on_gpu(); + self.with_opencl_data(|src_tensor| { + let cl: OpenCL = src_tensor.cl(); + // TODO: don't generate a CPU-side copy, create the result directly on OpenCL side + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + result = result.to_f16(); + result.to_gpu(&cl).unwrap(); + other.with_opencl_data(|other_tensor| { + result.with_opencl_data_mut(|tgt_tensor| { + tgt_tensor.copy_inplace(src_tensor).unwrap(); + tgt_tensor.add_inplace(other_tensor).unwrap(); + }); + }); + result + }) + } + + fn add_cpu(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); + other.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { for col in 0..self.cols { @@ -472,6 +603,37 @@ impl Tensor { } pub fn add_scalar(&self, scalar: f32) -> Tensor { + #[cfg(feature = "opencl")] + { + if self.is_on_gpu() { + self.add_scalar_gpu(scalar) + } else { + self.add_scalar_cpu(scalar) + } + } + #[cfg(not(feature = "opencl"))] + { + self.add_scalar_cpu(scalar) + } + } + + fn add_scalar_gpu(&self, scalar: f32) -> Tensor { + self.assume_on_gpu(); + self.with_opencl_data(|src_tensor| { + let cl: OpenCL = src_tensor.cl(); + // TODO: don't generate a CPU-side copy, create the result directly on OpenCL side + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + result = result.to_f16(); + result.to_gpu(&cl).unwrap(); + result.with_opencl_data_mut(|tgt_tensor| { + tgt_tensor.copy_inplace(src_tensor).unwrap(); + tgt_tensor.add_scalar_inplace(scalar).unwrap(); + }); + result + }) + } + + fn add_scalar_cpu(&self, scalar: f32) -> Tensor { self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { @@ -496,6 +658,48 @@ impl Tensor { } pub fn scalar_multiply_broadcast(&self, other: &Tensor) -> Tensor { + #[cfg(feature = "opencl")] + { + if self.is_on_gpu() { + self.scalar_multiply_broadcast_gpu(other) + } else { + self.scalar_multiply_broadcast_cpu(other) + } + } + #[cfg(not(feature = "opencl"))] + { + self.scalar_multiply_broadcast_cpu(other) + } + } + + fn scalar_multiply_broadcast_gpu(&self, other: &Tensor) -> Tensor { + self.assume_on_gpu(); + other.assume_on_gpu(); + if other.cols != 1 { + panic!("Invalid scalar broadcast"); + } + if other.rows != self.rows { + panic!("Invalid scalar broadcast"); + } + self.with_opencl_data(|src_tensor| { + let cl: OpenCL = src_tensor.cl(); + // TODO: don't generate a CPU-side copy, create the result directly on OpenCL side + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + result = result.to_f16(); + result.to_gpu(&cl).unwrap(); + other.with_opencl_data(|other_tensor| { + result.with_opencl_data_mut(|tgt_tensor| { + tgt_tensor.copy_inplace(src_tensor).unwrap(); + tgt_tensor + .scalar_multiply_broadcast_inplace(other_tensor) + .unwrap(); + }); + }); + result + }) + } + + fn scalar_multiply_broadcast_cpu(&self, other: &Tensor) -> Tensor { self.assume_on_cpu(); if other.cols != 1 { panic!("Invalid scalar broadcast"); @@ -531,8 +735,6 @@ impl Tensor { } pub fn hadamard_product_broadcast(&self, other: &Tensor) -> Tensor { - self.assume_on_cpu(); - other.assume_on_cpu(); if self.cols != other.cols { panic!( "Invalid hadamard product broadcast: {}x{} vs {}x{}", @@ -545,6 +747,44 @@ impl Tensor { self.rows, self.cols, other.rows, other.cols ); } + + #[cfg(feature = "opencl")] + { + if self.is_on_gpu() { + self.hadamard_product_broadcast_gpu(other) + } else { + self.hadamard_product_broadcast_cpu(other) + } + } + #[cfg(not(feature = "opencl"))] + { + self.hadamard_product_broadcast_cpu(scalar) + } + } + + fn hadamard_product_broadcast_gpu(&self, other: &Tensor) -> Tensor { + self.assume_on_gpu(); + self.with_opencl_data(|src_tensor| { + let cl: OpenCL = src_tensor.cl(); + // TODO: don't generate a CPU-side copy, create the result directly on OpenCL side + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + result = result.to_f16(); + result.to_gpu(&cl).unwrap(); + other.with_opencl_data(|other_tensor| { + result.with_opencl_data_mut(|tgt_tensor| { + tgt_tensor.copy_inplace(src_tensor).unwrap(); + tgt_tensor + .hadamard_product_broadcast_inplace(other_tensor) + .unwrap(); + }); + }); + result + }) + } + + fn hadamard_product_broadcast_cpu(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); + other.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { for col in 0..self.cols { @@ -644,21 +884,6 @@ impl Tensor { result } - pub fn silu(&self) -> Tensor { - #[cfg(feature = "opencl")] - { - if self.is_on_gpu() { - self.silu_gpu() - } else { - self.silu_cpu() - } - } - #[cfg(not(feature = "opencl"))] - { - self.silu_cpu() - } - } - // with_opencl_data & with_opencl_data_mut are utilities to get access to the underlying // OpenCLTensor, if the tensor is on gpu. Panics if they are not on GPU. #[cfg(feature = "opencl")] @@ -681,6 +906,21 @@ impl Tensor { f(opencl_data.unwrap()) } + pub fn silu(&self) -> Tensor { + #[cfg(feature = "opencl")] + { + if self.is_on_gpu() { + self.silu_gpu() + } else { + self.silu_cpu() + } + } + #[cfg(not(feature = "opencl"))] + { + self.silu_cpu() + } + } + #[cfg(feature = "opencl")] fn silu_gpu(&self) -> Tensor { self.assume_on_gpu(); @@ -2212,6 +2452,181 @@ mod tests { } } + #[cfg(feature = "opencl")] + #[test] + fn gpu_rsqrt_and_cpu_rsqrt_agree() { + let cl = OpenCL::new(false, 0).unwrap(); + + for _trial in 0..300 { + let mut rng = rand::thread_rng(); + let a = rng.gen_range(1..=300); + let b = rng.gen_range(1..=300); + let mat1 = Tensor::random(a, b, TensorDType::Float16); + let mat2 = mat1.clone(); + let mut mat2 = mat2.to_f16(); + mat2.to_gpu(&cl).unwrap(); + + let mat1_result = mat1.rsqrt(); + let mut mat2_result = mat2.rsqrt(); + mat2_result.to_cpu().unwrap(); + + assert_eq!(mat1_result.rows(), mat2_result.rows()); + assert_eq!(mat1_result.cols(), mat2_result.cols()); + + for row in 0..mat1_result.rows { + for col in 0..mat1_result.cols { + let mat1_v = mat1_result.get_f32(row, col); + let mat2_v = mat2_result.get_f32(row, col); + if mat1_v.is_nan() && mat2_v.is_nan() { + continue; + } + assert_relative_eq!(mat1_v, mat2_v, epsilon = 1e-2); + } + } + } + } + + #[cfg(feature = "opencl")] + #[test] + fn gpu_add_and_cpu_add_agree() { + let cl = OpenCL::new(false, 0).unwrap(); + + for _trial in 0..300 { + let mut rng = rand::thread_rng(); + let a = rng.gen_range(1..=300); + let b = rng.gen_range(1..=300); + let mat1 = Tensor::random(a, b, TensorDType::Float16); + let mat2 = Tensor::random(a, b, TensorDType::Float16); + + let mut mat2 = mat2.to_f16(); + mat2.to_gpu(&cl).unwrap(); + + let mat1_result = mat1.rsqrt(); + let mut mat2_result = mat2.rsqrt(); + mat2_result.to_cpu().unwrap(); + + assert_eq!(mat1_result.rows(), mat2_result.rows()); + assert_eq!(mat1_result.cols(), mat2_result.cols()); + + for row in 0..mat1_result.rows { + for col in 0..mat1_result.cols { + let mat1_v = mat1_result.get_f32(row, col); + let mat2_v = mat2_result.get_f32(row, col); + if mat1_v.is_nan() && mat2_v.is_nan() { + continue; + } + assert_relative_eq!(mat1_v, mat2_v, epsilon = 1e-2); + } + } + } + } + + #[cfg(feature = "opencl")] + #[test] + fn gpu_pow_and_cpu_pow_agree() { + let cl = OpenCL::new(false, 0).unwrap(); + + for _trial in 0..300 { + let mut rng = rand::thread_rng(); + let a = rng.gen_range(1..=100); + let b = rng.gen_range(1..=100); + let c = rng.gen_range(-1.2..1.2); + let mat1 = Tensor::random(a, b, TensorDType::Float16); + let mat2 = mat1.clone(); + let mut mat2 = mat2.to_f16(); + mat2.to_gpu(&cl).unwrap(); + + let mat1_result = mat1.pow(c); + let mut mat2_result = mat2.pow(c); + mat2_result.to_cpu().unwrap(); + + assert_eq!(mat1_result.rows(), mat2_result.rows()); + assert_eq!(mat1_result.cols(), mat2_result.cols()); + + for row in 0..mat1_result.rows { + for col in 0..mat1_result.cols { + let left = mat1_result.get_f32(row, col); + let right = mat2_result.get_f32(row, col); + if left.is_nan() && right.is_nan() { + continue; + } + + assert_relative_eq!(left, right, epsilon = 1e-1); + } + } + } + } + + #[cfg(feature = "opencl")] + #[test] + fn gpu_add_scalar_and_cpu_add_scalar_agree() { + let cl = OpenCL::new(false, 0).unwrap(); + for _trial in 0..300 { + let mut rng = rand::thread_rng(); + let a = rng.gen_range(1..=100); + let b = rng.gen_range(1..=100); + let mat1 = Tensor::random(a, b, TensorDType::Float16); + let mat2 = Tensor::random(a, b, TensorDType::Float16); + let mut mat1_gpu = mat1.clone(); + let mut mat2_gpu = mat2.clone(); + mat1_gpu.to_gpu(&cl).unwrap(); + mat2_gpu.to_gpu(&cl).unwrap(); + + let result1 = mat1.add(&mat2); + let mut result2 = mat1_gpu.add(&mat2_gpu); + result2.to_cpu().unwrap(); + + assert_eq!(result1.rows(), result2.rows()); + assert_eq!(result1.cols(), result2.cols()); + + for row in 0..result1.rows { + for col in 0..result1.cols { + let left = result1.get_f32(row, col); + let right = result2.get_f32(row, col); + if left.is_nan() && right.is_nan() { + continue; + } + + assert_relative_eq!(left, right, epsilon = 1e-2); + } + } + } + } + + #[cfg(feature = "opencl")] + #[test] + fn gpu_mean_cols_and_cpu_mean_cols_agree() { + let cl = OpenCL::new(false, 0).unwrap(); + for _trial in 0..300 { + let mut rng = rand::thread_rng(); + let a = rng.gen_range(1..=100); + let b = rng.gen_range(1..=100); + let mat1 = Tensor::random(a, b, TensorDType::Float16); + let mat2 = mat1.clone(); + let mut mat2 = mat2.to_f16(); + mat2.to_gpu(&cl).unwrap(); + + let mat1_result = mat1.mean_cols(); + let mut mat2_result = mat2.mean_cols(); + mat2_result.to_cpu().unwrap(); + + assert_eq!(mat1_result.rows(), mat2_result.rows()); + assert_eq!(mat1_result.cols(), mat2_result.cols()); + + for row in 0..mat1_result.rows { + for col in 0..mat1_result.cols { + let left = mat1_result.get_f32(row, col); + let right = mat2_result.get_f32(row, col); + if left.is_nan() && right.is_nan() { + continue; + } + + assert_relative_eq!(left, right, epsilon = 1e-2,); + } + } + } + } + #[cfg(feature = "opencl")] #[test] fn gpu_hadamard_product_and_cpu_hadamard_product_agree() { @@ -2248,6 +2663,78 @@ mod tests { } } + #[cfg(feature = "opencl")] + #[test] + fn gpu_hadamard_product_broadcast_and_cpu_hadamard_product_broadcast_agree() { + let cl = OpenCL::new(false, 0).unwrap(); + + for _trial in 0..300 { + let mut rng = rand::thread_rng(); + let a = rng.gen_range(1..=300); + let b = rng.gen_range(1..=300); + let mat1 = Tensor::random(a, b, TensorDType::Float16); + let mat2 = Tensor::random(1, b, TensorDType::Float16); + + let mut mat1_gpu = mat1.to_f16(); + let mut mat2_gpu = mat2.to_f16(); + mat1_gpu.to_gpu(&cl).unwrap(); + mat2_gpu.to_gpu(&cl).unwrap(); + + let result1 = mat1.hadamard_product_broadcast(&mat2); + let mut result2 = mat1_gpu.hadamard_product_broadcast(&mat2_gpu); + result2.to_cpu().unwrap(); + + assert_eq!(result1.rows(), result2.rows()); + assert_eq!(result1.cols(), result2.cols()); + + for row in 0..result1.rows() { + for col in 0..result2.cols() { + assert_relative_eq!( + result1.get_f32(row, col), + result2.get_f32(row, col), + epsilon = 1e-2 + ); + } + } + } + } + + #[cfg(feature = "opencl")] + #[test] + fn gpu_scalar_multiply_product_broadcast_and_cpu_scalar_multiply_product_broadcast_agree() { + let cl = OpenCL::new(false, 0).unwrap(); + + for _trial in 0..300 { + let mut rng = rand::thread_rng(); + let a = rng.gen_range(1..=300); + let b = rng.gen_range(1..=300); + let mat1 = Tensor::random(a, b, TensorDType::Float16); + let mat2 = Tensor::random(a, 1, TensorDType::Float16); + + let mut mat1_gpu = mat1.to_f16(); + let mut mat2_gpu = mat2.to_f16(); + mat1_gpu.to_gpu(&cl).unwrap(); + mat2_gpu.to_gpu(&cl).unwrap(); + + let result1 = mat1.scalar_multiply_broadcast(&mat2); + let mut result2 = mat1_gpu.scalar_multiply_broadcast(&mat2_gpu); + result2.to_cpu().unwrap(); + + assert_eq!(result1.rows(), result2.rows()); + assert_eq!(result1.cols(), result2.cols()); + + for row in 0..result1.rows() { + for col in 0..result2.cols() { + assert_relative_eq!( + result1.get_f32(row, col), + result2.get_f32(row, col), + epsilon = 1e-2 + ); + } + } + } + } + #[cfg(feature = "opencl")] #[test] fn gpu_transpose_and_cpu_transpose_agree() { @@ -2313,11 +2800,13 @@ mod tests { for row in 0..mat3.rows { for col in 0..mat3.cols { - assert_relative_eq!( - mat3.get_f32(row, col), - mat3_gpu.get_f32(row, col), - epsilon = 1e-2, - ); + let left = mat3.get_f32(row, col); + let right = mat3_gpu.get_f32(row, col); + if left.is_nan() && right.is_nan() { + continue; + } + + assert_relative_eq!(left, right, epsilon = 1e-2,); } } } diff --git a/src/tensor_opencl_support.rs b/src/tensor_opencl_support.rs index 59e8ace..3464b31 100644 --- a/src/tensor_opencl_support.rs +++ b/src/tensor_opencl_support.rs @@ -18,6 +18,20 @@ struct Programs { hadamard_product_f16: Kernel, transpose_f16_program: Program, transpose_f16: Kernel, + pow_f16_program: Program, + pow_f16: Kernel, + mean_cols_f16_program: Program, + mean_cols_f16: Kernel, + add_scalar_f16_program: Program, + add_scalar_f16: Kernel, + scalar_multiply_broadcast_f16_program: Program, + scalar_multiply_broadcast_f16: Kernel, + hadamard_product_broadcast_f16_program: Program, + hadamard_product_broadcast_f16: Kernel, + rsqrt_f16_program: Program, + rsqrt_f16: Kernel, + add_f16_program: Program, + add_f16: Kernel, } #[derive(Debug, Clone)] @@ -217,6 +231,58 @@ impl OpenCLTensor { Ok(OpenCLEvent { event }) } + pub fn add_scalar_inplace(&mut self, scalar: f32) -> Result { + let prg = self.cl.programs.write().unwrap(); + prg.add_scalar_f16.set_arg(0, self.buf.clone()).unwrap(); + prg.add_scalar_f16 + .set_arg(1, self.cols_capacity as i32) + .unwrap(); + prg.add_scalar_f16.set_arg(2, scalar).unwrap(); + let mut event = Event::empty(); + unsafe { + let b = prg + .add_scalar_f16 + .cmd() + .queue(&self.queue) + .global_work_size([self.rows as usize, self.cols as usize]) + .enew(&mut event); + b.enq()?; + } + self.last_event = Some(event.clone()); + Ok(OpenCLEvent { event }) + } + + pub fn scalar_multiply_broadcast_inplace( + &mut self, + other: &OpenCLTensor, + ) -> Result { + let prg = self.cl.programs.write().unwrap(); + prg.scalar_multiply_broadcast_f16 + .set_arg(0, self.buf.clone()) + .unwrap(); + prg.scalar_multiply_broadcast_f16 + .set_arg(1, other.buf.clone()) + .unwrap(); + prg.scalar_multiply_broadcast_f16 + .set_arg(2, self.cols_capacity as i32) + .unwrap(); + prg.scalar_multiply_broadcast_f16 + .set_arg(3, other.cols_capacity as i32) + .unwrap(); + let mut event = Event::empty(); + unsafe { + let b = prg + .scalar_multiply_broadcast_f16 + .cmd() + .queue(&self.queue) + .global_work_size([self.rows as usize, (self.cols_capacity / 16) as usize]) + .enew(&mut event); + b.enq()?; + } + self.last_event = Some(event.clone()); + Ok(OpenCLEvent { event }) + } + pub fn transpose_from(&mut self, other: &OpenCLTensor) -> Result { let prg = self.cl.programs.write().unwrap(); prg.transpose_f16.set_arg(0, self.buf.clone()).unwrap(); @@ -235,7 +301,7 @@ impl OpenCLTensor { .queue(&self.queue) .global_work_size([self.rows as usize, self.cols as usize]) .enew(&mut event); - b.enq().unwrap(); + b.enq()?; } self.last_event = Some(event.clone()); Ok(OpenCLEvent { event }) @@ -266,6 +332,85 @@ impl OpenCLTensor { Ok(OpenCLEvent { event }) } + pub fn hadamard_product_broadcast_inplace( + &mut self, + other: &OpenCLTensor, + ) -> Result { + let prg = self.cl.programs.write().unwrap(); + prg.hadamard_product_broadcast_f16 + .set_arg(0, self.buf.clone())?; + prg.hadamard_product_broadcast_f16 + .set_arg(1, other.buf.clone())?; + prg.hadamard_product_broadcast_f16 + .set_arg(2, self.cols_capacity as i32)?; + prg.hadamard_product_broadcast_f16 + .set_arg(3, other.cols_capacity as i32)?; + let mut event = Event::empty(); + unsafe { + let b = prg + .hadamard_product_broadcast_f16 + .cmd() + .queue(&self.queue) + .global_work_size([self.rows as usize, (self.cols_capacity as usize) / 16]) + .enew(&mut event); + b.enq()?; + } + self.last_event = Some(event.clone()); + Ok(OpenCLEvent { event }) + } + + pub fn mean_cols_from(&mut self, other: &OpenCLTensor) -> Result { + if self.cols != 1 { + panic!( + "mean_cols_from: number of columns in target is not 1: {}", + self.cols + ); + } + if self.rows != other.rows { + panic!( + "mean_cols_from: number of rows in target is not equal to number of rows in source: {} != {}", + self.rows, other.rows + ); + } + let prg = self.cl.programs.write().unwrap(); + prg.mean_cols_f16.set_arg(0, self.buf.clone())?; + prg.mean_cols_f16.set_arg(1, other.buf.clone())?; + prg.mean_cols_f16.set_arg(2, self.cols_capacity as i32)?; + prg.mean_cols_f16.set_arg(3, other.cols_capacity as i32)?; + prg.mean_cols_f16.set_arg(4, other.cols as i32)?; + let mut event = Event::empty(); + unsafe { + let b = prg + .mean_cols_f16 + .cmd() + .queue(&self.queue) + .global_work_size([self.rows as usize, 1]) + .enew(&mut event); + b.enq()?; + } + self.last_event = Some(event.clone()); + Ok(OpenCLEvent { event }) + } + + pub fn pow_inplace(&mut self, scalar: f32) -> Result { + let prg = self.cl.programs.write().unwrap(); + prg.pow_f16.set_arg(0, self.buf.clone())?; + prg.pow_f16.set_arg(1, self.cols_capacity as i32)?; + prg.pow_f16.set_arg(2, scalar)?; + let mut event = Event::empty(); + unsafe { + let b = prg + .pow_f16 + .cmd() + .queue(&self.queue) + .global_work_size([self.rows as usize, self.cols as usize]) + .enew(&mut event); + b.enq()?; + } + self.last_event = Some(event.clone()); + Ok(OpenCLEvent { event }) + } + pub fn silu_inplace(&mut self) -> Result { let prg = self.cl.programs.write().unwrap(); prg.silu_f16.set_arg(0, self.buf.clone())?; @@ -284,6 +429,44 @@ impl OpenCLTensor { Ok(OpenCLEvent { event }) } + pub fn add_inplace(&mut self, left: &OpenCLTensor) -> Result { + let prg = self.cl.programs.write().unwrap(); + prg.add_f16.set_arg(0, self.buf.clone())?; + prg.add_f16.set_arg(1, left.buf.clone())?; + prg.add_f16.set_arg(2, self.cols_capacity as i32)?; + prg.add_f16.set_arg(3, left.cols_capacity as i32)?; + let mut event = Event::empty(); + unsafe { + let b = prg + .add_f16 + .cmd() + .queue(&self.queue) + .global_work_size([self.rows as usize, self.cols as usize]) + .enew(&mut event); + b.enq()?; + } + self.last_event = Some(event.clone()); + Ok(OpenCLEvent { event }) + } + + pub fn rsqrt_inplace(&mut self) -> Result { + let prg = self.cl.programs.write().unwrap(); + prg.rsqrt_f16.set_arg(0, self.buf.clone())?; + prg.rsqrt_f16.set_arg(1, self.cols_capacity as i32)?; + let mut event = Event::empty(); + unsafe { + let b = prg + .rsqrt_f16 + .cmd() + .queue(&self.queue) + .global_work_size([self.rows as usize, self.cols as usize]) + .enew(&mut event); + b.enq()?; + } + self.last_event = Some(event.clone()); + Ok(OpenCLEvent { event }) + } + pub fn matrix_mul_inplace_transposed( &mut self, src: &OpenCLTensor, @@ -397,6 +580,75 @@ fn make_programs(ctx: &Context, queue: &Queue) -> Result .arg(&0) .queue(queue.clone()) .build()?; + let pow_f16_program = make_program_with_src(ctx, POW_F16_SRC)?; + let pow_f16 = Kernel::builder() + .program(&pow_f16_program) + .name("pow_f16") + .arg(None::<&Buffer>) + .arg(&0) + .arg(&0) + .queue(queue.clone()) + .build()?; + let mean_cols_f16_program = make_program_with_src(ctx, MEAN_COLS_F16_SRC)?; + let mean_cols_f16 = Kernel::builder() + .program(&mean_cols_f16_program) + .name("mean_cols_f16") + .arg(None::<&Buffer>) + .arg(None::<&Buffer>) + .arg(&0) + .arg(&0) + .arg(&0) + .queue(queue.clone()) + .build()?; + let add_scalar_f16_program = make_program_with_src(ctx, ADD_SCALAR_F16_SRC)?; + let add_scalar_f16 = Kernel::builder() + .program(&add_scalar_f16_program) + .name("add_scalar_f16") + .arg(None::<&Buffer>) + .arg(&0) + .arg(&0) + .queue(queue.clone()) + .build()?; + let scalar_multiply_broadcast_f16_program = + make_program_with_src(ctx, SCALAR_MULTIPLY_BROADCAST_F16_SRC)?; + let scalar_multiply_broadcast_f16 = Kernel::builder() + .program(&scalar_multiply_broadcast_f16_program) + .name("scalar_multiply_broadcast_f16") + .arg(None::<&Buffer>) + .arg(None::<&Buffer>) + .arg(&0) + .arg(&0) + .queue(queue.clone()) + .build()?; + let hadamard_product_broadcast_f16_program = + make_program_with_src(ctx, HADAMARD_PRODUCT_BROADCAST_F16_SRC)?; + let hadamard_product_broadcast_f16 = Kernel::builder() + .program(&hadamard_product_broadcast_f16_program) + .name("hadamard_product_broadcast_f16") + .arg(None::<&Buffer>) + .arg(None::<&Buffer>) + .arg(&0) + .arg(&0) + .queue(queue.clone()) + .build()?; + let rsqrt_f16_program = make_program_with_src(ctx, RSQRT_F16_SRC)?; + let rsqrt_f16 = Kernel::builder() + .program(&rsqrt_f16_program) + .name("rsqrt_f16") + .arg(None::<&Buffer>) + .arg(&0) + .queue(queue.clone()) + .build()?; + let add_f16_program = make_program_with_src(ctx, ADD_F16_SRC)?; + let add_f16 = Kernel::builder() + .program(&add_f16_program) + .name("add_f16") + .arg(None::<&Buffer>) + .arg(None::<&Buffer>) + .arg(&0) + .arg(&0) + .queue(queue.clone()) + .build()?; Ok(Programs { matrix_mul_transposed_by_row_f16_program, matrix_mul_transposed_by_row_f16, @@ -406,6 +658,20 @@ fn make_programs(ctx: &Context, queue: &Queue) -> Result hadamard_product_f16, transpose_f16_program, transpose_f16, + pow_f16_program, + pow_f16, + mean_cols_f16_program, + mean_cols_f16, + add_scalar_f16_program, + add_scalar_f16, + scalar_multiply_broadcast_f16_program, + scalar_multiply_broadcast_f16, + hadamard_product_broadcast_f16_program, + hadamard_product_broadcast_f16, + rsqrt_f16_program, + rsqrt_f16, + add_f16_program, + add_f16, }) } @@ -532,3 +798,131 @@ __kernel void transpose_f16(__global half *tgt, vstore_half(val, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt); } "#; + +/// Computes x^scalar for every f16 value in the tensor +const POW_F16_SRC: &str = r#" +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void pow_f16(__global half *tgt, + const int ncols_capacity, + const float scalar) +{ + const int tgt_row = get_global_id(0); + const int tgt_col = get_global_id(1); + const float val = vload_half(tgt_row * ncols_capacity + tgt_col, (__global const half*) tgt); + const float result = pow(val, scalar); + vstore_half(result, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt); +} +"#; + +/// Computes the mean of each column in a tensor +const MEAN_COLS_F16_SRC: &str = r#" +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void mean_cols_f16(__global half *tgt, + __global const half *left, + const int ncols_capacity, + const int left_cols_capacity, + const int ncolumns) +{ + // global work group size is nrows x 1 + const int row = get_global_id(0); + + float16 src_value = 0.0; + for (int col16 = 0; col16 < left_cols_capacity; col16 += 16) { + const int actual_col = col16; + if (actual_col >= ncolumns) { + break; + } + src_value += vload_half16((row * left_cols_capacity)/16 + col16/16, (__global const half*) left); + } + float src_value_sum = src_value.s0 + src_value.s1 + src_value.s2 + src_value.s3 + src_value.s4 + src_value.s5 + src_value.s6 + src_value.s7 + src_value.s8 + src_value.s9 + src_value.sa + src_value.sb + src_value.sc + src_value.sd + src_value.se + src_value.sf; + src_value_sum = src_value_sum / (float) ncolumns; + vstore_half(src_value_sum, row * ncols_capacity, (__global half*) tgt); +} +"#; + +/// Adds a scalar to a tensor +const ADD_SCALAR_F16_SRC: &str = r#" +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void add_scalar_f16(__global half *tgt, const int ncols_capacity, const float scalar) +{ + const int tgt_row = get_global_id(0); + const int tgt_col = get_global_id(1); + const float val = vload_half(tgt_row * ncols_capacity + tgt_col, (__global const half*) tgt); + const float result = val + scalar; + vstore_half(result, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt); +} +"#; + +/// Adds scalars from a row vector to each row of a tensor +const SCALAR_MULTIPLY_BROADCAST_F16_SRC: &str = r#" +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void scalar_multiply_broadcast_f16(__global half *tgt, + __global const half *left, + const int ncols_capacity, + const int left_cols_capacity) +{ + // global work group size is nrows x (ncols/16) + const int row = get_global_id(0); + const int col = get_global_id(1) * 16; + + const float scalar = vload_half(row * left_cols_capacity, (__global const half*) left); + + float16 src_value = vload_half16((row * ncols_capacity)/16 + col/16, (__global const half*) tgt) * scalar; + vstore_half16(src_value, (row * ncols_capacity)/16 + col/16, (__global half*) tgt); +} +"#; + +/// Does a hadamard product from a column vector to each column of a tensor +const HADAMARD_PRODUCT_BROADCAST_F16_SRC: &str = r#" +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void hadamard_product_broadcast_f16(__global half *tgt, + __global const half *left, + const int ncols_capacity, + const int left_cols_capacity) +{ + // global work group size is nrows x (ncols/16) + const int row = get_global_id(0); + const int col16 = get_global_id(1) * 16; + const float16 product_value = vload_half16(col16/16, (__global const half*) left); + const float16 src_value = vload_half16((row * ncols_capacity)/16 + col16/16, (__global const half*) tgt); + const float16 result = src_value * product_value; + vstore_half16(result, (row * ncols_capacity)/16 + col16/16, (__global half*) tgt); +} +"#; + +/// Computes 1/sqrt(x) for each f16 value in the tensor +const RSQRT_F16_SRC: &str = r#" +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void rsqrt_f16(__global half *tgt, const int ncols_capacity) +{ + const int tgt_row = get_global_id(0); + const int tgt_col = get_global_id(1); + const float val = vload_half(tgt_row * ncols_capacity + tgt_col, (__global const half*) tgt); + const float result = rsqrt(val); + vstore_half(result, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt); +} +"#; + +/// Computes sum of two tensors +const ADD_F16_SRC: &str = r#" +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void add_f16(__global half *tgt, + __global const half *left, + const int tgt_ncols_capacity, + const int left_ncols_capacity) +{ + const int tgt_row = get_global_id(0); + const int tgt_col = get_global_id(1); + const float tgt_v = vload_half(tgt_row * tgt_ncols_capacity + tgt_col, (__global const half*) tgt); + const float left_v = vload_half(tgt_row * left_ncols_capacity + tgt_col, (__global const half*) left); + const float result = tgt_v + left_v; + vstore_half(result, tgt_row * tgt_ncols_capacity + tgt_col, (__global half*) tgt); +} +"#; diff --git a/src/transformer.rs b/src/transformer.rs index 4d158b8..75afd12 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -8,6 +8,7 @@ use crate::unpickler::UnpicklingError; use indicatif::ProgressBar; use num_complex::Complex; use rayon::prelude::*; +use std::mem::drop; use std::path::Path; use std::sync::{Arc, RwLock}; @@ -38,6 +39,8 @@ pub struct DataSettings { #[cfg(feature = "opencl")] use_opencl_for_attention: bool, #[cfg(feature = "opencl")] + use_opencl_for_rmsnorm: bool, + #[cfg(feature = "opencl")] cl: Option, } @@ -51,6 +54,7 @@ impl DataSettings { DataSettings { use_opencl_for_feedforward: false, use_opencl_for_attention: false, + use_opencl_for_rmsnorm: false, cl: cl.clone(), } } @@ -67,6 +71,7 @@ impl DataSettings { } self.use_opencl_for_feedforward = true; self.use_opencl_for_attention = true; + self.use_opencl_for_rmsnorm = true; self } } @@ -137,6 +142,7 @@ impl TransformerCaches { pub struct RMSNorm { eps: f64, weight: Tensor, + data_settings: DataSettings, } pub struct Attention { @@ -195,9 +201,15 @@ impl Transformer { result }) .collect::, UnpicklingError>>()?; - std::mem::drop(progress_bar); + drop(progress_bar); - let norm = RMSNorm::from_unpickled(unpickled, "norm.weight".to_string(), eps, data_dir)?; + let norm = RMSNorm::from_unpickled( + unpickled, + "norm.weight".to_string(), + eps, + data_settings.clone(), + data_dir, + )?; let output = Tensor::from_unpickled_pieces( unpickled, "output.weight", @@ -261,18 +273,23 @@ impl Transformer { embs.push(emb); } let mut emb_tensor: Tensor = Tensor::concat(&embs); - std::mem::drop(embs); + drop(embs); for (idx, layer) in self.layers.iter().enumerate() { emb_tensor = layer.forward( - &emb_tensor, + &mut emb_tensor, start_pos, &self.freqs_cis, &mask, &mut caches.layer_caches[idx], ); } - let out = self.norm.forward(&emb_tensor); + let mut out = self.norm.forward(&mut emb_tensor); + #[cfg(feature = "opencl")] + if out.is_on_gpu() { + out.to_cpu().unwrap(); + out = out.to_f32(); + } let out = out.row(out.rows() - 1); self.output.matrix_mul_transposed(&out) @@ -296,19 +313,21 @@ impl TransformerBlock { layer_id, n_local_heads, head_dim, - data_settings, + data_settings.clone(), data_dir, )?; let ffn_norm = RMSNorm::from_unpickled( unpickled, format!("layers.{}.ffn_norm.weight", layer_id), eps, + data_settings.clone(), data_dir, )?; let attn_norm = RMSNorm::from_unpickled( unpickled, format!("layers.{}.attention_norm.weight", layer_id), eps, + data_settings.clone(), data_dir, )?; Ok(Self { @@ -321,26 +340,61 @@ impl TransformerBlock { pub fn forward( &self, - x: &Tensor, + x: &mut Tensor, start_pos: usize, freqs_cis: &FreqsCis, mask: &Option, attention_cache: &mut AttentionCache, ) -> Tensor { + let now = std::time::Instant::now(); let mut attnorm_out = self.attention_norm.forward(x); - let att_out = self.attn.forward( + let now = std::time::Instant::now(); + let mut att_out = self.attn.forward( &mut attnorm_out, start_pos, freqs_cis, mask, attention_cache, ); - std::mem::drop(attnorm_out); + let now = std::time::Instant::now(); + drop(attnorm_out); - let h = x.add(&att_out); - let mut att_out = self.ffn_norm.forward(&h); + #[cfg(feature = "opencl")] + let mut x_was_on_cpu: bool; + #[cfg(feature = "opencl")] + { + x_was_on_cpu = x.is_on_cpu(); + if x_was_on_cpu { + *x = x.to_f16(); + x.to_gpu(self.attention_norm.data_settings.cl.as_ref().unwrap()) + .unwrap(); + } + if x.is_on_gpu() { + att_out = att_out.to_f16(); + att_out + .to_gpu(self.attention_norm.data_settings.cl.as_ref().unwrap()) + .unwrap(); + } + } + let mut h = x.add(&att_out); + let now = std::time::Instant::now(); + let mut att_out = self.ffn_norm.forward(&mut h); + let now = std::time::Instant::now(); let att_out = self.feed_forward.forward(&mut att_out).transpose(); - h.add(&att_out) + let mut result = h.add(&att_out); + #[cfg(feature = "opencl")] + { + if x_was_on_cpu { + result.to_cpu().unwrap(); + return result.to_f32(); + } else { + result + } + } + #[cfg(not(feature = "opencl"))] + { + result + } } } @@ -349,26 +403,64 @@ impl RMSNorm { unpickled: &[unpickler::Value], name: String, eps: f64, + data_settings: DataSettings, data_dir: P, ) -> Result { let data_dir: &Path = data_dir.as_ref(); - let weights = Tensor::from_unpickled_pieces( + let mut weights = Tensor::from_unpickled_pieces( &unpickled[0..=0], name.clone(), data_dir, FromPiecesDirection::Rows, - )? - .to_f32(); + )?; + + #[cfg(feature = "opencl")] + { + if data_settings.use_opencl_for_rmsnorm { + weights = weights.to_f16(); + let ds = data_settings.clone(); + weights.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap(); + } else { + weights = weights.to_f32(); + } + } + #[cfg(not(feature = "opencl"))] + { + weights = weights.to_f32(); + } + Ok(Self { eps, weight: weights, + data_settings, }) } - fn forward(&self, x: &Tensor) -> Tensor { + fn forward(&self, x: &mut Tensor) -> Tensor { + #[cfg(feature = "opencl")] + let x_was_on_cpu: bool; + #[cfg(feature = "opencl")] + { + x_was_on_cpu = x.is_on_cpu(); + if self.data_settings.use_opencl_for_rmsnorm && x_was_on_cpu { + *x = x.to_f16(); + x.to_gpu(self.data_settings.cl.as_ref().unwrap()).unwrap(); + } + } let inner = x.pow(2.0).mean_cols().add_scalar(self.eps as f32); let out1 = x.scalar_multiply_broadcast(&inner.rsqrt()); - out1.hadamard_product_broadcast(&self.weight) + let mut result = out1.hadamard_product_broadcast(&self.weight); + #[cfg(feature = "opencl")] + { + if x_was_on_cpu { + result.to_cpu().unwrap(); + } + result + } + #[cfg(not(feature = "opencl"))] + { + result + } } } @@ -410,6 +502,10 @@ impl FeedForward { w1.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap(); w2.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap(); w3.to_gpu(&ds.cl.unwrap()).unwrap(); + } else { + w1 = w1.to_f32(); + w2 = w2.to_f32(); + w3 = w3.to_f32(); } } #[cfg(not(feature = "opencl"))] @@ -433,7 +529,7 @@ impl FeedForward { #[cfg(feature = "opencl")] { x_was_on_cpu = x.is_on_cpu(); - if self.data_settings.use_opencl_for_feedforward { + if self.data_settings.use_opencl_for_feedforward && x_was_on_cpu { *x = x.to_f16(); x.to_gpu(self.data_settings.cl.as_ref().unwrap()).unwrap(); } @@ -514,6 +610,11 @@ impl Attention { wk.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap(); wv.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap(); wo.to_gpu(&ds.cl.unwrap()).unwrap(); + } else { + wq = wq.to_f32(); + wk = wk.to_f32(); + wv = wv.to_f32(); + wo = wo.to_f32(); } } #[cfg(not(feature = "opencl"))] @@ -548,7 +649,7 @@ impl Attention { #[cfg(feature = "opencl")] { x_was_on_cpu = x.is_on_cpu(); - if self.data_settings.use_opencl_for_attention { + if self.data_settings.use_opencl_for_attention && x_was_on_cpu { *x = x.to_f16(); x.to_gpu(self.data_settings.cl.as_ref().unwrap()).unwrap(); } @@ -686,8 +787,21 @@ impl Attention { .collect(); let output3: Vec<&Tensor> = output2.iter().collect(); - let output2: Tensor = Tensor::concat(&output3); - output2 + let mut output2: Tensor = Tensor::concat(&output3); + + #[cfg(feature = "opencl")] + { + if x_was_on_cpu { + output2.to_cpu().unwrap(); + return output2.to_f32(); + } else { + return output2; + } + } + #[cfg(not(feature = "opencl"))] + { + output2 + } } }