diff --git a/src/transformer.rs b/src/transformer.rs index fba2fb7..4d158b8 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -36,6 +36,8 @@ pub struct DataSettings { #[cfg(feature = "opencl")] use_opencl_for_feedforward: bool, #[cfg(feature = "opencl")] + use_opencl_for_attention: bool, + #[cfg(feature = "opencl")] cl: Option, } @@ -48,6 +50,7 @@ impl DataSettings { pub fn new(cl: Option) -> Self { DataSettings { use_opencl_for_feedforward: false, + use_opencl_for_attention: false, cl: cl.clone(), } } @@ -63,6 +66,7 @@ impl DataSettings { panic!("OpenCL is not available, cannot call use_opencl() on DataSettings."); } self.use_opencl_for_feedforward = true; + self.use_opencl_for_attention = true; self } } @@ -142,6 +146,7 @@ pub struct Attention { wo: Tensor, n_local_heads: usize, head_dim: usize, + data_settings: DataSettings, } #[allow(dead_code)] @@ -285,9 +290,15 @@ impl TransformerBlock { data_dir: P, ) -> Result { let data_dir: &Path = data_dir.as_ref(); - let ff = FeedForward::from_unpickled(unpickled, layer_id, data_dir, data_settings)?; - let attn = - Attention::from_unpickled(unpickled, layer_id, n_local_heads, head_dim, data_dir)?; + let ff = FeedForward::from_unpickled(unpickled, layer_id, data_dir, data_settings.clone())?; + let attn = Attention::from_unpickled( + unpickled, + layer_id, + n_local_heads, + head_dim, + data_settings, + data_dir, + )?; let ffn_norm = RMSNorm::from_unpickled( unpickled, format!("layers.{}.ffn_norm.weight", layer_id), @@ -316,10 +327,16 @@ impl TransformerBlock { mask: &Option, attention_cache: &mut AttentionCache, ) -> Tensor { - let attnorm_out = self.attention_norm.forward(x); - let att_out = self - .attn - .forward(&attnorm_out, start_pos, freqs_cis, mask, attention_cache); + let mut attnorm_out = self.attention_norm.forward(x); + let att_out = self.attn.forward( + &mut attnorm_out, + start_pos, + freqs_cis, + mask, + attention_cache, + ); + std::mem::drop(attnorm_out); + let h = x.add(&att_out); let mut att_out = self.ffn_norm.forward(&h); let att_out = self.feed_forward.forward(&mut att_out).transpose(); @@ -416,9 +433,6 @@ impl FeedForward { #[cfg(feature = "opencl")] { x_was_on_cpu = x.is_on_cpu(); - } - #[cfg(feature = "opencl")] - { if self.data_settings.use_opencl_for_feedforward { *x = x.to_f16(); x.to_gpu(self.data_settings.cl.as_ref().unwrap()).unwrap(); @@ -458,38 +472,57 @@ impl Attention { layer_id: usize, n_local_heads: usize, head_dim: usize, + data_settings: DataSettings, data_dir: P, ) -> Result { let data_dir: &Path = data_dir.as_ref(); - let wq = Tensor::from_unpickled_pieces( + let mut wq = Tensor::from_unpickled_pieces( unpickled, format!("layers.{}.attention.wq.weight", layer_id), data_dir, FromPiecesDirection::Rows, - )? - .to_f32(); - let wk = Tensor::from_unpickled_pieces( + )?; + let mut wk = Tensor::from_unpickled_pieces( unpickled, format!("layers.{}.attention.wk.weight", layer_id), data_dir, FromPiecesDirection::Rows, - )? - .to_f32(); - let wv = Tensor::from_unpickled_pieces( + )?; + let mut wv = Tensor::from_unpickled_pieces( unpickled, format!("layers.{}.attention.wv.weight", layer_id), data_dir, FromPiecesDirection::Rows, - )? - .to_f32(); - let wo = Tensor::from_unpickled_pieces( + )?; + let mut wo = Tensor::from_unpickled_pieces( unpickled, format!("layers.{}.attention.wo.weight", layer_id), data_dir, FromPiecesDirection::Cols, - )? - .to_f32(); + )?; + + #[cfg(feature = "opencl")] + { + if data_settings.use_opencl_for_attention { + wq = wq.to_f16(); + wk = wk.to_f16(); + wv = wv.to_f16(); + wo = wo.to_f16(); + let ds = data_settings.clone(); + wq.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap(); + wk.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap(); + wv.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap(); + wo.to_gpu(&ds.cl.unwrap()).unwrap(); + } + } + #[cfg(not(feature = "opencl"))] + { + wq = wq.to_f32(); + wk = wk.to_f32(); + wv = wv.to_f32(); + wo = wo.to_f32(); + } Ok(Self { wq, @@ -498,18 +531,42 @@ impl Attention { wo, n_local_heads, head_dim, + data_settings, }) } fn forward( &self, - x: &Tensor, + x: &mut Tensor, start_pos: usize, freqs_cis: &FreqsCis, mask: &Option, attention_cache: &mut AttentionCache, ) -> Tensor { + #[cfg(feature = "opencl")] + let x_was_on_cpu: bool; + #[cfg(feature = "opencl")] + { + x_was_on_cpu = x.is_on_cpu(); + if self.data_settings.use_opencl_for_attention { + *x = x.to_f16(); + x.to_gpu(self.data_settings.cl.as_ref().unwrap()).unwrap(); + } + } + let seq_len = x.rows(); + #[cfg(feature = "opencl")] + let (xq_out, xk_out, xv_out) = { + let mut xq_out = x.matrix_mul_transposed(&self.wq); + let mut xk_out = x.matrix_mul_transposed(&self.wk); + let mut xv_out = x.matrix_mul_transposed(&self.wv); + xq_out.to_cpu().unwrap(); + xk_out.to_cpu().unwrap(); + xv_out.to_cpu().unwrap(); + (xq_out.to_f32(), xk_out.to_f32(), xv_out.to_f32()) + }; + + #[cfg(not(feature = "opencl"))] let (xq_out, (xk_out, xv_out)) = rayon::join( || x.matrix_mul_transposed(&self.wq), || { @@ -604,8 +661,27 @@ impl Attention { concat_vec.push(output.row(idx)); } let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect(); - let xq_row = Tensor::concat(&concat_vec2).view(1, self.wo.rows()); - xq_row.matrix_mul_transposed(&self.wo) + #[cfg(not(feature = "opencl"))] + { + let xq_row = Tensor::concat(&concat_vec2).view(1, self.wo.rows()); + xq_row.matrix_mul_transposed(&self.wo) + } + #[cfg(feature = "opencl")] + { + let mut xq_row = Tensor::concat(&concat_vec2) + .view(1, self.wo.rows()) + .to_f16(); + if self.wo.is_on_gpu() { + xq_row + .to_gpu(&self.data_settings.cl.as_ref().unwrap()) + .unwrap(); + let mut result = xq_row.matrix_mul_transposed(&self.wo); + result.to_cpu().unwrap(); + result.to_f32() + } else { + xq_row.matrix_mul_transposed(&self.wo) + } + } }) .collect();