From 40121e1c82e00a3f9567c7af66d632b3d41fca22 Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Wed, 22 Mar 2023 21:13:13 -0700 Subject: [PATCH] Multithread the k4 * f32 matrix multiplication. --- src/simd_support.rs | 2 + src/tensor.rs | 332 +++++++++++++++++++++++--------------------- 2 files changed, 175 insertions(+), 159 deletions(-) diff --git a/src/simd_support.rs b/src/simd_support.rs index 2324ded..613c537 100644 --- a/src/simd_support.rs +++ b/src/simd_support.rs @@ -314,9 +314,11 @@ pub fn horizontal_sum_and_f32_to_f16(mut ymm: __m256) -> f16 { /// Prints a binary representation of i16x8 to stdout in this form: /// +/// ```ignore /// 0 0 0 0 /// 0x0000 0x0000 0x0000 0x0000 /// 0000000000000000 0000000000000000 0000000000000000 0000000000000000 etc. +/// ``` /// /// decimal on first line, hex on second, binary on third. pub fn print_i16x8(a: I16x8) { diff --git a/src/tensor.rs b/src/tensor.rs index 8aeb70e..df1e793 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1187,10 +1187,6 @@ impl Tensor { let self_cols: usize = self.cols as usize; let self_cols_capacity: usize = self.capacity_cols as usize; - let src_data: *const u8 = src.data; - let other_data: *const f32 = other.data as *const f32; - let tgt_data: *mut f32 = self.data as *mut f32; - // src_cols_its == also the shared dimension between src and other. let src_cols_its = if src_cols % 32 == 0 { src_cols / 32 @@ -1199,170 +1195,188 @@ impl Tensor { }; debug_assert!(!src.q4_data.is_null()); - for row in 0..self_rows { - let quant0 = load_i16x8(src.q4_data.add(row * 32) as *const I16x8); - let quant1 = load_i16x8(src.q4_data.add(row * 32 + 16) as *const I16x8); - let quants: [F32x8; 2] = - [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; - - for col in 0..self_cols { - #[inline] - fn load_f32( - other: *const f32, - row: usize, - col: usize, - ncols: usize, - nrows: usize, - cols_capacity: usize, - ) -> F32x8 { - unsafe { - if row >= nrows || col >= ncols { - f32x8_zero() - } else { - load_f32x8(other.add(row * cols_capacity + col) as *const F32x8) - } + let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data); + let other_data: WrappedPtr = WrappedPtr::wrap(other.data); + let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data); + let src_q4_data: WrappedPtr = WrappedPtr::wrap(src.q4_data); + + let nthreads: usize = rayon::current_num_threads(); + (0..nthreads).into_par_iter().for_each(|thread_idx| { + let src_q4_data: *const u8 = src_q4_data.unwrap() as *const u8; + let src_data: *const u8 = src_data_wrap.unwrap() as *const u8; + let other_data: *const f32 = other_data.unwrap() as *const f32; + let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32; + + for row in 0..self_rows { + let quant0 = load_i16x8(src_q4_data.add(row * 32) as *const I16x8); + let quant1 = load_i16x8(src_q4_data.add(row * 32 + 16) as *const I16x8); + let quants: [F32x8; 2] = + [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; + + for col in 0..self_cols { + let row_col = row * self_cols + col; + if row_col % nthreads != thread_idx { + continue; } - } - #[inline] - fn load_k4_to_f32( - tensor: &Tensor, - row: usize, - col: usize, - nrows: usize, - quants: *const F32x8, - ) -> (F32x8, F32x8, F32x8, F32x8) { - unsafe { - let M: u32 = 0xFFFFFFFF; - let MASKS: [I32x8; 8] = [ - i32x8_from_values_u32(M, M, M, M, M, M, M, M), - i32x8_from_values_u32(0, M, M, M, M, M, M, M), - i32x8_from_values_u32(0, 0, M, M, M, M, M, M), - i32x8_from_values_u32(0, 0, 0, M, M, M, M, M), - i32x8_from_values_u32(0, 0, 0, 0, M, M, M, M), - i32x8_from_values_u32(0, 0, 0, 0, 0, M, M, M), - i32x8_from_values_u32(0, 0, 0, 0, 0, 0, M, M), - i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, M), - ]; - let NOMASK: I32x8 = i32x8_from_values_u32(M, M, M, M, M, M, M, M); - let FULLMASK: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0); - - if row < nrows { - let col = col as i64; - let ncols = tensor.cols; - let (addr, side) = tensor.q4_address(row as i64, col); - let i = load_i16x8(addr as *const I16x8); - let even_mask = i16x8_singleton_u16(0x0F0F); - let odd_mask = i16x8_singleton_u16(0xF0F0); - let evens = and_i16x8(i, even_mask); - let odds = and_i16x8(i, odd_mask); - let odds = shift_right_by_4_i16x8(odds); - - let indices1 = extend_i8_to_i32_i32x8(odds); - let odds_shifted = shift_right_by_64_i128(odds); - let indices2 = extend_i8_to_i32_i32x8(odds_shifted); - let indices3 = extend_i8_to_i32_i32x8(evens); - let indices4 = - extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens)); - - let unquantized1: F32x8 = - gather_scale4_f32x8(quants as *const f32, indices1); - let unquantized2: F32x8 = - gather_scale4_f32x8(quants as *const f32, indices2); - let unquantized3: F32x8 = - gather_scale4_f32x8(quants as *const f32, indices3); - let unquantized4: F32x8 = - gather_scale4_f32x8(quants as *const f32, indices4); - let quan1_mask: I32x8 = if col <= ncols - 8 { - NOMASK - } else if col < ncols { - MASKS[(col % 8) as usize] - } else { - FULLMASK - }; - let quan2_mask: I32x8 = if col <= ncols - 16 { - NOMASK - } else if col < ncols - 8 { - MASKS[(col % 8) as usize] - } else { - FULLMASK - }; - let quan3_mask: I32x8 = if col <= ncols - 24 { - NOMASK - } else if col < ncols - 16 { - MASKS[(col % 8) as usize] + #[inline] + fn load_f32( + other: *const f32, + row: usize, + col: usize, + ncols: usize, + nrows: usize, + cols_capacity: usize, + ) -> F32x8 { + unsafe { + if row >= nrows || col >= ncols { + f32x8_zero() } else { - FULLMASK - }; - let quan4_mask: I32x8 = if col <= ncols - 32 { - NOMASK - } else if col < ncols - 24 { - MASKS[(col % 8) as usize] + load_f32x8(other.add(row * cols_capacity + col) as *const F32x8) + } + } + } + + #[inline] + fn load_k4_to_f32( + tensor: &Tensor, + row: usize, + col: usize, + nrows: usize, + quants: *const F32x8, + ) -> (F32x8, F32x8, F32x8, F32x8) { + unsafe { + let M: u32 = 0xFFFFFFFF; + let MASKS: [I32x8; 8] = [ + i32x8_from_values_u32(M, M, M, M, M, M, M, M), + i32x8_from_values_u32(0, M, M, M, M, M, M, M), + i32x8_from_values_u32(0, 0, M, M, M, M, M, M), + i32x8_from_values_u32(0, 0, 0, M, M, M, M, M), + i32x8_from_values_u32(0, 0, 0, 0, M, M, M, M), + i32x8_from_values_u32(0, 0, 0, 0, 0, M, M, M), + i32x8_from_values_u32(0, 0, 0, 0, 0, 0, M, M), + i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, M), + ]; + let NOMASK: I32x8 = i32x8_from_values_u32(M, M, M, M, M, M, M, M); + let FULLMASK: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0); + + if row < nrows { + let col = col as i64; + let ncols = tensor.cols; + let (addr, side) = tensor.q4_address(row as i64, col); + let i = load_i16x8(addr as *const I16x8); + let even_mask = i16x8_singleton_u16(0x0F0F); + let odd_mask = i16x8_singleton_u16(0xF0F0); + let evens = and_i16x8(i, even_mask); + let odds = and_i16x8(i, odd_mask); + let odds = shift_right_by_4_i16x8(odds); + + let indices1 = extend_i8_to_i32_i32x8(odds); + let odds_shifted = shift_right_by_64_i128(odds); + let indices2 = extend_i8_to_i32_i32x8(odds_shifted); + let indices3 = extend_i8_to_i32_i32x8(evens); + let indices4 = + extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens)); + + let unquantized1: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices1); + let unquantized2: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices2); + let unquantized3: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices3); + let unquantized4: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices4); + let quan1_mask: I32x8 = if col <= ncols - 8 { + NOMASK + } else if col < ncols { + MASKS[(col % 8) as usize] + } else { + FULLMASK + }; + let quan2_mask: I32x8 = if col <= ncols - 16 { + NOMASK + } else if col < ncols - 8 { + MASKS[(col % 8) as usize] + } else { + FULLMASK + }; + let quan3_mask: I32x8 = if col <= ncols - 24 { + NOMASK + } else if col < ncols - 16 { + MASKS[(col % 8) as usize] + } else { + FULLMASK + }; + let quan4_mask: I32x8 = if col <= ncols - 32 { + NOMASK + } else if col < ncols - 24 { + MASKS[(col % 8) as usize] + } else { + FULLMASK + }; + let unquantized1 = and_f32x8(unquantized1, quan1_mask); + let unquantized2 = and_f32x8(unquantized2, quan2_mask); + let unquantized3 = and_f32x8(unquantized3, quan3_mask); + let unquantized4 = and_f32x8(unquantized4, quan4_mask); + (unquantized1, unquantized2, unquantized3, unquantized4) } else { - FULLMASK - }; - let unquantized1 = and_f32x8(unquantized1, quan1_mask); - let unquantized2 = and_f32x8(unquantized2, quan2_mask); - let unquantized3 = and_f32x8(unquantized3, quan3_mask); - let unquantized4 = and_f32x8(unquantized4, quan4_mask); - (unquantized1, unquantized2, unquantized3, unquantized4) - } else { - (f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()) + (f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()) + } } } - } - let mut targets8: [F32x8; 4] = - [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; - - for p in 0..src_cols_its { - let other8_0: F32x8 = load_f32( - other_data, - col, - p * 32, - other_cols, - other_rows, - other_cols_capacity, - ); - let other8_1: F32x8 = load_f32( - other_data, - col, - p * 32 + 8, - other_cols, - other_rows, - other_cols_capacity, - ); - let other8_2: F32x8 = load_f32( - other_data, - col, - p * 32 + 16, - other_cols, - other_rows, - other_cols_capacity, - ); - let other8_3: F32x8 = load_f32( - other_data, - col, - p * 32 + 24, - other_cols, - other_rows, - other_cols_capacity, - ); - let (src8_0, src8_1, src8_2, src8_3): (F32x8, F32x8, F32x8, F32x8) = - load_k4_to_f32(&src, row, p * 32, src_rows, quants.as_ptr()); - targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]); - targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]); - targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]); - targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]); + let mut targets8: [F32x8; 4] = + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; + + for p in 0..src_cols_its { + let other8_0: F32x8 = load_f32( + other_data, + col, + p * 32, + other_cols, + other_rows, + other_cols_capacity, + ); + let other8_1: F32x8 = load_f32( + other_data, + col, + p * 32 + 8, + other_cols, + other_rows, + other_cols_capacity, + ); + let other8_2: F32x8 = load_f32( + other_data, + col, + p * 32 + 16, + other_cols, + other_rows, + other_cols_capacity, + ); + let other8_3: F32x8 = load_f32( + other_data, + col, + p * 32 + 24, + other_cols, + other_rows, + other_cols_capacity, + ); + let (src8_0, src8_1, src8_2, src8_3): (F32x8, F32x8, F32x8, F32x8) = + load_k4_to_f32(&src, row, p * 32, src_rows, quants.as_ptr()); + targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]); + targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]); + targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]); + targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]); + } + let target0 = horizontal_sum_f32x8(targets8[0]); + let target1 = horizontal_sum_f32x8(targets8[1]); + let target2 = horizontal_sum_f32x8(targets8[2]); + let target3 = horizontal_sum_f32x8(targets8[3]); + let target = target0 + target1 + target2 + target3; + *tgt_data.add(row * self_cols_capacity + col) = target; } - let target0 = horizontal_sum_f32x8(targets8[0]); - let target1 = horizontal_sum_f32x8(targets8[1]); - let target2 = horizontal_sum_f32x8(targets8[2]); - let target3 = horizontal_sum_f32x8(targets8[3]); - let target = target0 + target1 + target2 + target3; - *tgt_data.add(row * self_cols_capacity + col) = target; } - } + }); } }