diff --git a/src/rllama_main.rs b/src/rllama_main.rs index aaff326..d1328be 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -39,6 +39,9 @@ struct Cli { #[arg(long)] repetition_penalty: Option, + #[arg(long)] + max_threads: Option, + #[arg(long, action)] f16: bool, @@ -63,6 +66,17 @@ pub fn main() -> Result<(), Box> { let tokenizer_path = cli.tokenizer_path; let param_path = cli.param_path; + let max_threads: usize = match cli.max_threads { + None => rayon::current_num_threads(), + Some(max_threads) => { + rayon::ThreadPoolBuilder::new() + .num_threads(max_threads) + .build_global() + .unwrap(); + max_threads + } + }; + let mut be_quiet: bool = false; if !colored::control::SHOULD_COLORIZE.should_colorize() { be_quiet = true; @@ -218,6 +232,8 @@ pub fn main() -> Result<(), Box> { pln!(" norm_eps: {}", params.norm_eps); pln!(" vocab_size: {}", params.vocab_size); pln!("---"); + pln!(" maximum number of threads: {}", max_threads); + pln!("---"); pln!("Max sequence length: {}", max_seq_len); pln!("Temperature: {}", token_sampler.get_temperature()); pln!("Top P: {}", token_sampler.get_top_p()); diff --git a/src/tensor.rs b/src/tensor.rs index 0691ae7..2bb35c4 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1096,6 +1096,8 @@ impl Tensor { /// Matrix multiplication done in-place, but the second matrix is transposed. /// With this, you can avoid using .transpose() on the second matrix. pub fn matrix_mul_inplace_transposed(&mut self, src: &Tensor, other: &Tensor) { + let nthreads: usize = rayon::current_num_threads(); + #[cfg(feature = "opencl")] if self.is_on_gpu() && src.is_on_gpu() && other.is_on_gpu() { self.matrix_mul_inplace_transposed_gpu(src, other); @@ -1165,7 +1167,8 @@ impl Tensor { let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data); let other_data: WrappedPtr = WrappedPtr::wrap(other.data); let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data); - (0..32).into_par_iter().for_each(|thread_idx| { + + (0..nthreads).into_par_iter().for_each(|thread_idx| { let src_data: *const f32 = src_data_wrap.unwrap() as *const f32; let other_data: *const f32 = other_data.unwrap() as *const f32; let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32; @@ -1176,7 +1179,7 @@ impl Tensor { let row3 = row * 4 + 3; for col in 0..self_cols_its { let row_col = row * self_cols_its + col; - if row_col % 32 != thread_idx { + if row_col % nthreads != thread_idx { continue; } let col0 = col * 4; @@ -1386,7 +1389,7 @@ impl Tensor { let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data); let other_data: WrappedPtr = WrappedPtr::wrap(other.data); let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data); - (0..32).into_par_iter().for_each(|thread_idx| { + (0..nthreads).into_par_iter().for_each(|thread_idx| { let src_data: *const f16 = src_data_wrap.unwrap() as *const f16; let other_data: *const f16 = other_data.unwrap() as *const f16; let tgt_data: *mut f16 = tgt_data.unwrap() as *mut f16; @@ -1397,7 +1400,7 @@ impl Tensor { let row3 = row * 4 + 3; for col in 0..self_cols_its { let row_col = row * self_cols_its + col; - if row_col % 32 != thread_idx { + if row_col % nthreads != thread_idx { continue; } let col0 = col * 4;