diff --git a/src/benches/benchmark.rs b/src/benches/benchmark.rs index 80175ce..0ca62ef 100644 --- a/src/benches/benchmark.rs +++ b/src/benches/benchmark.rs @@ -113,6 +113,7 @@ pub fn tensor_benchmarks(c: &mut Criterion) { let m2_f16 = m2.to_f16(); let quant = m1.quantize(); + let quant2 = m2.quantize(); c.bench_function( "1024x128 * 1x128 matrix vector transposed multiplication, k4 quantized * f32", @@ -122,9 +123,17 @@ pub fn tensor_benchmarks(c: &mut Criterion) { }) }, ); + c.bench_function( + "1024x128 * 1x128 matrix vector transposed multiplication, f32 quantized * k4", + |b| { + b.iter(|| { + let _ = m1.matrix_vector_mul_transposed(black_box(&quant2)); + }) + }, + ); c.bench_function( - "matrix multiplication 8x4096 @ 4096x4096 k8 quantized * f32 in-place, transposed", + "matrix multiplication 8x4096 @ 4096x4096 k4 quantized * f32 in-place, transposed", |b| { b.iter(|| { let _ = result_84096.matrix_mul_inplace_transposed( diff --git a/src/simd_support.rs b/src/simd_support.rs index 613c537..487443c 100644 --- a/src/simd_support.rs +++ b/src/simd_support.rs @@ -266,11 +266,6 @@ pub fn shift_right_by_64_i128(a: I16x8) -> I16x8 { unsafe { _mm_srli_si128(a, 64 / 8) } } -// Shuffle/premute -pub fn shuffle_i16x8(a: I16x8, permutation: I16x8) -> I16x8 { - unsafe { _mm_shuffle_epi8(a, permutation) } -} - // Extends 8 i8 values into 7 i16 values // // XXYYZZ -> 00XX00YY00ZZ diff --git a/src/st.rs b/src/st.rs new file mode 100644 index 0000000..9da8ffb --- /dev/null +++ b/src/st.rs @@ -0,0 +1,419 @@ + fn matrix_mul_inplace_transposed_f32_and_k4bit(&mut self, src: &Tensor, other: &Tensor) { + // Assume: size checks have been done already. + assert!(src.dtype == TensorDType::Float32); + assert!(other.dtype == TensorDType::K4BitQuantization); + assert!(self.dtype == TensorDType::Float32); + + unsafe { + let src_rows: usize = src.rows as usize; + let src_cols: usize = src.cols as usize; + let src_cols_capacity: usize = src.capacity_cols as usize; + let other_cols: usize = other.cols as usize; + let other_rows: usize = other.rows as usize; + let other_cols_capacity: usize = other.capacity_cols as usize; + let self_rows: usize = self.rows as usize; + let self_cols: usize = self.cols as usize; + let self_cols_capacity: usize = self.capacity_cols as usize; + + // src_cols_its == also the shared dimension between src and other. + let src_cols_its = if src_cols % 32 == 0 { + src_cols / 32 + } else { + src_cols / 32 + 1 + }; + debug_assert!(!other.q4_data.is_null()); + + let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data); + let other_data: WrappedPtr = WrappedPtr::wrap(other.data); + let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data); + let other_q4_data: WrappedPtr = WrappedPtr::wrap(other.q4_data); + + let nthreads: usize = rayon::current_num_threads(); + (0..nthreads).into_par_iter().for_each(|thread_idx| { + let other_q4_data: *const u8 = other_q4_data.unwrap() as *const u8; + let src_data: *const f32 = src_data_wrap.unwrap() as *const f32; + let other_data: *const u8 = other_data.unwrap() as *const u8; + let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32; + + for row in 0..self_rows { + for col in 0..self_cols { + let row_col = row * self_cols + col; + if row_col % nthreads != thread_idx { + continue; + } + + let quant0 = load_i16x8(other_q4_data.add(col * 32) as *const I16x8); + let quant1 = load_i16x8(other_q4_data.add(col * 32 + 16) as *const I16x8); + let quants: [F32x8; 2] = + [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; + + #[inline] + fn load_f32( + src: *const f32, + row: usize, + col: usize, + ncols: usize, + nrows: usize, + cols_capacity: usize, + ) -> F32x8 { + unsafe { + if row >= nrows || col >= ncols { + f32x8_zero() + } else { + load_f32x8(src.add(row * cols_capacity + col) as *const F32x8) + } + } + } + + #[inline] + fn load_k4_to_f32( + tensor: &Tensor, + row: usize, + col: usize, + nrows: usize, + quants: *const F32x8, + ) -> (F32x8, F32x8, F32x8, F32x8) { + unsafe { + let m: u32 = 0xFFFFFFFF; + let masks: [I32x8; 8] = [ + i32x8_from_values_u32(m, m, m, m, m, m, m, m), + i32x8_from_values_u32(0, m, m, m, m, m, m, m), + i32x8_from_values_u32(0, 0, m, m, m, m, m, m), + i32x8_from_values_u32(0, 0, 0, m, m, m, m, m), + i32x8_from_values_u32(0, 0, 0, 0, m, m, m, m), + i32x8_from_values_u32(0, 0, 0, 0, 0, m, m, m), + i32x8_from_values_u32(0, 0, 0, 0, 0, 0, m, m), + i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, m), + ]; + let nomask: I32x8 = i32x8_from_values_u32(m, m, m, m, m, m, m, m); + let fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0); + + if row < nrows { + let col = col as i64; + let ncols = tensor.cols; + let (addr, side) = tensor.q4_address(row as i64, col); + let i = load_i16x8(addr as *const I16x8); + let even_mask = i16x8_singleton_u16(0x0F0F); + let odd_mask = i16x8_singleton_u16(0xF0F0); + let evens = and_i16x8(i, even_mask); + let odds = and_i16x8(i, odd_mask); + let odds = shift_right_by_4_i16x8(odds); + + let indices1 = extend_i8_to_i32_i32x8(odds); + let odds_shifted = shift_right_by_64_i128(odds); + let indices2 = extend_i8_to_i32_i32x8(odds_shifted); + let indices3 = extend_i8_to_i32_i32x8(evens); + let indices4 = + extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens)); + + let unquantized1: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices1); + let unquantized2: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices2); + let unquantized3: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices3); + let unquantized4: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices4); + let quan1_mask: I32x8 = if col <= ncols - 8 { + nomask + } else if col < ncols { + masks[(col % 8) as usize] + } else { + fullmask + }; + let quan2_mask: I32x8 = if col <= ncols - 16 { + nomask + } else if col < ncols - 8 { + masks[(col % 8) as usize] + } else { + fullmask + }; + let quan3_mask: I32x8 = if col <= ncols - 24 { + nomask + } else if col < ncols - 16 { + masks[(col % 8) as usize] + } else { + fullmask + }; + let quan4_mask: I32x8 = if col <= ncols - 32 { + nomask + } else if col < ncols - 24 { + masks[(col % 8) as usize] + } else { + fullmask + }; + let unquantized1 = and_f32x8(unquantized1, quan1_mask); + let unquantized2 = and_f32x8(unquantized2, quan2_mask); + let unquantized3 = and_f32x8(unquantized3, quan3_mask); + let unquantized4 = and_f32x8(unquantized4, quan4_mask); + (unquantized1, unquantized2, unquantized3, unquantized4) + } else { + (f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()) + } + } + } + + let mut targets8: [F32x8; 4] = + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; + + for p in 0..src_cols_its { + let (other8_0, other8_1, other8_2, other8_3) = + load_k4_to_f32(&other, col, p * 32, other_rows, quants.as_ptr()); + let src8_0 = load_f32( + src_data, + row, + p * 32, + src_cols, + src_rows, + src_cols_capacity, + ); + let src8_1 = load_f32( + src_data, + row, + p * 32 + 8, + src_cols, + src_rows, + src_cols_capacity, + ); + let src8_2 = load_f32( + src_data, + row, + p * 32 + 16, + src_cols, + src_rows, + src_cols_capacity, + ); + let src8_3 = load_f32( + src_data, + row, + p * 32 + 24, + src_cols, + src_rows, + src_cols_capacity, + ); + targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]); + targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]); + targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]); + targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]); + } + let target0 = horizontal_sum_f32x8(targets8[0]); + let target1 = horizontal_sum_f32x8(targets8[1]); + let target2 = horizontal_sum_f32x8(targets8[2]); + let target3 = horizontal_sum_f32x8(targets8[3]); + let target = target0 + target1 + target2 + target3; + *tgt_data.add(row * self_cols_capacity + col) = target; + } + } + }); + } + } + + fn matrix_mul_inplace_transposed_k4bit_and_f32(&mut self, src: &Tensor, other: &Tensor) { + // Assume: size checks have been done already. + assert!(src.dtype == TensorDType::K4BitQuantization); + assert!(other.dtype == TensorDType::Float32); + assert!(self.dtype == TensorDType::Float32); + + unsafe { + let src_rows: usize = src.rows as usize; + let src_cols: usize = src.cols as usize; + let src_cols_capacity: usize = src.capacity_cols as usize; + let other_cols: usize = other.cols as usize; + let other_rows: usize = other.rows as usize; + let other_cols_capacity: usize = other.capacity_cols as usize; + let self_rows: usize = self.rows as usize; + let self_cols: usize = self.cols as usize; + let self_cols_capacity: usize = self.capacity_cols as usize; + + // src_cols_its == also the shared dimension between src and other. + let src_cols_its = if src_cols % 32 == 0 { + src_cols / 32 + } else { + src_cols / 32 + 1 + }; + debug_assert!(!src.q4_data.is_null()); + + let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data); + let other_data: WrappedPtr = WrappedPtr::wrap(other.data); + let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data); + let src_q4_data: WrappedPtr = WrappedPtr::wrap(src.q4_data); + + let nthreads: usize = rayon::current_num_threads(); + (0..nthreads).into_par_iter().for_each(|thread_idx| { + let src_q4_data: *const u8 = src_q4_data.unwrap() as *const u8; + let src_data: *const u8 = src_data_wrap.unwrap() as *const u8; + let other_data: *const f32 = other_data.unwrap() as *const f32; + let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32; + + for row in 0..self_rows { + let quant0 = load_i16x8(src_q4_data.add(row * 32) as *const I16x8); + let quant1 = load_i16x8(src_q4_data.add(row * 32 + 16) as *const I16x8); + let quants: [F32x8; 2] = + [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; + + for col in 0..self_cols { + let row_col = row * self_cols + col; + if row_col % nthreads != thread_idx { + continue; + } + + #[inline] + fn load_f32( + other: *const f32, + row: usize, + col: usize, + ncols: usize, + nrows: usize, + cols_capacity: usize, + ) -> F32x8 { + unsafe { + if row >= nrows || col >= ncols { + f32x8_zero() + } else { + load_f32x8(other.add(row * cols_capacity + col) as *const F32x8) + } + } + } + + #[inline] + fn load_k4_to_f32( + tensor: &Tensor, + row: usize, + col: usize, + nrows: usize, + quants: *const F32x8, + ) -> (F32x8, F32x8, F32x8, F32x8) { + unsafe { + let m: u32 = 0xFFFFFFFF; + let masks: [I32x8; 8] = [ + i32x8_from_values_u32(m, m, m, m, m, m, m, m), + i32x8_from_values_u32(0, m, m, m, m, m, m, m), + i32x8_from_values_u32(0, 0, m, m, m, m, m, m), + i32x8_from_values_u32(0, 0, 0, m, m, m, m, m), + i32x8_from_values_u32(0, 0, 0, 0, m, m, m, m), + i32x8_from_values_u32(0, 0, 0, 0, 0, m, m, m), + i32x8_from_values_u32(0, 0, 0, 0, 0, 0, m, m), + i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, m), + ]; + let nomask: I32x8 = i32x8_from_values_u32(m, m, m, m, m, m, m, m); + let fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0); + + if row < nrows { + let col = col as i64; + let ncols = tensor.cols; + let (addr, side) = tensor.q4_address(row as i64, col); + let i = load_i16x8(addr as *const I16x8); + let even_mask = i16x8_singleton_u16(0x0F0F); + let odd_mask = i16x8_singleton_u16(0xF0F0); + let evens = and_i16x8(i, even_mask); + let odds = and_i16x8(i, odd_mask); + let odds = shift_right_by_4_i16x8(odds); + + let indices1 = extend_i8_to_i32_i32x8(odds); + let odds_shifted = shift_right_by_64_i128(odds); + let indices2 = extend_i8_to_i32_i32x8(odds_shifted); + let indices3 = extend_i8_to_i32_i32x8(evens); + let indices4 = + extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens)); + + let unquantized1: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices1); + let unquantized2: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices2); + let unquantized3: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices3); + let unquantized4: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices4); + let quan1_mask: I32x8 = if col <= ncols - 8 { + nomask + } else if col < ncols { + masks[(col % 8) as usize] + } else { + fullmask + }; + let quan2_mask: I32x8 = if col <= ncols - 16 { + nomask + } else if col < ncols - 8 { + masks[(col % 8) as usize] + } else { + fullmask + }; + let quan3_mask: I32x8 = if col <= ncols - 24 { + nomask + } else if col < ncols - 16 { + masks[(col % 8) as usize] + } else { + fullmask + }; + let quan4_mask: I32x8 = if col <= ncols - 32 { + nomask + } else if col < ncols - 24 { + masks[(col % 8) as usize] + } else { + fullmask + }; + let unquantized1 = and_f32x8(unquantized1, quan1_mask); + let unquantized2 = and_f32x8(unquantized2, quan2_mask); + let unquantized3 = and_f32x8(unquantized3, quan3_mask); + let unquantized4 = and_f32x8(unquantized4, quan4_mask); + (unquantized1, unquantized2, unquantized3, unquantized4) + } else { + (f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()) + } + } + } + + let mut targets8: [F32x8; 4] = + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; + + for p in 0..src_cols_its { + let other8_0: F32x8 = load_f32( + other_data, + col, + p * 32, + other_cols, + other_rows, + other_cols_capacity, + ); + let other8_1: F32x8 = load_f32( + other_data, + col, + p * 32 + 8, + other_cols, + other_rows, + other_cols_capacity, + ); + let other8_2: F32x8 = load_f32( + other_data, + col, + p * 32 + 16, + other_cols, + other_rows, + other_cols_capacity, + ); + let other8_3: F32x8 = load_f32( + other_data, + col, + p * 32 + 24, + other_cols, + other_rows, + other_cols_capacity, + ); + let (src8_0, src8_1, src8_2, src8_3): (F32x8, F32x8, F32x8, F32x8) = + load_k4_to_f32(&src, row, p * 32, src_rows, quants.as_ptr()); + targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]); + targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]); + targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]); + targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]); + } + let target0 = horizontal_sum_f32x8(targets8[0]); + let target1 = horizontal_sum_f32x8(targets8[1]); + let target2 = horizontal_sum_f32x8(targets8[2]); + let target3 = horizontal_sum_f32x8(targets8[3]); + let target = target0 + target1 + target2 + target3; + *tgt_data.add(row * self_cols_capacity + col) = target; + } + } + }); + } + } diff --git a/src/tensor.rs b/src/tensor.rs index 27c8216..6f80224 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -225,6 +225,24 @@ fn compute_capacity_cols_f16(cols: i64) -> i64 { } } +lazy_static! { + static ref m: u32 = 0xFFFFFFFF; + static ref masks: [I32x8; 8] = [ + i32x8_from_values_u32(*m, *m, *m, *m, *m, *m, *m, *m), + i32x8_from_values_u32(0, *m, *m, *m, *m, *m, *m, *m), + i32x8_from_values_u32(0, 0, *m, *m, *m, *m, *m, *m), + i32x8_from_values_u32(0, 0, 0, *m, *m, *m, *m, *m), + i32x8_from_values_u32(0, 0, 0, 0, *m, *m, *m, *m), + i32x8_from_values_u32(0, 0, 0, 0, 0, *m, *m, *m), + i32x8_from_values_u32(0, 0, 0, 0, 0, 0, *m, *m), + i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, *m), + ]; + static ref nomask: I32x8 = i32x8_from_values_u32(*m, *m, *m, *m, *m, *m, *m, *m); + static ref fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0); + static ref even_mask: I16x8 = i16x8_singleton_u16(0x0F0F); + static ref odd_mask: I16x8 = i16x8_singleton_u16(0xF0F0); +} + impl Tensor { #[inline] pub fn assume_on_gpu(&self) { @@ -921,10 +939,7 @@ impl Tensor { // We don't have implementation for f16, so don't use the vector function if we have // f16 #[cfg(not(feature = "opencl"))] - if other.rows == 1 - && (other.dtype != TensorDType::K4BitQuantization - && self.dtype != TensorDType::K4BitQuantization) - { + if other.rows == 1 { return self.matrix_vector_mul_transposed(other); } #[cfg(feature = "opencl")] @@ -1278,29 +1293,13 @@ impl Tensor { quants: *const F32x8, ) -> (F32x8, F32x8, F32x8, F32x8) { unsafe { - let m: u32 = 0xFFFFFFFF; - let masks: [I32x8; 8] = [ - i32x8_from_values_u32(m, m, m, m, m, m, m, m), - i32x8_from_values_u32(0, m, m, m, m, m, m, m), - i32x8_from_values_u32(0, 0, m, m, m, m, m, m), - i32x8_from_values_u32(0, 0, 0, m, m, m, m, m), - i32x8_from_values_u32(0, 0, 0, 0, m, m, m, m), - i32x8_from_values_u32(0, 0, 0, 0, 0, m, m, m), - i32x8_from_values_u32(0, 0, 0, 0, 0, 0, m, m), - i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, m), - ]; - let nomask: I32x8 = i32x8_from_values_u32(m, m, m, m, m, m, m, m); - let fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0); - if row < nrows { let col = col as i64; let ncols = tensor.cols; let (addr, side) = tensor.q4_address(row as i64, col); let i = load_i16x8(addr as *const I16x8); - let even_mask = i16x8_singleton_u16(0x0F0F); - let odd_mask = i16x8_singleton_u16(0xF0F0); - let evens = and_i16x8(i, even_mask); - let odds = and_i16x8(i, odd_mask); + let evens = and_i16x8(i, *even_mask); + let odds = and_i16x8(i, *odd_mask); let odds = shift_right_by_4_i16x8(odds); let indices1 = extend_i8_to_i32_i32x8(odds); @@ -1319,32 +1318,32 @@ impl Tensor { let unquantized4: F32x8 = gather_scale4_f32x8(quants as *const f32, indices4); let quan1_mask: I32x8 = if col <= ncols - 8 { - nomask + *nomask } else if col < ncols { masks[(col % 8) as usize] } else { - fullmask + *fullmask }; let quan2_mask: I32x8 = if col <= ncols - 16 { - nomask + *nomask } else if col < ncols - 8 { masks[(col % 8) as usize] } else { - fullmask + *fullmask }; let quan3_mask: I32x8 = if col <= ncols - 24 { - nomask + *nomask } else if col < ncols - 16 { masks[(col % 8) as usize] } else { - fullmask + *fullmask }; let quan4_mask: I32x8 = if col <= ncols - 32 { - nomask + *nomask } else if col < ncols - 24 { masks[(col % 8) as usize] } else { - fullmask + *fullmask }; let unquantized1 = and_f32x8(unquantized1, quan1_mask); let unquantized2 = and_f32x8(unquantized2, quan2_mask); @@ -1429,6 +1428,11 @@ impl Tensor { let self_cols: usize = self.cols as usize; let self_cols_capacity: usize = self.capacity_cols as usize; + let self_cols_its = if self_cols % 4 == 0 { + self_cols / 4 + } else { + self_cols / 4 + 1 + }; // src_cols_its == also the shared dimension between src and other. let src_cols_its = if src_cols % 32 == 0 { src_cols / 32 @@ -1455,11 +1459,15 @@ impl Tensor { let quants: [F32x8; 2] = [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; - for col in 0..self_cols { - let row_col = row * self_cols + col; + for col_raw in 0..self_cols_its { + let row_col = row * self_cols_its + col_raw; if row_col % nthreads != thread_idx { continue; } + let col0 = col_raw * 4; + let col1 = col_raw * 4 + 1; + let col2 = col_raw * 4 + 2; + let col3 = col_raw * 4 + 3; #[inline] fn load_f32( @@ -1488,29 +1496,13 @@ impl Tensor { quants: *const F32x8, ) -> (F32x8, F32x8, F32x8, F32x8) { unsafe { - let m: u32 = 0xFFFFFFFF; - let masks: [I32x8; 8] = [ - i32x8_from_values_u32(m, m, m, m, m, m, m, m), - i32x8_from_values_u32(0, m, m, m, m, m, m, m), - i32x8_from_values_u32(0, 0, m, m, m, m, m, m), - i32x8_from_values_u32(0, 0, 0, m, m, m, m, m), - i32x8_from_values_u32(0, 0, 0, 0, m, m, m, m), - i32x8_from_values_u32(0, 0, 0, 0, 0, m, m, m), - i32x8_from_values_u32(0, 0, 0, 0, 0, 0, m, m), - i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, m), - ]; - let nomask: I32x8 = i32x8_from_values_u32(m, m, m, m, m, m, m, m); - let fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0); - if row < nrows { let col = col as i64; let ncols = tensor.cols; let (addr, side) = tensor.q4_address(row as i64, col); let i = load_i16x8(addr as *const I16x8); - let even_mask = i16x8_singleton_u16(0x0F0F); - let odd_mask = i16x8_singleton_u16(0xF0F0); - let evens = and_i16x8(i, even_mask); - let odds = and_i16x8(i, odd_mask); + let evens = and_i16x8(i, *even_mask); + let odds = and_i16x8(i, *odd_mask); let odds = shift_right_by_4_i16x8(odds); let indices1 = extend_i8_to_i32_i32x8(odds); @@ -1529,32 +1521,32 @@ impl Tensor { let unquantized4: F32x8 = gather_scale4_f32x8(quants as *const f32, indices4); let quan1_mask: I32x8 = if col <= ncols - 8 { - nomask + *nomask } else if col < ncols { masks[(col % 8) as usize] } else { - fullmask + *fullmask }; let quan2_mask: I32x8 = if col <= ncols - 16 { - nomask + *nomask } else if col < ncols - 8 { masks[(col % 8) as usize] } else { - fullmask + *fullmask }; let quan3_mask: I32x8 = if col <= ncols - 24 { - nomask + *nomask } else if col < ncols - 16 { masks[(col % 8) as usize] } else { - fullmask + *fullmask }; let quan4_mask: I32x8 = if col <= ncols - 32 { - nomask + *nomask } else if col < ncols - 24 { masks[(col % 8) as usize] } else { - fullmask + *fullmask }; let unquantized1 = and_f32x8(unquantized1, quan1_mask); let unquantized2 = and_f32x8(unquantized2, quan2_mask); @@ -1567,56 +1559,477 @@ impl Tensor { } } - let mut targets8: [F32x8; 4] = - [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; + let mut targets8: [[F32x8; 4]; 4] = [ + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()], + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()], + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()], + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()], + ]; for p in 0..src_cols_its { - let other8_0: F32x8 = load_f32( - other_data, - col, - p * 32, - other_cols, - other_rows, - other_cols_capacity, - ); - let other8_1: F32x8 = load_f32( - other_data, - col, - p * 32 + 8, - other_cols, - other_rows, - other_cols_capacity, - ); - let other8_2: F32x8 = load_f32( - other_data, - col, - p * 32 + 16, - other_cols, - other_rows, - other_cols_capacity, - ); - let other8_3: F32x8 = load_f32( - other_data, - col, - p * 32 + 24, - other_cols, - other_rows, - other_cols_capacity, - ); + // Macro to make code shorter + macro_rules! lo { + ($col:expr, $p:expr) => { + load_f32( + other_data, + $col, + $p, + other_cols, + other_rows, + other_cols_capacity, + ) + }; + } + let other8_00: F32x8 = lo!(col0, p * 32); + let other8_01: F32x8 = lo!(col0, p * 32 + 8); + let other8_02: F32x8 = lo!(col0, p * 32 + 16); + let other8_03: F32x8 = lo!(col0, p * 32 + 24); + let other8_10: F32x8 = lo!(col1, p * 32); + let other8_11: F32x8 = lo!(col1, p * 32 + 8); + let other8_12: F32x8 = lo!(col1, p * 32 + 16); + let other8_13: F32x8 = lo!(col1, p * 32 + 24); + let other8_20: F32x8 = lo!(col2, p * 32); + let other8_21: F32x8 = lo!(col2, p * 32 + 8); + let other8_22: F32x8 = lo!(col2, p * 32 + 16); + let other8_23: F32x8 = lo!(col2, p * 32 + 24); + let other8_30: F32x8 = lo!(col3, p * 32); + let other8_31: F32x8 = lo!(col3, p * 32 + 8); + let other8_32: F32x8 = lo!(col3, p * 32 + 16); + let other8_33: F32x8 = lo!(col3, p * 32 + 24); + let (src8_0, src8_1, src8_2, src8_3): (F32x8, F32x8, F32x8, F32x8) = load_k4_to_f32(&src, row, p * 32, src_rows, quants.as_ptr()); - targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]); - targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]); - targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]); - targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]); + + targets8[0][0] = fma_f32x8(src8_0, other8_00, targets8[0][0]); + targets8[0][1] = fma_f32x8(src8_1, other8_01, targets8[0][1]); + targets8[0][2] = fma_f32x8(src8_2, other8_02, targets8[0][2]); + targets8[0][3] = fma_f32x8(src8_3, other8_03, targets8[0][3]); + targets8[1][0] = fma_f32x8(src8_0, other8_10, targets8[1][0]); + targets8[1][1] = fma_f32x8(src8_1, other8_11, targets8[1][1]); + targets8[1][2] = fma_f32x8(src8_2, other8_12, targets8[1][2]); + targets8[1][3] = fma_f32x8(src8_3, other8_13, targets8[1][3]); + targets8[2][0] = fma_f32x8(src8_0, other8_20, targets8[2][0]); + targets8[2][1] = fma_f32x8(src8_1, other8_21, targets8[2][1]); + targets8[2][2] = fma_f32x8(src8_2, other8_22, targets8[2][2]); + targets8[2][3] = fma_f32x8(src8_3, other8_23, targets8[2][3]); + targets8[3][0] = fma_f32x8(src8_0, other8_30, targets8[3][0]); + targets8[3][1] = fma_f32x8(src8_1, other8_31, targets8[3][1]); + targets8[3][2] = fma_f32x8(src8_2, other8_32, targets8[3][2]); + targets8[3][3] = fma_f32x8(src8_3, other8_33, targets8[3][3]); } - let target0 = horizontal_sum_f32x8(targets8[0]); - let target1 = horizontal_sum_f32x8(targets8[1]); - let target2 = horizontal_sum_f32x8(targets8[2]); - let target3 = horizontal_sum_f32x8(targets8[3]); - let target = target0 + target1 + target2 + target3; - *tgt_data.add(row * self_cols_capacity + col) = target; + let target00 = horizontal_sum_f32x8(targets8[0][0]); + let target01 = horizontal_sum_f32x8(targets8[0][1]); + let target02 = horizontal_sum_f32x8(targets8[0][2]); + let target03 = horizontal_sum_f32x8(targets8[0][3]); + let target0 = target00 + target01 + target02 + target03; + let target10 = horizontal_sum_f32x8(targets8[1][0]); + let target11 = horizontal_sum_f32x8(targets8[1][1]); + let target12 = horizontal_sum_f32x8(targets8[1][2]); + let target13 = horizontal_sum_f32x8(targets8[1][3]); + let target1 = target10 + target11 + target12 + target13; + let target20 = horizontal_sum_f32x8(targets8[2][0]); + let target21 = horizontal_sum_f32x8(targets8[2][1]); + let target22 = horizontal_sum_f32x8(targets8[2][2]); + let target23 = horizontal_sum_f32x8(targets8[2][3]); + let target2 = target20 + target21 + target22 + target23; + let target30 = horizontal_sum_f32x8(targets8[3][0]); + let target31 = horizontal_sum_f32x8(targets8[3][1]); + let target32 = horizontal_sum_f32x8(targets8[3][2]); + let target33 = horizontal_sum_f32x8(targets8[3][3]); + let target3 = target30 + target31 + target32 + target33; + + *tgt_data.add(row * self_cols_capacity + col0) = target0; + if col1 < self_cols { + *tgt_data.add(row * self_cols_capacity + col1) = target1; + } + if col2 < self_cols { + *tgt_data.add(row * self_cols_capacity + col2) = target2; + } + if col3 < self_cols { + *tgt_data.add(row * self_cols_capacity + col3) = target3; + } + } + } + }); + } + } + + fn matrix_vector_mul_inplace_transposed_f32_and_k4bit(&mut self, src: &Tensor, other: &Tensor) { + // Assume: size checks have been done already. + assert!(src.dtype == TensorDType::Float32); + assert!(other.dtype == TensorDType::K4BitQuantization); + assert!(self.dtype == TensorDType::Float32); + assert_eq!(other.rows, 1); + assert_eq!(self.cols, 1); + + unsafe { + let src_rows: usize = src.rows as usize; + let src_cols: usize = src.cols as usize; + let src_cols_capacity: usize = src.capacity_cols as usize; + let other_cols: usize = other.cols as usize; + let other_rows: usize = other.rows as usize; + let other_cols_capacity: usize = other.capacity_cols as usize; + let self_rows: usize = self.rows as usize; + let self_cols: usize = self.cols as usize; + let self_cols_capacity: usize = self.capacity_cols as usize; + + // src_cols_its == also the shared dimension between src and other. + let src_cols_its = if src_cols % 32 == 0 { + src_cols / 32 + } else { + src_cols / 32 + 1 + }; + debug_assert!(!other.q4_data.is_null()); + + let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data); + let other_data: WrappedPtr = WrappedPtr::wrap(other.data); + let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data); + let other_q4_data: WrappedPtr = WrappedPtr::wrap(other.q4_data); + + let nthreads: usize = rayon::current_num_threads(); + (0..nthreads).into_par_iter().for_each(|thread_idx| { + let other_q4_data: *const u8 = other_q4_data.unwrap() as *const u8; + let src_data: *const f32 = src_data_wrap.unwrap() as *const f32; + let other_data: *const u8 = other_data.unwrap() as *const u8; + let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32; + + let quant0 = load_i16x8(other_q4_data.add(0) as *const I16x8); + let quant1 = load_i16x8(other_q4_data.add(16) as *const I16x8); + let quants: [F32x8; 2] = + [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; + + let col = 0; + for row in 0..self_rows { + if row % nthreads != thread_idx { + continue; } + #[inline] + fn load_f32( + src: *const f32, + row: usize, + col: usize, + ncols: usize, + nrows: usize, + cols_capacity: usize, + ) -> F32x8 { + unsafe { + if row >= nrows || col >= ncols { + f32x8_zero() + } else { + load_f32x8(src.add(row * cols_capacity + col) as *const F32x8) + } + } + } + + #[inline] + fn load_k4_to_f32( + tensor: &Tensor, + row: usize, + col: usize, + nrows: usize, + quants: *const F32x8, + ) -> (F32x8, F32x8, F32x8, F32x8) { + unsafe { + if row < nrows { + let col = col as i64; + let ncols = tensor.cols; + let (addr, side) = tensor.q4_address(row as i64, col); + let i = load_i16x8(addr as *const I16x8); + let evens = and_i16x8(i, *even_mask); + let odds = and_i16x8(i, *odd_mask); + let odds = shift_right_by_4_i16x8(odds); + + let indices1 = extend_i8_to_i32_i32x8(odds); + let odds_shifted = shift_right_by_64_i128(odds); + let indices2 = extend_i8_to_i32_i32x8(odds_shifted); + let indices3 = extend_i8_to_i32_i32x8(evens); + let indices4 = + extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens)); + + let unquantized1: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices1); + let unquantized2: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices2); + let unquantized3: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices3); + let unquantized4: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices4); + let quan1_mask: I32x8 = if col <= ncols - 8 { + *nomask + } else if col < ncols { + masks[(col % 8) as usize] + } else { + *fullmask + }; + let quan2_mask: I32x8 = if col <= ncols - 16 { + *nomask + } else if col < ncols - 8 { + masks[(col % 8) as usize] + } else { + *fullmask + }; + let quan3_mask: I32x8 = if col <= ncols - 24 { + *nomask + } else if col < ncols - 16 { + masks[(col % 8) as usize] + } else { + *fullmask + }; + let quan4_mask: I32x8 = if col <= ncols - 32 { + *nomask + } else if col < ncols - 24 { + masks[(col % 8) as usize] + } else { + *fullmask + }; + let unquantized1 = and_f32x8(unquantized1, quan1_mask); + let unquantized2 = and_f32x8(unquantized2, quan2_mask); + let unquantized3 = and_f32x8(unquantized3, quan3_mask); + let unquantized4 = and_f32x8(unquantized4, quan4_mask); + (unquantized1, unquantized2, unquantized3, unquantized4) + } else { + (f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()) + } + } + } + + let mut targets8: [F32x8; 4] = + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; + + for p in 0..src_cols_its { + let (other8_0, other8_1, other8_2, other8_3) = + load_k4_to_f32(&other, col, p * 32, other_rows, quants.as_ptr()); + let src8_0 = + load_f32(src_data, row, p * 32, src_cols, src_rows, src_cols_capacity); + let src8_1 = load_f32( + src_data, + row, + p * 32 + 8, + src_cols, + src_rows, + src_cols_capacity, + ); + let src8_2 = load_f32( + src_data, + row, + p * 32 + 16, + src_cols, + src_rows, + src_cols_capacity, + ); + let src8_3 = load_f32( + src_data, + row, + p * 32 + 24, + src_cols, + src_rows, + src_cols_capacity, + ); + targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]); + targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]); + targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]); + targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]); + } + let target0 = horizontal_sum_f32x8(targets8[0]); + let target1 = horizontal_sum_f32x8(targets8[1]); + let target2 = horizontal_sum_f32x8(targets8[2]); + let target3 = horizontal_sum_f32x8(targets8[3]); + let target = target0 + target1 + target2 + target3; + *tgt_data.add(row * self_cols_capacity + col) = target; + } + }); + } + } + + fn matrix_vector_mul_inplace_transposed_k4bit_and_f32(&mut self, src: &Tensor, other: &Tensor) { + // Assume: size checks have been done already. + assert!(src.dtype == TensorDType::K4BitQuantization); + assert!(other.dtype == TensorDType::Float32); + assert!(self.dtype == TensorDType::Float32); + assert_eq!(other.rows, 1); + assert_eq!(self.cols, 1); + + unsafe { + let src_rows: usize = src.rows as usize; + let src_cols: usize = src.cols as usize; + let src_cols_capacity: usize = src.capacity_cols as usize; + let other_cols: usize = other.cols as usize; + let other_rows: usize = other.rows as usize; + let other_cols_capacity: usize = other.capacity_cols as usize; + let self_rows: usize = self.rows as usize; + let self_cols: usize = self.cols as usize; + let self_cols_capacity: usize = self.capacity_cols as usize; + + // src_cols_its == also the shared dimension between src and other. + let src_cols_its = if src_cols % 32 == 0 { + src_cols / 32 + } else { + src_cols / 32 + 1 + }; + debug_assert!(!src.q4_data.is_null()); + + let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data); + let other_data: WrappedPtr = WrappedPtr::wrap(other.data); + let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data); + let src_q4_data: WrappedPtr = WrappedPtr::wrap(src.q4_data); + + let nthreads: usize = rayon::current_num_threads(); + (0..nthreads).into_par_iter().for_each(|thread_idx| { + let src_q4_data: *const u8 = src_q4_data.unwrap() as *const u8; + let src_data: *const u8 = src_data_wrap.unwrap() as *const u8; + let other_data: *const f32 = other_data.unwrap() as *const f32; + let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32; + + for row in 0..self_rows { + if row % nthreads != thread_idx { + continue; + } + let quant0 = load_i16x8(src_q4_data.add(row * 32) as *const I16x8); + let quant1 = load_i16x8(src_q4_data.add(row * 32 + 16) as *const I16x8); + let quants: [F32x8; 2] = + [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; + + let col = 0; + + #[inline] + fn load_f32( + other: *const f32, + row: usize, + col: usize, + ncols: usize, + nrows: usize, + cols_capacity: usize, + ) -> F32x8 { + unsafe { + if row >= nrows || col >= ncols { + f32x8_zero() + } else { + load_f32x8(other.add(row * cols_capacity + col) as *const F32x8) + } + } + } + + #[inline] + fn load_k4_to_f32( + tensor: &Tensor, + row: usize, + col: usize, + nrows: usize, + quants: *const F32x8, + ) -> (F32x8, F32x8, F32x8, F32x8) { + unsafe { + if row < nrows { + let col = col as i64; + let ncols = tensor.cols; + let (addr, side) = tensor.q4_address(row as i64, col); + let i = load_i16x8(addr as *const I16x8); + let evens = and_i16x8(i, *even_mask); + let odds = and_i16x8(i, *odd_mask); + let odds = shift_right_by_4_i16x8(odds); + + let indices1 = extend_i8_to_i32_i32x8(odds); + let odds_shifted = shift_right_by_64_i128(odds); + let indices2 = extend_i8_to_i32_i32x8(odds_shifted); + let indices3 = extend_i8_to_i32_i32x8(evens); + let indices4 = + extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens)); + + let unquantized1: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices1); + let unquantized2: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices2); + let unquantized3: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices3); + let unquantized4: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices4); + let quan1_mask: I32x8 = if col <= ncols - 8 { + *nomask + } else if col < ncols { + masks[(col % 8) as usize] + } else { + *fullmask + }; + let quan2_mask: I32x8 = if col <= ncols - 16 { + *nomask + } else if col < ncols - 8 { + masks[(col % 8) as usize] + } else { + *fullmask + }; + let quan3_mask: I32x8 = if col <= ncols - 24 { + *nomask + } else if col < ncols - 16 { + masks[(col % 8) as usize] + } else { + *fullmask + }; + let quan4_mask: I32x8 = if col <= ncols - 32 { + *nomask + } else if col < ncols - 24 { + masks[(col % 8) as usize] + } else { + *fullmask + }; + let unquantized1 = and_f32x8(unquantized1, quan1_mask); + let unquantized2 = and_f32x8(unquantized2, quan2_mask); + let unquantized3 = and_f32x8(unquantized3, quan3_mask); + let unquantized4 = and_f32x8(unquantized4, quan4_mask); + (unquantized1, unquantized2, unquantized3, unquantized4) + } else { + (f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()) + } + } + } + + let mut targets8: [F32x8; 4] = + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; + + for p in 0..src_cols_its { + let other8_0: F32x8 = load_f32( + other_data, + col, + p * 32, + other_cols, + other_rows, + other_cols_capacity, + ); + let other8_1: F32x8 = load_f32( + other_data, + col, + p * 32 + 8, + other_cols, + other_rows, + other_cols_capacity, + ); + let other8_2: F32x8 = load_f32( + other_data, + col, + p * 32 + 16, + other_cols, + other_rows, + other_cols_capacity, + ); + let other8_3: F32x8 = load_f32( + other_data, + col, + p * 32 + 24, + other_cols, + other_rows, + other_cols_capacity, + ); + let (src8_0, src8_1, src8_2, src8_3): (F32x8, F32x8, F32x8, F32x8) = + load_k4_to_f32(&src, row, p * 32, src_rows, quants.as_ptr()); + targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]); + targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]); + targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]); + targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]); + } + let target0 = horizontal_sum_f32x8(targets8[0]); + let target1 = horizontal_sum_f32x8(targets8[1]); + let target2 = horizontal_sum_f32x8(targets8[2]); + let target3 = horizontal_sum_f32x8(targets8[3]); + let target = target0 + target1 + target2 + target3; + *tgt_data.add(row * self_cols_capacity + col) = target; } }); } @@ -2091,11 +2504,15 @@ impl Tensor { } assert_eq!(other.rows, 1); - // K4 bit currently has no implementation for matrix_vector_mul - if self.dtype == TensorDType::K4BitQuantization - || other.dtype == TensorDType::K4BitQuantization - { - return self.matrix_mul_transposed(other); + if self.dtype == TensorDType::K4BitQuantization { + let mut result = unsafe { Tensor::uninitialized(self.rows, 1, other.dtype) }; + result.matrix_vector_mul_inplace_transposed_k4bit_and_f32(self, other); + return result; + } + if other.dtype == TensorDType::K4BitQuantization { + let mut result = unsafe { Tensor::uninitialized(self.rows, 1, self.dtype) }; + result.matrix_vector_mul_inplace_transposed_f32_and_k4bit(self, other); + return result; } assert_eq!(other.dtype, self.dtype); @@ -3872,4 +4289,80 @@ mod tests { } } } + + #[test] + fn quantized_matrices_matrix_vector_mul_transposed_correctly_f32_mul_k4() { + // TODO: this test is mostly a copypaste from the matrix_mul tests except let b = 1; + let mut rng = rand::thread_rng(); + for _ in 0..100 { + let a = rng.gen_range(1..=128); + let b = 1; + let c = rng.gen_range(1..=128); + let other_matrix = Tensor::random(a, c, TensorDType::Float32); + let mut reference = Tensor::zeros(b, c, TensorDType::Float32); + + let mut quant_values: Vec> = Vec::with_capacity(c as usize); + for row in 0..b { + let mut quant_values_for_row: Vec = Vec::with_capacity(16); + for _ in 0..16 { + quant_values_for_row.push(rng.gen_range(0.0..=1.0)); + } + quant_values.push(quant_values_for_row); + } + + let mut quantized_values: Vec> = Vec::with_capacity(b as usize); + for row in 0..b { + let mut quant_values_for_row: Vec = Vec::with_capacity(c as usize); + for col in 0..c { + let i = rng.gen_range(0..=15); + reference.set_f32(row, col, quant_values[row as usize][i as usize]); + quant_values_for_row.push(i as u8); + } + quantized_values.push(quant_values_for_row); + } + + let quantized = Tensor::make_k4bit_from_fn( + b, + c, + |row, col| quantized_values[row as usize][col as usize], + |row| { + let mut result: [f32; 16] = [0.0; 16]; + for col in 0..16 { + result[col] = quant_values[row as usize][col]; + } + result + }, + ); + + assert_eq!(reference.rows(), quantized.rows()); + assert_eq!(reference.cols(), quantized.cols()); + + for row in 0..reference.rows { + for col in 0..reference.cols { + // The quantized table always uses f16 so values may not be 100% equal. + assert_relative_eq!( + reference.get_f32(row, col), + quantized.get_f32(row, col), + epsilon = 1e-1, + ); + } + } + + let mult1 = other_matrix.matrix_mul_transposed(&reference); + let mult2 = other_matrix.matrix_mul_transposed(&quantized); + + assert_eq!(mult1.rows(), mult2.rows()); + assert_eq!(mult1.cols(), mult2.cols()); + + for row in 0..mult1.rows { + for col in 0..mult1.cols { + assert_relative_eq!( + mult1.get_f32(row, col), + mult2.get_f32(row, col), + epsilon = 1e-1, + ); + } + } + } + } }