diff --git a/src/rllama_main.rs b/src/rllama_main.rs index 84dfa1d..e6df193 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -53,6 +53,9 @@ struct Cli { #[arg(long, action)] f16: bool, + #[arg(long, action)] + k4: bool, + #[cfg(feature = "opencl")] #[arg(long)] opencl_device: Option, @@ -233,6 +236,9 @@ pub fn main() -> Result<(), Box> { if cli.f16 { data_settings = data_settings.force_f16(); } + if cli.k4 { + data_settings = data_settings.force_k4(); + } pln!("Loading transformer weights from {}...", model_path); let tr = Transformer::from_unpickled( diff --git a/src/tensor.rs b/src/tensor.rs index df1e793..27c8216 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -921,7 +921,10 @@ impl Tensor { // We don't have implementation for f16, so don't use the vector function if we have // f16 #[cfg(not(feature = "opencl"))] - if other.rows == 1 && other.dtype != TensorDType::Float16 { + if other.rows == 1 + && (other.dtype != TensorDType::K4BitQuantization + && self.dtype != TensorDType::K4BitQuantization) + { return self.matrix_vector_mul_transposed(other); } #[cfg(feature = "opencl")] @@ -1117,8 +1120,8 @@ impl Tensor { // Casts data type to whatever the other tensors data type is. pub fn to_same_type(&self, other: &Tensor) -> Tensor { - let result = self.clone(); - if result.dtype() == other.dtype() { + if self.dtype() == other.dtype() { + let result = self.clone(); return result; } match other.dtype { @@ -1128,6 +1131,35 @@ impl Tensor { } } + // Casts data type so that it is valid to do: other.matrix_mul_transposed(result) + pub fn to_compatible_matrix_mul_type(&self, other: &Tensor) -> Tensor { + if self.dtype() != TensorDType::K4BitQuantization + && other.dtype() != TensorDType::K4BitQuantization + { + return self.to_same_type(other); + } + if other.dtype() == TensorDType::K4BitQuantization { + return self.to_f32(); + } + unimplemented!() + } + + // Casts data type so that it is valid to do: result.matrix_mul_transposed(other) + pub fn to_compatible_matrix_mul_type2(&self, other: &Tensor) -> Tensor { + if self.dtype() != TensorDType::K4BitQuantization + && other.dtype() != TensorDType::K4BitQuantization + { + return self.to_same_type(other); + } + if other.dtype() == TensorDType::Float32 { + return self.clone(); + } + if other.dtype() == TensorDType::K4BitQuantization { + return self.to_f32(); + } + unimplemented!() + } + pub fn into_same_type(self, other: &Tensor) -> Tensor { if self.dtype() == other.dtype() { return self; @@ -1170,6 +1202,216 @@ impl Tensor { std::mem::drop(other_od); } + fn matrix_mul_inplace_transposed_f32_and_k4bit(&mut self, src: &Tensor, other: &Tensor) { + // Assume: size checks have been done already. + assert!(src.dtype == TensorDType::Float32); + assert!(other.dtype == TensorDType::K4BitQuantization); + assert!(self.dtype == TensorDType::Float32); + + unsafe { + let src_rows: usize = src.rows as usize; + let src_cols: usize = src.cols as usize; + let src_cols_capacity: usize = src.capacity_cols as usize; + let other_cols: usize = other.cols as usize; + let other_rows: usize = other.rows as usize; + let other_cols_capacity: usize = other.capacity_cols as usize; + let self_rows: usize = self.rows as usize; + let self_cols: usize = self.cols as usize; + let self_cols_capacity: usize = self.capacity_cols as usize; + + // src_cols_its == also the shared dimension between src and other. + let src_cols_its = if src_cols % 32 == 0 { + src_cols / 32 + } else { + src_cols / 32 + 1 + }; + debug_assert!(!other.q4_data.is_null()); + + let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data); + let other_data: WrappedPtr = WrappedPtr::wrap(other.data); + let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data); + let other_q4_data: WrappedPtr = WrappedPtr::wrap(other.q4_data); + + let nthreads: usize = rayon::current_num_threads(); + (0..nthreads).into_par_iter().for_each(|thread_idx| { + let other_q4_data: *const u8 = other_q4_data.unwrap() as *const u8; + let src_data: *const f32 = src_data_wrap.unwrap() as *const f32; + let other_data: *const u8 = other_data.unwrap() as *const u8; + let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32; + + for row in 0..self_rows { + for col in 0..self_cols { + let row_col = row * self_cols + col; + if row_col % nthreads != thread_idx { + continue; + } + + let quant0 = load_i16x8(other_q4_data.add(col * 32) as *const I16x8); + let quant1 = load_i16x8(other_q4_data.add(col * 32 + 16) as *const I16x8); + let quants: [F32x8; 2] = + [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; + + #[inline] + fn load_f32( + src: *const f32, + row: usize, + col: usize, + ncols: usize, + nrows: usize, + cols_capacity: usize, + ) -> F32x8 { + unsafe { + if row >= nrows || col >= ncols { + f32x8_zero() + } else { + load_f32x8(src.add(row * cols_capacity + col) as *const F32x8) + } + } + } + + #[inline] + fn load_k4_to_f32( + tensor: &Tensor, + row: usize, + col: usize, + nrows: usize, + quants: *const F32x8, + ) -> (F32x8, F32x8, F32x8, F32x8) { + unsafe { + let m: u32 = 0xFFFFFFFF; + let masks: [I32x8; 8] = [ + i32x8_from_values_u32(m, m, m, m, m, m, m, m), + i32x8_from_values_u32(0, m, m, m, m, m, m, m), + i32x8_from_values_u32(0, 0, m, m, m, m, m, m), + i32x8_from_values_u32(0, 0, 0, m, m, m, m, m), + i32x8_from_values_u32(0, 0, 0, 0, m, m, m, m), + i32x8_from_values_u32(0, 0, 0, 0, 0, m, m, m), + i32x8_from_values_u32(0, 0, 0, 0, 0, 0, m, m), + i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, m), + ]; + let nomask: I32x8 = i32x8_from_values_u32(m, m, m, m, m, m, m, m); + let fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0); + + if row < nrows { + let col = col as i64; + let ncols = tensor.cols; + let (addr, side) = tensor.q4_address(row as i64, col); + let i = load_i16x8(addr as *const I16x8); + let even_mask = i16x8_singleton_u16(0x0F0F); + let odd_mask = i16x8_singleton_u16(0xF0F0); + let evens = and_i16x8(i, even_mask); + let odds = and_i16x8(i, odd_mask); + let odds = shift_right_by_4_i16x8(odds); + + let indices1 = extend_i8_to_i32_i32x8(odds); + let odds_shifted = shift_right_by_64_i128(odds); + let indices2 = extend_i8_to_i32_i32x8(odds_shifted); + let indices3 = extend_i8_to_i32_i32x8(evens); + let indices4 = + extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens)); + + let unquantized1: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices1); + let unquantized2: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices2); + let unquantized3: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices3); + let unquantized4: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices4); + let quan1_mask: I32x8 = if col <= ncols - 8 { + nomask + } else if col < ncols { + masks[(col % 8) as usize] + } else { + fullmask + }; + let quan2_mask: I32x8 = if col <= ncols - 16 { + nomask + } else if col < ncols - 8 { + masks[(col % 8) as usize] + } else { + fullmask + }; + let quan3_mask: I32x8 = if col <= ncols - 24 { + nomask + } else if col < ncols - 16 { + masks[(col % 8) as usize] + } else { + fullmask + }; + let quan4_mask: I32x8 = if col <= ncols - 32 { + nomask + } else if col < ncols - 24 { + masks[(col % 8) as usize] + } else { + fullmask + }; + let unquantized1 = and_f32x8(unquantized1, quan1_mask); + let unquantized2 = and_f32x8(unquantized2, quan2_mask); + let unquantized3 = and_f32x8(unquantized3, quan3_mask); + let unquantized4 = and_f32x8(unquantized4, quan4_mask); + (unquantized1, unquantized2, unquantized3, unquantized4) + } else { + (f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()) + } + } + } + + let mut targets8: [F32x8; 4] = + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; + + for p in 0..src_cols_its { + let (other8_0, other8_1, other8_2, other8_3) = + load_k4_to_f32(&other, col, p * 32, other_rows, quants.as_ptr()); + let src8_0 = load_f32( + src_data, + row, + p * 32, + src_cols, + src_rows, + src_cols_capacity, + ); + let src8_1 = load_f32( + src_data, + row, + p * 32 + 8, + src_cols, + src_rows, + src_cols_capacity, + ); + let src8_2 = load_f32( + src_data, + row, + p * 32 + 16, + src_cols, + src_rows, + src_cols_capacity, + ); + let src8_3 = load_f32( + src_data, + row, + p * 32 + 24, + src_cols, + src_rows, + src_cols_capacity, + ); + targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]); + targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]); + targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]); + targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]); + } + let target0 = horizontal_sum_f32x8(targets8[0]); + let target1 = horizontal_sum_f32x8(targets8[1]); + let target2 = horizontal_sum_f32x8(targets8[2]); + let target3 = horizontal_sum_f32x8(targets8[3]); + let target = target0 + target1 + target2 + target3; + *tgt_data.add(row * self_cols_capacity + col) = target; + } + } + }); + } + } + fn matrix_mul_inplace_transposed_k4bit_and_f32(&mut self, src: &Tensor, other: &Tensor) { // Assume: size checks have been done already. assert!(src.dtype == TensorDType::K4BitQuantization); @@ -1246,19 +1488,19 @@ impl Tensor { quants: *const F32x8, ) -> (F32x8, F32x8, F32x8, F32x8) { unsafe { - let M: u32 = 0xFFFFFFFF; - let MASKS: [I32x8; 8] = [ - i32x8_from_values_u32(M, M, M, M, M, M, M, M), - i32x8_from_values_u32(0, M, M, M, M, M, M, M), - i32x8_from_values_u32(0, 0, M, M, M, M, M, M), - i32x8_from_values_u32(0, 0, 0, M, M, M, M, M), - i32x8_from_values_u32(0, 0, 0, 0, M, M, M, M), - i32x8_from_values_u32(0, 0, 0, 0, 0, M, M, M), - i32x8_from_values_u32(0, 0, 0, 0, 0, 0, M, M), - i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, M), + let m: u32 = 0xFFFFFFFF; + let masks: [I32x8; 8] = [ + i32x8_from_values_u32(m, m, m, m, m, m, m, m), + i32x8_from_values_u32(0, m, m, m, m, m, m, m), + i32x8_from_values_u32(0, 0, m, m, m, m, m, m), + i32x8_from_values_u32(0, 0, 0, m, m, m, m, m), + i32x8_from_values_u32(0, 0, 0, 0, m, m, m, m), + i32x8_from_values_u32(0, 0, 0, 0, 0, m, m, m), + i32x8_from_values_u32(0, 0, 0, 0, 0, 0, m, m), + i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, m), ]; - let NOMASK: I32x8 = i32x8_from_values_u32(M, M, M, M, M, M, M, M); - let FULLMASK: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0); + let nomask: I32x8 = i32x8_from_values_u32(m, m, m, m, m, m, m, m); + let fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0); if row < nrows { let col = col as i64; @@ -1287,32 +1529,32 @@ impl Tensor { let unquantized4: F32x8 = gather_scale4_f32x8(quants as *const f32, indices4); let quan1_mask: I32x8 = if col <= ncols - 8 { - NOMASK + nomask } else if col < ncols { - MASKS[(col % 8) as usize] + masks[(col % 8) as usize] } else { - FULLMASK + fullmask }; let quan2_mask: I32x8 = if col <= ncols - 16 { - NOMASK + nomask } else if col < ncols - 8 { - MASKS[(col % 8) as usize] + masks[(col % 8) as usize] } else { - FULLMASK + fullmask }; let quan3_mask: I32x8 = if col <= ncols - 24 { - NOMASK + nomask } else if col < ncols - 16 { - MASKS[(col % 8) as usize] + masks[(col % 8) as usize] } else { - FULLMASK + fullmask }; let quan4_mask: I32x8 = if col <= ncols - 32 { - NOMASK + nomask } else if col < ncols - 24 { - MASKS[(col % 8) as usize] + masks[(col % 8) as usize] } else { - FULLMASK + fullmask }; let unquantized1 = and_f32x8(unquantized1, quan1_mask); let unquantized2 = and_f32x8(unquantized2, quan2_mask); @@ -1401,6 +1643,7 @@ impl Tensor { } if src.dtype != other.dtype && (src.dtype != TensorDType::K4BitQuantization || other.dtype != TensorDType::Float32) + && (src.dtype != TensorDType::Float32 || other.dtype != TensorDType::K4BitQuantization) { panic!("Invalid matrix multiplication, different dtypes"); } @@ -1414,6 +1657,9 @@ impl Tensor { if src.dtype == TensorDType::K4BitQuantization && other.dtype == TensorDType::Float32 { return self.matrix_mul_inplace_transposed_k4bit_and_f32(src, other); } + if src.dtype == TensorDType::Float32 && other.dtype == TensorDType::K4BitQuantization { + return self.matrix_mul_inplace_transposed_f32_and_k4bit(src, other); + } match src.dtype { TensorDType::K4BitQuantization => unimplemented!(), @@ -1844,6 +2090,14 @@ impl Tensor { ); } assert_eq!(other.rows, 1); + + // K4 bit currently has no implementation for matrix_vector_mul + if self.dtype == TensorDType::K4BitQuantization + || other.dtype == TensorDType::K4BitQuantization + { + return self.matrix_mul_transposed(other); + } + assert_eq!(other.dtype, self.dtype); #[allow(unreachable_patterns)] @@ -3471,7 +3725,7 @@ mod tests { } #[test] - fn quantized_matrices_matrix_mul_transposed_correctly() { + fn quantized_matrices_matrix_mul_transposed_correctly_k4_mul_f32() { let mut rng = rand::thread_rng(); for _ in 0..100 { let a = rng.gen_range(1..=128); @@ -3543,4 +3797,79 @@ mod tests { } } } + + #[test] + fn quantized_matrices_matrix_mul_transposed_correctly_f32_mul_k4() { + let mut rng = rand::thread_rng(); + for _ in 0..100 { + let a = rng.gen_range(1..=128); + let b = rng.gen_range(1..=128); + let c = rng.gen_range(1..=128); + let other_matrix = Tensor::random(a, c, TensorDType::Float32); + let mut reference = Tensor::zeros(b, c, TensorDType::Float32); + + let mut quant_values: Vec> = Vec::with_capacity(c as usize); + for row in 0..b { + let mut quant_values_for_row: Vec = Vec::with_capacity(16); + for _ in 0..16 { + quant_values_for_row.push(rng.gen_range(0.0..=1.0)); + } + quant_values.push(quant_values_for_row); + } + + let mut quantized_values: Vec> = Vec::with_capacity(b as usize); + for row in 0..b { + let mut quant_values_for_row: Vec = Vec::with_capacity(c as usize); + for col in 0..c { + let i = rng.gen_range(0..=15); + reference.set_f32(row, col, quant_values[row as usize][i as usize]); + quant_values_for_row.push(i as u8); + } + quantized_values.push(quant_values_for_row); + } + + let quantized = Tensor::make_k4bit_from_fn( + b, + c, + |row, col| quantized_values[row as usize][col as usize], + |row| { + let mut result: [f32; 16] = [0.0; 16]; + for col in 0..16 { + result[col] = quant_values[row as usize][col]; + } + result + }, + ); + + assert_eq!(reference.rows(), quantized.rows()); + assert_eq!(reference.cols(), quantized.cols()); + + for row in 0..reference.rows { + for col in 0..reference.cols { + // The quantized table always uses f16 so values may not be 100% equal. + assert_relative_eq!( + reference.get_f32(row, col), + quantized.get_f32(row, col), + epsilon = 1e-1, + ); + } + } + + let mult1 = other_matrix.matrix_mul_transposed(&reference); + let mult2 = other_matrix.matrix_mul_transposed(&quantized); + + assert_eq!(mult1.rows(), mult2.rows()); + assert_eq!(mult1.cols(), mult2.cols()); + + for row in 0..mult1.rows { + for col in 0..mult1.cols { + assert_relative_eq!( + mult1.get_f32(row, col), + mult2.get_f32(row, col), + epsilon = 1e-1, + ); + } + } + } + } } diff --git a/src/transformer.rs b/src/transformer.rs index cf3dc6c..5730d3c 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -33,7 +33,7 @@ pub struct Transformer { } // Clone is cheap -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct DataSettings { #[cfg(feature = "opencl")] use_opencl_for_feedforward: bool, @@ -43,6 +43,7 @@ pub struct DataSettings { cl: Option, force_f16: bool, + force_k4: bool, } // OpenCL is safe to send to threads but Rust doesn't know that @@ -56,6 +57,7 @@ impl DataSettings { use_opencl_for_feedforward: false, use_opencl_for_attention: false, force_f16: false, + force_k4: false, cl: cl.clone(), } } @@ -63,7 +65,10 @@ impl DataSettings { #[allow(clippy::new_without_default)] #[cfg(not(feature = "opencl"))] pub fn new() -> Self { - DataSettings { force_f16: false } + DataSettings { + force_f16: false, + force_k4: false, + } } #[cfg(feature = "opencl")] @@ -78,6 +83,17 @@ impl DataSettings { pub fn force_f16(mut self) -> DataSettings { self.force_f16 = true; + if self.force_k4 { + panic!("Cannot force both f16 and k4"); + } + self + } + + pub fn force_k4(mut self) -> DataSettings { + self.force_k4 = true; + if self.force_f16 { + panic!("Cannot force both f16 and k4"); + } self } } @@ -457,13 +473,14 @@ impl FeedForward { FromPiecesDirection::Rows, )?; - w1 = crate::weight_compression::quantize(&w1); - panic!("stop"); - if data_settings.force_f16 { w1 = w1.to_f16(); w2 = w2.to_f16(); w3 = w3.to_f16(); + } else if data_settings.force_k4 { + w1 = w1.quantize(); + w2 = w2.quantize(); + w3 = w3.quantize(); } #[cfg(feature = "opencl")] @@ -491,7 +508,7 @@ impl FeedForward { pub fn forward(&self, x: &mut Tensor) -> Tensor { let original_x_dtype = x.dtype(); if x.dtype() != self.w1.dtype() { - *x = x.to_same_type(&self.w1); + *x = x.to_compatible_matrix_mul_type(&self.w1); } #[cfg(feature = "opencl")] let x_was_on_cpu: bool; @@ -516,7 +533,7 @@ impl FeedForward { let w1_out = w1_out.silu(); let mut w1w3_out = w1_out.hadamard_product(&w3_out).transpose(); if w1w3_out.dtype() != self.w2.dtype() { - w1w3_out = w1w3_out.to_same_type(&self.w2); + w1w3_out = w1w3_out.to_compatible_matrix_mul_type(&self.w2); } #[cfg(not(feature = "opencl"))] { @@ -578,6 +595,11 @@ impl Attention { wk = wk.to_f16(); wv = wv.to_f16(); wo = wo.to_f16(); + } else if data_settings.force_k4 { + wq = wq.quantize(); + wk = wk.quantize(); + wv = wv.quantize(); + wo = wo.quantize(); } #[cfg(feature = "opencl")] @@ -616,7 +638,7 @@ impl Attention { ) -> Tensor { let original_x_dtype = x.dtype(); if x.dtype() != self.wq.dtype() { - *x = x.to_same_type(&self.wq); + *x = x.to_compatible_matrix_mul_type2(&self.wq); } #[cfg(feature = "opencl")] @@ -744,7 +766,7 @@ impl Attention { { let xq_row = Tensor::concat(&concat_vec2).view(1, self.wo.rows()); xq_row - .into_same_type(&self.wo) + .to_compatible_matrix_mul_type2(&self.wo) .matrix_mul_transposed(&self.wo) } #[cfg(feature = "opencl")] diff --git a/src/weight_compression.rs b/src/weight_compression.rs index 6f5f9d1..a39af64 100644 --- a/src/weight_compression.rs +++ b/src/weight_compression.rs @@ -9,7 +9,7 @@ pub fn quantize(tensor: &Tensor) -> Tensor { * This is a simplistic rounding quantizer. It splits each row in a tensor to 16 buckets and * takes the average value in said buckets as the quantized weight. */ - let mut result = Tensor::zeros(tensor.rows(), tensor.cols(), tensor.dtype()); + let mut allowed_values_by_row: Vec> = Vec::with_capacity(tensor.rows() as usize); for row in 0..tensor.rows() { let mut values: Vec = Vec::with_capacity(tensor.cols() as usize); values.truncate(0); @@ -43,8 +43,15 @@ pub fn quantize(tensor: &Tensor) -> Tensor { allowed_values[0] = mi; allowed_values[15] = ma; allowed_values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + allowed_values_by_row.push(allowed_values); + } - for col in 0..tensor.cols() { + //let mut result = Tensor::zeros(tensor.rows(), tensor.cols(), tensor.dtype()); + let result = Tensor::make_k4bit_from_fn( + tensor.rows(), + tensor.cols(), + |row, col| { + let allowed_values: &[f32] = &allowed_values_by_row[row as usize]; let val = tensor.get_f32(row, col); let mut best = 0; let mut best_dist = std::f32::MAX; @@ -55,8 +62,16 @@ pub fn quantize(tensor: &Tensor) -> Tensor { best_dist = dist; } } - result.set_f32(row, col, allowed_values[best] as f32); - } - } + best as u8 + }, + |row: i64| -> [f32; 16] { + let allowed_values: &[f32] = &allowed_values_by_row[row as usize]; + let mut result: [f32; 16] = [0.0; 16]; + for i in 0..16 { + result[i] = allowed_values[i]; + } + result + }, + ); result }