Multithread the k4 * f32 matrix multiplication.

k4bit
Mikko Juola 3 years ago
parent b8946da2d8
commit 40121e1c82

@ -314,9 +314,11 @@ pub fn horizontal_sum_and_f32_to_f16(mut ymm: __m256) -> f16 {
/// Prints a binary representation of i16x8 to stdout in this form: /// Prints a binary representation of i16x8 to stdout in this form:
/// ///
/// ```ignore
/// 0 0 0 0 /// 0 0 0 0
/// 0x0000 0x0000 0x0000 0x0000 /// 0x0000 0x0000 0x0000 0x0000
/// 0000000000000000 0000000000000000 0000000000000000 0000000000000000 etc. /// 0000000000000000 0000000000000000 0000000000000000 0000000000000000 etc.
/// ```
/// ///
/// decimal on first line, hex on second, binary on third. /// decimal on first line, hex on second, binary on third.
pub fn print_i16x8(a: I16x8) { pub fn print_i16x8(a: I16x8) {

@ -1187,10 +1187,6 @@ impl Tensor {
let self_cols: usize = self.cols as usize; let self_cols: usize = self.cols as usize;
let self_cols_capacity: usize = self.capacity_cols as usize; let self_cols_capacity: usize = self.capacity_cols as usize;
let src_data: *const u8 = src.data;
let other_data: *const f32 = other.data as *const f32;
let tgt_data: *mut f32 = self.data as *mut f32;
// src_cols_its == also the shared dimension between src and other. // src_cols_its == also the shared dimension between src and other.
let src_cols_its = if src_cols % 32 == 0 { let src_cols_its = if src_cols % 32 == 0 {
src_cols / 32 src_cols / 32
@ -1199,170 +1195,188 @@ impl Tensor {
}; };
debug_assert!(!src.q4_data.is_null()); debug_assert!(!src.q4_data.is_null());
for row in 0..self_rows { let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data);
let quant0 = load_i16x8(src.q4_data.add(row * 32) as *const I16x8); let other_data: WrappedPtr = WrappedPtr::wrap(other.data);
let quant1 = load_i16x8(src.q4_data.add(row * 32 + 16) as *const I16x8); let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data);
let quants: [F32x8; 2] = let src_q4_data: WrappedPtr = WrappedPtr::wrap(src.q4_data);
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
let nthreads: usize = rayon::current_num_threads();
for col in 0..self_cols { (0..nthreads).into_par_iter().for_each(|thread_idx| {
#[inline] let src_q4_data: *const u8 = src_q4_data.unwrap() as *const u8;
fn load_f32( let src_data: *const u8 = src_data_wrap.unwrap() as *const u8;
other: *const f32, let other_data: *const f32 = other_data.unwrap() as *const f32;
row: usize, let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32;
col: usize,
ncols: usize, for row in 0..self_rows {
nrows: usize, let quant0 = load_i16x8(src_q4_data.add(row * 32) as *const I16x8);
cols_capacity: usize, let quant1 = load_i16x8(src_q4_data.add(row * 32 + 16) as *const I16x8);
) -> F32x8 { let quants: [F32x8; 2] =
unsafe { [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
if row >= nrows || col >= ncols {
f32x8_zero() for col in 0..self_cols {
} else { let row_col = row * self_cols + col;
load_f32x8(other.add(row * cols_capacity + col) as *const F32x8) if row_col % nthreads != thread_idx {
} continue;
} }
}
#[inline] #[inline]
fn load_k4_to_f32( fn load_f32(
tensor: &Tensor, other: *const f32,
row: usize, row: usize,
col: usize, col: usize,
nrows: usize, ncols: usize,
quants: *const F32x8, nrows: usize,
) -> (F32x8, F32x8, F32x8, F32x8) { cols_capacity: usize,
unsafe { ) -> F32x8 {
let M: u32 = 0xFFFFFFFF; unsafe {
let MASKS: [I32x8; 8] = [ if row >= nrows || col >= ncols {
i32x8_from_values_u32(M, M, M, M, M, M, M, M), f32x8_zero()
i32x8_from_values_u32(0, M, M, M, M, M, M, M),
i32x8_from_values_u32(0, 0, M, M, M, M, M, M),
i32x8_from_values_u32(0, 0, 0, M, M, M, M, M),
i32x8_from_values_u32(0, 0, 0, 0, M, M, M, M),
i32x8_from_values_u32(0, 0, 0, 0, 0, M, M, M),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, M, M),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, M),
];
let NOMASK: I32x8 = i32x8_from_values_u32(M, M, M, M, M, M, M, M);
let FULLMASK: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0);
if row < nrows {
let col = col as i64;
let ncols = tensor.cols;
let (addr, side) = tensor.q4_address(row as i64, col);
let i = load_i16x8(addr as *const I16x8);
let even_mask = i16x8_singleton_u16(0x0F0F);
let odd_mask = i16x8_singleton_u16(0xF0F0);
let evens = and_i16x8(i, even_mask);
let odds = and_i16x8(i, odd_mask);
let odds = shift_right_by_4_i16x8(odds);
let indices1 = extend_i8_to_i32_i32x8(odds);
let odds_shifted = shift_right_by_64_i128(odds);
let indices2 = extend_i8_to_i32_i32x8(odds_shifted);
let indices3 = extend_i8_to_i32_i32x8(evens);
let indices4 =
extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens));
let unquantized1: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices1);
let unquantized2: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices2);
let unquantized3: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices3);
let unquantized4: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices4);
let quan1_mask: I32x8 = if col <= ncols - 8 {
NOMASK
} else if col < ncols {
MASKS[(col % 8) as usize]
} else {
FULLMASK
};
let quan2_mask: I32x8 = if col <= ncols - 16 {
NOMASK
} else if col < ncols - 8 {
MASKS[(col % 8) as usize]
} else {
FULLMASK
};
let quan3_mask: I32x8 = if col <= ncols - 24 {
NOMASK
} else if col < ncols - 16 {
MASKS[(col % 8) as usize]
} else { } else {
FULLMASK load_f32x8(other.add(row * cols_capacity + col) as *const F32x8)
}; }
let quan4_mask: I32x8 = if col <= ncols - 32 { }
NOMASK }
} else if col < ncols - 24 {
MASKS[(col % 8) as usize] #[inline]
fn load_k4_to_f32(
tensor: &Tensor,
row: usize,
col: usize,
nrows: usize,
quants: *const F32x8,
) -> (F32x8, F32x8, F32x8, F32x8) {
unsafe {
let M: u32 = 0xFFFFFFFF;
let MASKS: [I32x8; 8] = [
i32x8_from_values_u32(M, M, M, M, M, M, M, M),
i32x8_from_values_u32(0, M, M, M, M, M, M, M),
i32x8_from_values_u32(0, 0, M, M, M, M, M, M),
i32x8_from_values_u32(0, 0, 0, M, M, M, M, M),
i32x8_from_values_u32(0, 0, 0, 0, M, M, M, M),
i32x8_from_values_u32(0, 0, 0, 0, 0, M, M, M),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, M, M),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, M),
];
let NOMASK: I32x8 = i32x8_from_values_u32(M, M, M, M, M, M, M, M);
let FULLMASK: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0);
if row < nrows {
let col = col as i64;
let ncols = tensor.cols;
let (addr, side) = tensor.q4_address(row as i64, col);
let i = load_i16x8(addr as *const I16x8);
let even_mask = i16x8_singleton_u16(0x0F0F);
let odd_mask = i16x8_singleton_u16(0xF0F0);
let evens = and_i16x8(i, even_mask);
let odds = and_i16x8(i, odd_mask);
let odds = shift_right_by_4_i16x8(odds);
let indices1 = extend_i8_to_i32_i32x8(odds);
let odds_shifted = shift_right_by_64_i128(odds);
let indices2 = extend_i8_to_i32_i32x8(odds_shifted);
let indices3 = extend_i8_to_i32_i32x8(evens);
let indices4 =
extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens));
let unquantized1: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices1);
let unquantized2: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices2);
let unquantized3: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices3);
let unquantized4: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices4);
let quan1_mask: I32x8 = if col <= ncols - 8 {
NOMASK
} else if col < ncols {
MASKS[(col % 8) as usize]
} else {
FULLMASK
};
let quan2_mask: I32x8 = if col <= ncols - 16 {
NOMASK
} else if col < ncols - 8 {
MASKS[(col % 8) as usize]
} else {
FULLMASK
};
let quan3_mask: I32x8 = if col <= ncols - 24 {
NOMASK
} else if col < ncols - 16 {
MASKS[(col % 8) as usize]
} else {
FULLMASK
};
let quan4_mask: I32x8 = if col <= ncols - 32 {
NOMASK
} else if col < ncols - 24 {
MASKS[(col % 8) as usize]
} else {
FULLMASK
};
let unquantized1 = and_f32x8(unquantized1, quan1_mask);
let unquantized2 = and_f32x8(unquantized2, quan2_mask);
let unquantized3 = and_f32x8(unquantized3, quan3_mask);
let unquantized4 = and_f32x8(unquantized4, quan4_mask);
(unquantized1, unquantized2, unquantized3, unquantized4)
} else { } else {
FULLMASK (f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero())
}; }
let unquantized1 = and_f32x8(unquantized1, quan1_mask);
let unquantized2 = and_f32x8(unquantized2, quan2_mask);
let unquantized3 = and_f32x8(unquantized3, quan3_mask);
let unquantized4 = and_f32x8(unquantized4, quan4_mask);
(unquantized1, unquantized2, unquantized3, unquantized4)
} else {
(f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero())
} }
} }
}
let mut targets8: [F32x8; 4] = let mut targets8: [F32x8; 4] =
[f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()];
for p in 0..src_cols_its { for p in 0..src_cols_its {
let other8_0: F32x8 = load_f32( let other8_0: F32x8 = load_f32(
other_data, other_data,
col, col,
p * 32, p * 32,
other_cols, other_cols,
other_rows, other_rows,
other_cols_capacity, other_cols_capacity,
); );
let other8_1: F32x8 = load_f32( let other8_1: F32x8 = load_f32(
other_data, other_data,
col, col,
p * 32 + 8, p * 32 + 8,
other_cols, other_cols,
other_rows, other_rows,
other_cols_capacity, other_cols_capacity,
); );
let other8_2: F32x8 = load_f32( let other8_2: F32x8 = load_f32(
other_data, other_data,
col, col,
p * 32 + 16, p * 32 + 16,
other_cols, other_cols,
other_rows, other_rows,
other_cols_capacity, other_cols_capacity,
); );
let other8_3: F32x8 = load_f32( let other8_3: F32x8 = load_f32(
other_data, other_data,
col, col,
p * 32 + 24, p * 32 + 24,
other_cols, other_cols,
other_rows, other_rows,
other_cols_capacity, other_cols_capacity,
); );
let (src8_0, src8_1, src8_2, src8_3): (F32x8, F32x8, F32x8, F32x8) = let (src8_0, src8_1, src8_2, src8_3): (F32x8, F32x8, F32x8, F32x8) =
load_k4_to_f32(&src, row, p * 32, src_rows, quants.as_ptr()); load_k4_to_f32(&src, row, p * 32, src_rows, quants.as_ptr());
targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]); targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]);
targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]); targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]);
targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]); targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]);
targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]); targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]);
}
let target0 = horizontal_sum_f32x8(targets8[0]);
let target1 = horizontal_sum_f32x8(targets8[1]);
let target2 = horizontal_sum_f32x8(targets8[2]);
let target3 = horizontal_sum_f32x8(targets8[3]);
let target = target0 + target1 + target2 + target3;
*tgt_data.add(row * self_cols_capacity + col) = target;
} }
let target0 = horizontal_sum_f32x8(targets8[0]);
let target1 = horizontal_sum_f32x8(targets8[1]);
let target2 = horizontal_sum_f32x8(targets8[2]);
let target3 = horizontal_sum_f32x8(targets8[3]);
let target = target0 + target1 + target2 + target3;
*tgt_data.add(row * self_cols_capacity + col) = target;
} }
} });
} }
} }

Loading…
Cancel
Save