From 63d27dba9091823f8ba11a270ab5790d6f597311 Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Mon, 13 Mar 2023 17:11:00 -0700 Subject: [PATCH] Add partial OpenCL support, it's used in feed forward network only. --- src/rllama_main.rs | 20 ++- src/tensor.rs | 260 ++++++++++++++++++++++++++++++++++- src/tensor_opencl_support.rs | 255 ++++++++++++++++++++++++++++------ src/transformer.rs | 121 ++++++++++++++-- 4 files changed, 588 insertions(+), 68 deletions(-) diff --git a/src/rllama_main.rs b/src/rllama_main.rs index 2698fb1..f097346 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -3,7 +3,7 @@ use crate::embedding::Embedding; use crate::tensor_opencl_support::OpenCL; use crate::token_sampler::TokenSampler; use crate::tokenizer::{TokenId, Tokenizer}; -use crate::transformer::Transformer; +use crate::transformer::{DataSettings, Transformer}; use crate::unpickler; use crate::unpickler::Value; use clap::Parser; @@ -38,6 +38,7 @@ struct Cli { top_k: Option, #[cfg(feature = "opencl")] + #[arg(long)] opencl_device: Option, } @@ -63,7 +64,7 @@ pub fn main() -> Result<(), Box> { } #[cfg(feature = "opencl")] - let _opencl: Option = { + let opencl: Option = { let opencl_device = cli.opencl_device.unwrap_or(0); match OpenCL::new(!be_quiet, opencl_device) { Err(openclerr) => { @@ -154,6 +155,20 @@ pub fn main() -> Result<(), Box> { None => 1024, }; + let data_settings = { + #[cfg(feature = "opencl")] + { + if let Some(opencl) = opencl { + let ds = DataSettings::new(Some(opencl)); + ds.use_opencl() + } else { + DataSettings::new(None) + } + } + #[cfg(not(feature = "opencl"))] + DataSettings::new() + }; + pln!("Loading transformer weights from {}...", model_path); let tr = Transformer::from_unpickled( &unpickle_results, @@ -163,6 +178,7 @@ pub fn main() -> Result<(), Box> { params.n_heads, max_seq_len, params.norm_eps, + data_settings, model_path, )?; pln!("All is loaded. Starting inference."); diff --git a/src/tensor.rs b/src/tensor.rs index de6306f..176ec93 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -151,6 +151,18 @@ fn horizontal_sum(mut ymm: __m256) -> f32 { } impl Tensor { + #[inline] + pub fn assume_on_gpu(&self) { + #[cfg(feature = "opencl")] + { + self.process_waiting_for_data(); + let od = self.opencl_data.read().unwrap(); + if !od.is_some() { + panic!("Tried to assume_on_gpu on a tensor that is on the CPU"); + } + } + } + #[inline] pub fn assume_on_cpu(&self) { #[cfg(feature = "opencl")] @@ -544,14 +556,52 @@ impl Tensor { } pub fn hadamard_product(&self, other: &Tensor) -> Tensor { - self.assume_on_cpu(); - other.assume_on_cpu(); if self.cols != other.cols || self.rows != other.rows { panic!( "Invalid hadamard product: incompatible shapes, {}x{} vs {}x{}", self.rows, self.cols, other.rows, other.cols ); } + #[cfg(feature = "opencl")] + { + if self.is_on_gpu() { + self.hadamard_product_gpu(other) + } else { + self.hadamard_product_cpu(other) + } + } + #[cfg(not(feature = "opencl"))] + { + self.hadamard_product_cpu(other) + } + } + + #[cfg(feature = "opencl")] + fn hadamard_product_gpu(&self, other: &Tensor) -> Tensor { + // Assume: sizes have been checked already + self.assume_on_gpu(); + other.assume_on_gpu(); + + self.with_opencl_data(|self_tensor| { + let cl = self_tensor.cl(); + // TODO: do not create a CPU-side copy + let result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + let mut result = result.to_f16(); + result.to_gpu(&cl).unwrap(); + result.with_opencl_data_mut(|tgt_tensor| { + tgt_tensor.copy_inplace(self_tensor).unwrap(); + other.with_opencl_data(|other_tensor| { + tgt_tensor.hadamard_product_inplace(other_tensor).unwrap(); + }); + }); + result + }) + } + + fn hadamard_product_cpu(&self, other: &Tensor) -> Tensor { + // Assume: sizes have been checked already + self.assume_on_cpu(); + other.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { for col in 0..self.cols { @@ -595,7 +645,60 @@ impl Tensor { } pub fn silu(&self) -> Tensor { - self.assume_on_cpu(); + #[cfg(feature = "opencl")] + { + if self.is_on_gpu() { + self.silu_gpu() + } else { + self.silu_cpu() + } + } + #[cfg(not(feature = "opencl"))] + { + self.silu_cpu() + } + } + + // with_opencl_data & with_opencl_data_mut are utilities to get access to the underlying + // OpenCLTensor, if the tensor is on gpu. Panics if they are not on GPU. + #[cfg(feature = "opencl")] + fn with_opencl_data(&self, f: F) -> R + where + F: FnOnce(&OpenCLTensor) -> R, + { + let opencl_data = self.opencl_data.read().unwrap(); + let opencl_data = opencl_data.as_ref(); + f(opencl_data.unwrap()) + } + + #[cfg(feature = "opencl")] + fn with_opencl_data_mut(&mut self, f: F) -> R + where + F: FnOnce(&mut OpenCLTensor) -> R, + { + let mut opencl_data = self.opencl_data.write().unwrap(); + let opencl_data = opencl_data.as_mut(); + f(opencl_data.unwrap()) + } + + #[cfg(feature = "opencl")] + fn silu_gpu(&self) -> Tensor { + self.assume_on_gpu(); + self.with_opencl_data(|src_tensor| { + let cl: OpenCL = src_tensor.cl(); + // TODO: don't generate a CPU-side copy, create the result directly on OpenCL side + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + result = result.to_f16(); + result.to_gpu(&cl).unwrap(); + result.with_opencl_data_mut(|tgt_tensor| { + tgt_tensor.copy_inplace(src_tensor).unwrap(); + tgt_tensor.silu_inplace().unwrap(); + }); + result + }) + } + + fn silu_cpu(&self) -> Tensor { let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { for col in 0..self.cols { @@ -608,6 +711,37 @@ impl Tensor { } pub fn transpose(&self) -> Tensor { + #[cfg(feature = "opencl")] + { + if self.is_on_gpu() { + self.transpose_gpu() + } else { + self.transpose_cpu() + } + } + #[cfg(not(feature = "opencl"))] + { + self.transpose_cpu() + } + } + + #[cfg(feature = "opencl")] + fn transpose_gpu(&self) -> Tensor { + self.assume_on_gpu(); + self.with_opencl_data(|src_tensor| { + let cl: OpenCL = src_tensor.cl(); + // TODO: don't generate a CPU-side copy, create the result directly on OpenCL side + let mut result = unsafe { Tensor::uninitialized(self.cols, self.rows, self.dtype) }; + result = result.to_f16(); + result.to_gpu(&cl).unwrap(); + result.with_opencl_data_mut(|tgt_tensor| { + tgt_tensor.transpose_from(src_tensor).unwrap(); + }); + result + }) + } + + fn transpose_cpu(&self) -> Tensor { self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.cols, self.rows, self.dtype) }; for row in 0..self.rows { @@ -665,18 +799,27 @@ impl Tensor { } pub fn matrix_mul_transposed(&self, other: &Tensor) -> Tensor { - self.assume_on_cpu(); - other.assume_on_cpu(); if self.cols != other.cols { panic!( "Invalid matrix transposed multiplication {}x{} vs {}x{}", self.rows, self.cols, other.cols, other.rows ); } + #[cfg(not(feature = "opencl"))] if other.rows == 1 { return self.matrix_vector_mul_transposed(other); } + #[cfg(feature = "opencl")] + if other.rows == 1 && self.is_on_cpu() { + return self.matrix_vector_mul_transposed(other); + } let mut result = unsafe { Tensor::uninitialized(self.rows, other.rows, self.dtype) }; + #[cfg(feature = "opencl")] + if self.is_on_gpu() { + let od = self.opencl_data.write().unwrap(); + result.to_gpu(&od.as_ref().unwrap().cl()).unwrap(); + } + result.matrix_mul_inplace_transposed(self, other); result } @@ -839,6 +982,11 @@ impl Tensor { false } + #[cfg(feature = "opencl")] + pub fn is_on_cpu(&self) -> bool { + return !self.is_on_gpu(); + } + #[cfg(feature = "opencl")] fn matrix_mul_inplace_transposed_gpu(&mut self, src: &Tensor, other: &Tensor) { let mut self_od = self.opencl_data.write().unwrap(); @@ -2031,10 +2179,110 @@ mod tests { } } + #[cfg(feature = "opencl")] + #[test] + fn gpu_silu_and_cpu_silu_agree() { + let cl = OpenCL::new(false, 0).unwrap(); + + for _trial in 0..300 { + let mut rng = rand::thread_rng(); + let a = rng.gen_range(1..=300); + let b = rng.gen_range(1..=300); + let mat1 = Tensor::random(a, b, TensorDType::Float16); + let mat2 = mat1.clone(); + let mut mat2 = mat2.to_f16(); + mat2.to_gpu(&cl).unwrap(); + + let mat1_result = mat1.silu(); + let mut mat2_result = mat2.silu(); + mat2_result.to_cpu().unwrap(); + + assert_eq!(mat1_result.rows(), mat2_result.rows()); + assert_eq!(mat1_result.cols(), mat2_result.cols()); + + for row in 0..mat1_result.rows { + for col in 0..mat1_result.cols { + assert_relative_eq!( + mat1_result.get_f32(row, col), + mat2_result.get_f32(row, col), + epsilon = 1e-2 + ); + } + } + } + } + + #[cfg(feature = "opencl")] + #[test] + fn gpu_hadamard_product_and_cpu_hadamard_product_agree() { + let cl = OpenCL::new(false, 0).unwrap(); + + for _trial in 0..300 { + let mut rng = rand::thread_rng(); + let a = rng.gen_range(1..=300); + let b = rng.gen_range(1..=300); + let mat1 = Tensor::random(a, b, TensorDType::Float16); + let mat2 = Tensor::random(a, b, TensorDType::Float16); + + let mut mat1_gpu = mat1.to_f16(); + let mut mat2_gpu = mat2.to_f16(); + mat1_gpu.to_gpu(&cl).unwrap(); + mat2_gpu.to_gpu(&cl).unwrap(); + + let result1 = mat1.hadamard_product(&mat2); + let mut result2 = mat1_gpu.hadamard_product(&mat2_gpu); + result2.to_cpu().unwrap(); + + assert_eq!(result1.rows(), result2.rows()); + assert_eq!(result1.cols(), result2.cols()); + + for row in 0..result1.rows() { + for col in 0..result2.cols() { + assert_relative_eq!( + result1.get_f32(row, col), + result2.get_f32(row, col), + epsilon = 1e-2 + ); + } + } + } + } + + #[cfg(feature = "opencl")] + #[test] + fn gpu_transpose_and_cpu_transpose_agree() { + let cl = OpenCL::new(false, 0).unwrap(); + let mut rng = rand::thread_rng(); + for _trial in 0..300 { + let a = rng.gen_range(1..=100); + let b = rng.gen_range(1..=100); + let mat1 = Tensor::random(a, b, TensorDType::Float16); + let mut mat1_gpu = mat1.to_f16(); + mat1_gpu.to_gpu(&cl).unwrap(); + + let mat1_transposed = mat1.transpose(); + let mut mat1_gpu_transposed = mat1_gpu.transpose(); + mat1_gpu_transposed.to_cpu().unwrap(); + + assert_eq!(mat1_transposed.rows(), mat1_gpu_transposed.rows()); + assert_eq!(mat1_transposed.cols(), mat1_gpu_transposed.cols()); + + for row in 0..mat1_transposed.rows { + for col in 0..mat1_transposed.cols { + assert_relative_eq!( + mat1_transposed.get_f32(row, col), + mat1_gpu_transposed.get_f32(row, col), + epsilon = 1e-2, + ); + } + } + } + } + #[cfg(feature = "opencl")] #[test] fn gpu_matrix_mul_transposed_is_close_to_cpu_matrix_mul_transposed() { - let cl = OpenCL::new(true, 1).unwrap(); + let cl = OpenCL::new(false, 0).unwrap(); let mut rng = rand::thread_rng(); for _trial in 0..300 { diff --git a/src/tensor_opencl_support.rs b/src/tensor_opencl_support.rs index 7f571ab..59e8ace 100644 --- a/src/tensor_opencl_support.rs +++ b/src/tensor_opencl_support.rs @@ -12,9 +12,15 @@ use thiserror::Error; struct Programs { matrix_mul_transposed_by_row_f16_program: Program, matrix_mul_transposed_by_row_f16: Kernel, + silu_f16_program: Program, + silu_f16: Kernel, + hadamard_product_f16_program: Program, + hadamard_product_f16: Kernel, + transpose_f16_program: Program, + transpose_f16: Kernel, } -#[derive(Debug)] +#[derive(Debug, Clone)] #[allow(dead_code)] pub struct OpenCL { ctx: Context, @@ -34,7 +40,7 @@ pub struct OpenCLTensor { cols: i64, cols_capacity: i64, queue: Queue, - programs: Arc>, + cl: OpenCL, } #[derive(Debug)] @@ -143,13 +149,17 @@ impl OpenCL { cols, cols_capacity, queue: self.queue.clone(), - programs: self.programs.clone(), + cl: self.clone(), }) } } } impl OpenCLTensor { + pub fn cl(&self) -> OpenCL { + self.cl.clone() + } + pub fn wait_until_ready(&mut self) { if self.last_event.is_some() { self.last_event.as_ref().unwrap().wait_for().unwrap(); @@ -187,6 +197,93 @@ impl OpenCLTensor { } } + /// Copies all values from another tensor + pub fn copy_inplace(&mut self, other: &OpenCLTensor) -> Result { + if other.rows != self.rows || other.cols != self.cols { + panic!( + "Cannot in-place copy tensors of different sizes: {}x{} <-- {}x{}", + self.rows, self.cols, other.rows, other.cols + ); + } + let mut event = Event::empty(); + other + .buf + .cmd() + .queue(&other.queue) + .copy(&self.buf, None, None) + .enew(&mut event) + .enq()?; + self.last_event = Some(event.clone()); + Ok(OpenCLEvent { event }) + } + + pub fn transpose_from(&mut self, other: &OpenCLTensor) -> Result { + let prg = self.cl.programs.write().unwrap(); + prg.transpose_f16.set_arg(0, self.buf.clone()).unwrap(); + prg.transpose_f16.set_arg(1, other.buf.clone()).unwrap(); + prg.transpose_f16 + .set_arg(2, self.cols_capacity as i32) + .unwrap(); + prg.transpose_f16 + .set_arg(3, other.cols_capacity as i32) + .unwrap(); + let mut event = Event::empty(); + unsafe { + let b = prg + .transpose_f16 + .cmd() + .queue(&self.queue) + .global_work_size([self.rows as usize, self.cols as usize]) + .enew(&mut event); + b.enq().unwrap(); + } + self.last_event = Some(event.clone()); + Ok(OpenCLEvent { event }) + } + + pub fn hadamard_product_inplace( + &mut self, + other: &OpenCLTensor, + ) -> Result { + let prg = self.cl.programs.write().unwrap(); + prg.hadamard_product_f16.set_arg(0, self.buf.clone())?; + prg.hadamard_product_f16.set_arg(1, other.buf.clone())?; + prg.hadamard_product_f16 + .set_arg(2, self.cols_capacity as i32)?; + prg.hadamard_product_f16 + .set_arg(3, other.cols_capacity as i32)?; + let mut event = Event::empty(); + unsafe { + let b = prg + .hadamard_product_f16 + .cmd() + .queue(&self.queue) + .global_work_size([self.rows as usize, self.cols as usize]) + .enew(&mut event); + b.enq()?; + } + self.last_event = Some(event.clone()); + Ok(OpenCLEvent { event }) + } + + pub fn silu_inplace(&mut self) -> Result { + let prg = self.cl.programs.write().unwrap(); + prg.silu_f16.set_arg(0, self.buf.clone())?; + prg.silu_f16.set_arg(1, self.cols_capacity as i32)?; + let mut event = Event::empty(); + unsafe { + let b = prg + .silu_f16 + .cmd() + .queue(&self.queue) + .global_work_size([self.rows as usize, self.cols as usize]) + .enew(&mut event); + b.enq()?; + } + self.last_event = Some(event.clone()); + Ok(OpenCLEvent { event }) + } + pub fn matrix_mul_inplace_transposed( &mut self, src: &OpenCLTensor, @@ -208,7 +305,7 @@ impl OpenCLTensor { // Clear out the target memory unsafe { self.buf.cmd().fill(0u16, None).block(false).enq()? }; - let prg = self.programs.write().unwrap(); + let prg = self.cl.programs.write().unwrap(); prg.matrix_mul_transposed_by_row_f16 .set_arg(0, self.buf.clone())?; prg.matrix_mul_transposed_by_row_f16 @@ -234,7 +331,7 @@ impl OpenCLTensor { .matrix_mul_transposed_by_row_f16 .cmd() .queue(&self.queue) - .global_work_size([self.rows as usize, self.cols_capacity as usize]) + .global_work_size([self.rows as usize, self.cols as usize]) .enew(&mut event); b.enq()?; } @@ -251,47 +348,65 @@ impl OpenCLEvent { } fn make_programs(ctx: &Context, queue: &Queue) -> Result { - let mut last_err: Option = None; - // There used to be more sources here but now it's just one. This can go through programs and - // accept first one that compiles - for src in &[MATRIX_MUL_TRANSPOSED_BY_ROW_F16_SRC] { - fn make_programs_with_src( - ctx: &Context, - queue: &Queue, - src: &str, - ) -> Result { - let program = Program::builder().src(src).build(&ctx)?; - let kernel = Kernel::builder() - .program(&program) - .name("matrix_mul_transposed_by_row_f16") - .arg(None::<&Buffer>) - .arg(None::<&Buffer>) - .arg(None::<&Buffer>) - .arg(&0) - .arg(&0) - .arg(&0) - .arg(&0) - .arg(&0) - .arg(&0) - .queue(queue.clone()) - .build()?; - Ok(Programs { - matrix_mul_transposed_by_row_f16_program: program, - matrix_mul_transposed_by_row_f16: kernel, - }) - } - match make_programs_with_src(ctx, queue, src) { - Err(e) => { - last_err = Some(e); - continue; - } - Ok(p) => return Ok(p), - } - } - if last_err.is_none() { - unreachable!(); + fn make_program_with_src(ctx: &Context, src: &str) -> Result { + let program = Program::builder().src(src).build(&ctx)?; + Ok(program) } - Err(last_err.unwrap()) + + let matrix_mul_transposed_by_row_f16_program = + make_program_with_src(ctx, MATRIX_MUL_TRANSPOSED_BY_ROW_F16_SRC)?; + let matrix_mul_transposed_by_row_f16 = Kernel::builder() + .program(&matrix_mul_transposed_by_row_f16_program) + .name("matrix_mul_transposed_by_row_f16") + .arg(None::<&Buffer>) + .arg(None::<&Buffer>) + .arg(None::<&Buffer>) + .arg(&0) + .arg(&0) + .arg(&0) + .arg(&0) + .arg(&0) + .arg(&0) + .queue(queue.clone()) + .build()?; + let silu_f16_program = make_program_with_src(ctx, SILU_F16_SRC)?; + let silu_f16 = Kernel::builder() + .program(&silu_f16_program) + .name("silu_f16") + .arg(None::<&Buffer>) + .arg(&0) + .queue(queue.clone()) + .build()?; + let hadamard_product_f16_program = make_program_with_src(ctx, HADAMARD_PRODUCT_F16_SRC)?; + let hadamard_product_f16 = Kernel::builder() + .program(&hadamard_product_f16_program) + .name("hadamard_product_f16") + .arg(None::<&Buffer>) + .arg(None::<&Buffer>) + .arg(&0) + .arg(&0) + .queue(queue.clone()) + .build()?; + let transpose_f16_program = make_program_with_src(ctx, TRANSPOSE_F16_SRC)?; + let transpose_f16 = Kernel::builder() + .program(&transpose_f16_program) + .name("transpose_f16") + .arg(None::<&Buffer>) + .arg(None::<&Buffer>) + .arg(&0) + .arg(&0) + .queue(queue.clone()) + .build()?; + Ok(Programs { + matrix_mul_transposed_by_row_f16_program, + matrix_mul_transposed_by_row_f16, + silu_f16_program, + silu_f16, + hadamard_product_f16_program, + hadamard_product_f16, + transpose_f16_program, + transpose_f16, + }) } const MATRIX_MUL_TRANSPOSED_BY_ROW_F16_SRC: &str = r#" @@ -367,3 +482,53 @@ __kernel void matrix_mul_transposed_by_row_f16( vstore_half(total, 0, (__global half*) &tgt[tgt_row * ncols_capacity + tgt_col]); } "#; + +/// Computes SILU for every f16 value in the tensor +const SILU_F16_SRC: &str = r#" +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void silu_f16(__global half *tgt, + const int ncols_capacity) +{ + const int tgt_row = get_global_id(0); + const int tgt_col = get_global_id(1); + const float val = vload_half(tgt_row * ncols_capacity + tgt_col, (__global const half*) tgt); + const float result = val * (1.0 / (1.0 + exp(-val))); + vstore_half(result, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt); +} +"#; + +/// Computes hadamard product of two identially sized tensors +const HADAMARD_PRODUCT_F16_SRC: &str = r#" +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void hadamard_product_f16(__global half *tgt, + __global const half *left, + const int ncols_capacity, + const int left_cols_capacity) { + const int tgt_row = get_global_id(0); + const int tgt_col = get_global_id(1); + const float tgt_value = vload_half(tgt_row * ncols_capacity + tgt_col, (__global const half*) tgt); + const float left_value = vload_half(tgt_row * left_cols_capacity + tgt_col, (__global const half*) left); + const float result = tgt_value * left_value; + vstore_half(result, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt); +} +"#; + +/// Computes the transpose of a matrix +const TRANSPOSE_F16_SRC: &str = r#" +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void transpose_f16(__global half *tgt, + __global const half *left, + const int ncols_capacity, + const int left_cols_capacity) +{ + const int tgt_row = get_global_id(0); + const int tgt_col = get_global_id(1); + const int src_row = tgt_col; + const int src_col = tgt_row; + const float val = vload_half(src_row * left_cols_capacity + src_col, (__global const half*) left); + vstore_half(val, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt); +} +"#; diff --git a/src/transformer.rs b/src/transformer.rs index 1eeb0f5..fba2fb7 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -1,5 +1,7 @@ use crate::embedding::Embedding; use crate::tensor::{FromPiecesDirection, Tensor, TensorDType}; +#[cfg(feature = "opencl")] +use crate::tensor_opencl_support::OpenCL; use crate::tokenizer::TokenId; use crate::unpickler; use crate::unpickler::UnpicklingError; @@ -28,6 +30,43 @@ pub struct Transformer { layers: Vec, } +// Clone is cheap +#[derive(Clone)] +pub struct DataSettings { + #[cfg(feature = "opencl")] + use_opencl_for_feedforward: bool, + #[cfg(feature = "opencl")] + cl: Option, +} + +// OpenCL is safe to send to threads but Rust doesn't know that +unsafe impl Send for DataSettings {} +unsafe impl Sync for DataSettings {} + +impl DataSettings { + #[cfg(feature = "opencl")] + pub fn new(cl: Option) -> Self { + DataSettings { + use_opencl_for_feedforward: false, + cl: cl.clone(), + } + } + + #[cfg(not(feature = "opencl"))] + pub fn new() -> Self { + DataSettings {} + } + + #[cfg(feature = "opencl")] + pub fn use_opencl(mut self) -> DataSettings { + if self.cl.is_none() { + panic!("OpenCL is not available, cannot call use_opencl() on DataSettings."); + } + self.use_opencl_for_feedforward = true; + self + } +} + pub struct TransformerCaches { layer_caches: Vec, } @@ -105,10 +144,12 @@ pub struct Attention { head_dim: usize, } +#[allow(dead_code)] pub struct FeedForward { w1: Tensor, w2: Tensor, w3: Tensor, + data_settings: DataSettings, } impl Transformer { @@ -121,6 +162,7 @@ impl Transformer { n_heads: usize, max_seq_len: usize, eps: f64, + data_settings: DataSettings, data_dir: P, ) -> Result { assert_eq!(dim % n_heads, 0); @@ -141,6 +183,7 @@ impl Transformer { eps, n_local_heads, head_dim, + data_settings.clone(), data_dir, ); progress_bar.inc(1); @@ -238,10 +281,11 @@ impl TransformerBlock { eps: f64, n_local_heads: usize, head_dim: usize, + data_settings: DataSettings, data_dir: P, ) -> Result { let data_dir: &Path = data_dir.as_ref(); - let ff = FeedForward::from_unpickled(unpickled, layer_id, data_dir)?; + let ff = FeedForward::from_unpickled(unpickled, layer_id, data_dir, data_settings)?; let attn = Attention::from_unpickled(unpickled, layer_id, n_local_heads, head_dim, data_dir)?; let ffn_norm = RMSNorm::from_unpickled( @@ -277,8 +321,8 @@ impl TransformerBlock { .attn .forward(&attnorm_out, start_pos, freqs_cis, mask, attention_cache); let h = x.add(&att_out); - let att_out = self.ffn_norm.forward(&h); - let att_out = self.feed_forward.forward(&att_out).transpose(); + let mut att_out = self.ffn_norm.forward(&h); + let att_out = self.feed_forward.forward(&mut att_out).transpose(); h.add(&att_out) } } @@ -316,35 +360,70 @@ impl FeedForward { unpickled: &[unpickler::Value], layer_id: usize, data_dir: P, + data_settings: DataSettings, ) -> Result { let data_dir: &Path = data_dir.as_ref(); - let w1 = Tensor::from_unpickled_pieces( + let mut w1 = Tensor::from_unpickled_pieces( unpickled, format!("layers.{}.feed_forward.w1.weight", layer_id), data_dir, FromPiecesDirection::Rows, - )? - .to_f32(); - let w2 = Tensor::from_unpickled_pieces( + )?; + let mut w2 = Tensor::from_unpickled_pieces( unpickled, format!("layers.{}.feed_forward.w2.weight", layer_id), data_dir, FromPiecesDirection::Cols, - )? - .to_f32(); - let w3 = Tensor::from_unpickled_pieces( + )?; + let mut w3 = Tensor::from_unpickled_pieces( unpickled, format!("layers.{}.feed_forward.w3.weight", layer_id), data_dir, FromPiecesDirection::Rows, - )? - .to_f32(); + )?; - Ok(Self { w1, w2, w3 }) + #[cfg(feature = "opencl")] + { + if data_settings.use_opencl_for_feedforward { + w1 = w1.to_f16(); + w2 = w2.to_f16(); + w3 = w3.to_f16(); + let ds = data_settings.clone(); + w1.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap(); + w2.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap(); + w3.to_gpu(&ds.cl.unwrap()).unwrap(); + } + } + #[cfg(not(feature = "opencl"))] + { + w1 = w1.to_f32(); + w2 = w2.to_f32(); + w3 = w3.to_f32(); + } + + Ok(Self { + w1, + w2, + w3, + data_settings, + }) } - pub fn forward(&self, x: &Tensor) -> Tensor { + pub fn forward(&self, x: &mut Tensor) -> Tensor { + #[cfg(feature = "opencl")] + let x_was_on_cpu: bool; + #[cfg(feature = "opencl")] + { + x_was_on_cpu = x.is_on_cpu(); + } + #[cfg(feature = "opencl")] + { + if self.data_settings.use_opencl_for_feedforward { + *x = x.to_f16(); + x.to_gpu(self.data_settings.cl.as_ref().unwrap()).unwrap(); + } + } let (w1_out, w3_out) = rayon::join( || self.w1.matrix_mul_transposed(x), || self.w3.matrix_mul_transposed(x), @@ -352,12 +431,24 @@ impl FeedForward { let w1_out = w1_out.silu(); let w1w3_out = w1_out.hadamard_product(&w3_out).transpose(); + #[cfg(not(feature = "opencl"))] if w1w3_out.rows() == 1 { return self .w2 .matrix_vector_mul_transposed_multithreaded(&w1w3_out); + } else { + return self.w2.matrix_mul_transposed(&w1w3_out); + } + #[cfg(feature = "opencl")] + { + let mut result = self.w2.matrix_mul_transposed(&w1w3_out); + if x_was_on_cpu { + result.to_cpu().unwrap(); + result + } else { + result + } } - self.w2.matrix_mul_transposed(&w1w3_out) } }