Make separate matrix_vector_muls for 4-bit quantization rather than using matrix_mul for them.

k4bit
Mikko Juola 3 years ago
parent 2f3e9bc0f5
commit 8cc82ae7e2

@ -113,6 +113,7 @@ pub fn tensor_benchmarks(c: &mut Criterion) {
let m2_f16 = m2.to_f16(); let m2_f16 = m2.to_f16();
let quant = m1.quantize(); let quant = m1.quantize();
let quant2 = m2.quantize();
c.bench_function( c.bench_function(
"1024x128 * 1x128 matrix vector transposed multiplication, k4 quantized * f32", "1024x128 * 1x128 matrix vector transposed multiplication, k4 quantized * f32",
@ -122,9 +123,17 @@ pub fn tensor_benchmarks(c: &mut Criterion) {
}) })
}, },
); );
c.bench_function(
"1024x128 * 1x128 matrix vector transposed multiplication, f32 quantized * k4",
|b| {
b.iter(|| {
let _ = m1.matrix_vector_mul_transposed(black_box(&quant2));
})
},
);
c.bench_function( c.bench_function(
"matrix multiplication 8x4096 @ 4096x4096 k8 quantized * f32 in-place, transposed", "matrix multiplication 8x4096 @ 4096x4096 k4 quantized * f32 in-place, transposed",
|b| { |b| {
b.iter(|| { b.iter(|| {
let _ = result_84096.matrix_mul_inplace_transposed( let _ = result_84096.matrix_mul_inplace_transposed(

@ -266,11 +266,6 @@ pub fn shift_right_by_64_i128(a: I16x8) -> I16x8 {
unsafe { _mm_srli_si128(a, 64 / 8) } unsafe { _mm_srli_si128(a, 64 / 8) }
} }
// Shuffle/premute
pub fn shuffle_i16x8(a: I16x8, permutation: I16x8) -> I16x8 {
unsafe { _mm_shuffle_epi8(a, permutation) }
}
// Extends 8 i8 values into 7 i16 values // Extends 8 i8 values into 7 i16 values
// //
// XXYYZZ -> 00XX00YY00ZZ // XXYYZZ -> 00XX00YY00ZZ

@ -0,0 +1,419 @@
fn matrix_mul_inplace_transposed_f32_and_k4bit(&mut self, src: &Tensor, other: &Tensor) {
// Assume: size checks have been done already.
assert!(src.dtype == TensorDType::Float32);
assert!(other.dtype == TensorDType::K4BitQuantization);
assert!(self.dtype == TensorDType::Float32);
unsafe {
let src_rows: usize = src.rows as usize;
let src_cols: usize = src.cols as usize;
let src_cols_capacity: usize = src.capacity_cols as usize;
let other_cols: usize = other.cols as usize;
let other_rows: usize = other.rows as usize;
let other_cols_capacity: usize = other.capacity_cols as usize;
let self_rows: usize = self.rows as usize;
let self_cols: usize = self.cols as usize;
let self_cols_capacity: usize = self.capacity_cols as usize;
// src_cols_its == also the shared dimension between src and other.
let src_cols_its = if src_cols % 32 == 0 {
src_cols / 32
} else {
src_cols / 32 + 1
};
debug_assert!(!other.q4_data.is_null());
let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data);
let other_data: WrappedPtr = WrappedPtr::wrap(other.data);
let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data);
let other_q4_data: WrappedPtr = WrappedPtr::wrap(other.q4_data);
let nthreads: usize = rayon::current_num_threads();
(0..nthreads).into_par_iter().for_each(|thread_idx| {
let other_q4_data: *const u8 = other_q4_data.unwrap() as *const u8;
let src_data: *const f32 = src_data_wrap.unwrap() as *const f32;
let other_data: *const u8 = other_data.unwrap() as *const u8;
let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32;
for row in 0..self_rows {
for col in 0..self_cols {
let row_col = row * self_cols + col;
if row_col % nthreads != thread_idx {
continue;
}
let quant0 = load_i16x8(other_q4_data.add(col * 32) as *const I16x8);
let quant1 = load_i16x8(other_q4_data.add(col * 32 + 16) as *const I16x8);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
#[inline]
fn load_f32(
src: *const f32,
row: usize,
col: usize,
ncols: usize,
nrows: usize,
cols_capacity: usize,
) -> F32x8 {
unsafe {
if row >= nrows || col >= ncols {
f32x8_zero()
} else {
load_f32x8(src.add(row * cols_capacity + col) as *const F32x8)
}
}
}
#[inline]
fn load_k4_to_f32(
tensor: &Tensor,
row: usize,
col: usize,
nrows: usize,
quants: *const F32x8,
) -> (F32x8, F32x8, F32x8, F32x8) {
unsafe {
let m: u32 = 0xFFFFFFFF;
let masks: [I32x8; 8] = [
i32x8_from_values_u32(m, m, m, m, m, m, m, m),
i32x8_from_values_u32(0, m, m, m, m, m, m, m),
i32x8_from_values_u32(0, 0, m, m, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, m, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, m),
];
let nomask: I32x8 = i32x8_from_values_u32(m, m, m, m, m, m, m, m);
let fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0);
if row < nrows {
let col = col as i64;
let ncols = tensor.cols;
let (addr, side) = tensor.q4_address(row as i64, col);
let i = load_i16x8(addr as *const I16x8);
let even_mask = i16x8_singleton_u16(0x0F0F);
let odd_mask = i16x8_singleton_u16(0xF0F0);
let evens = and_i16x8(i, even_mask);
let odds = and_i16x8(i, odd_mask);
let odds = shift_right_by_4_i16x8(odds);
let indices1 = extend_i8_to_i32_i32x8(odds);
let odds_shifted = shift_right_by_64_i128(odds);
let indices2 = extend_i8_to_i32_i32x8(odds_shifted);
let indices3 = extend_i8_to_i32_i32x8(evens);
let indices4 =
extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens));
let unquantized1: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices1);
let unquantized2: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices2);
let unquantized3: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices3);
let unquantized4: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices4);
let quan1_mask: I32x8 = if col <= ncols - 8 {
nomask
} else if col < ncols {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan2_mask: I32x8 = if col <= ncols - 16 {
nomask
} else if col < ncols - 8 {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan3_mask: I32x8 = if col <= ncols - 24 {
nomask
} else if col < ncols - 16 {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan4_mask: I32x8 = if col <= ncols - 32 {
nomask
} else if col < ncols - 24 {
masks[(col % 8) as usize]
} else {
fullmask
};
let unquantized1 = and_f32x8(unquantized1, quan1_mask);
let unquantized2 = and_f32x8(unquantized2, quan2_mask);
let unquantized3 = and_f32x8(unquantized3, quan3_mask);
let unquantized4 = and_f32x8(unquantized4, quan4_mask);
(unquantized1, unquantized2, unquantized3, unquantized4)
} else {
(f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero())
}
}
}
let mut targets8: [F32x8; 4] =
[f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()];
for p in 0..src_cols_its {
let (other8_0, other8_1, other8_2, other8_3) =
load_k4_to_f32(&other, col, p * 32, other_rows, quants.as_ptr());
let src8_0 = load_f32(
src_data,
row,
p * 32,
src_cols,
src_rows,
src_cols_capacity,
);
let src8_1 = load_f32(
src_data,
row,
p * 32 + 8,
src_cols,
src_rows,
src_cols_capacity,
);
let src8_2 = load_f32(
src_data,
row,
p * 32 + 16,
src_cols,
src_rows,
src_cols_capacity,
);
let src8_3 = load_f32(
src_data,
row,
p * 32 + 24,
src_cols,
src_rows,
src_cols_capacity,
);
targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]);
targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]);
targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]);
targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]);
}
let target0 = horizontal_sum_f32x8(targets8[0]);
let target1 = horizontal_sum_f32x8(targets8[1]);
let target2 = horizontal_sum_f32x8(targets8[2]);
let target3 = horizontal_sum_f32x8(targets8[3]);
let target = target0 + target1 + target2 + target3;
*tgt_data.add(row * self_cols_capacity + col) = target;
}
}
});
}
}
fn matrix_mul_inplace_transposed_k4bit_and_f32(&mut self, src: &Tensor, other: &Tensor) {
// Assume: size checks have been done already.
assert!(src.dtype == TensorDType::K4BitQuantization);
assert!(other.dtype == TensorDType::Float32);
assert!(self.dtype == TensorDType::Float32);
unsafe {
let src_rows: usize = src.rows as usize;
let src_cols: usize = src.cols as usize;
let src_cols_capacity: usize = src.capacity_cols as usize;
let other_cols: usize = other.cols as usize;
let other_rows: usize = other.rows as usize;
let other_cols_capacity: usize = other.capacity_cols as usize;
let self_rows: usize = self.rows as usize;
let self_cols: usize = self.cols as usize;
let self_cols_capacity: usize = self.capacity_cols as usize;
// src_cols_its == also the shared dimension between src and other.
let src_cols_its = if src_cols % 32 == 0 {
src_cols / 32
} else {
src_cols / 32 + 1
};
debug_assert!(!src.q4_data.is_null());
let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data);
let other_data: WrappedPtr = WrappedPtr::wrap(other.data);
let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data);
let src_q4_data: WrappedPtr = WrappedPtr::wrap(src.q4_data);
let nthreads: usize = rayon::current_num_threads();
(0..nthreads).into_par_iter().for_each(|thread_idx| {
let src_q4_data: *const u8 = src_q4_data.unwrap() as *const u8;
let src_data: *const u8 = src_data_wrap.unwrap() as *const u8;
let other_data: *const f32 = other_data.unwrap() as *const f32;
let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32;
for row in 0..self_rows {
let quant0 = load_i16x8(src_q4_data.add(row * 32) as *const I16x8);
let quant1 = load_i16x8(src_q4_data.add(row * 32 + 16) as *const I16x8);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
for col in 0..self_cols {
let row_col = row * self_cols + col;
if row_col % nthreads != thread_idx {
continue;
}
#[inline]
fn load_f32(
other: *const f32,
row: usize,
col: usize,
ncols: usize,
nrows: usize,
cols_capacity: usize,
) -> F32x8 {
unsafe {
if row >= nrows || col >= ncols {
f32x8_zero()
} else {
load_f32x8(other.add(row * cols_capacity + col) as *const F32x8)
}
}
}
#[inline]
fn load_k4_to_f32(
tensor: &Tensor,
row: usize,
col: usize,
nrows: usize,
quants: *const F32x8,
) -> (F32x8, F32x8, F32x8, F32x8) {
unsafe {
let m: u32 = 0xFFFFFFFF;
let masks: [I32x8; 8] = [
i32x8_from_values_u32(m, m, m, m, m, m, m, m),
i32x8_from_values_u32(0, m, m, m, m, m, m, m),
i32x8_from_values_u32(0, 0, m, m, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, m, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, m),
];
let nomask: I32x8 = i32x8_from_values_u32(m, m, m, m, m, m, m, m);
let fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0);
if row < nrows {
let col = col as i64;
let ncols = tensor.cols;
let (addr, side) = tensor.q4_address(row as i64, col);
let i = load_i16x8(addr as *const I16x8);
let even_mask = i16x8_singleton_u16(0x0F0F);
let odd_mask = i16x8_singleton_u16(0xF0F0);
let evens = and_i16x8(i, even_mask);
let odds = and_i16x8(i, odd_mask);
let odds = shift_right_by_4_i16x8(odds);
let indices1 = extend_i8_to_i32_i32x8(odds);
let odds_shifted = shift_right_by_64_i128(odds);
let indices2 = extend_i8_to_i32_i32x8(odds_shifted);
let indices3 = extend_i8_to_i32_i32x8(evens);
let indices4 =
extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens));
let unquantized1: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices1);
let unquantized2: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices2);
let unquantized3: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices3);
let unquantized4: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices4);
let quan1_mask: I32x8 = if col <= ncols - 8 {
nomask
} else if col < ncols {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan2_mask: I32x8 = if col <= ncols - 16 {
nomask
} else if col < ncols - 8 {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan3_mask: I32x8 = if col <= ncols - 24 {
nomask
} else if col < ncols - 16 {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan4_mask: I32x8 = if col <= ncols - 32 {
nomask
} else if col < ncols - 24 {
masks[(col % 8) as usize]
} else {
fullmask
};
let unquantized1 = and_f32x8(unquantized1, quan1_mask);
let unquantized2 = and_f32x8(unquantized2, quan2_mask);
let unquantized3 = and_f32x8(unquantized3, quan3_mask);
let unquantized4 = and_f32x8(unquantized4, quan4_mask);
(unquantized1, unquantized2, unquantized3, unquantized4)
} else {
(f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero())
}
}
}
let mut targets8: [F32x8; 4] =
[f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()];
for p in 0..src_cols_its {
let other8_0: F32x8 = load_f32(
other_data,
col,
p * 32,
other_cols,
other_rows,
other_cols_capacity,
);
let other8_1: F32x8 = load_f32(
other_data,
col,
p * 32 + 8,
other_cols,
other_rows,
other_cols_capacity,
);
let other8_2: F32x8 = load_f32(
other_data,
col,
p * 32 + 16,
other_cols,
other_rows,
other_cols_capacity,
);
let other8_3: F32x8 = load_f32(
other_data,
col,
p * 32 + 24,
other_cols,
other_rows,
other_cols_capacity,
);
let (src8_0, src8_1, src8_2, src8_3): (F32x8, F32x8, F32x8, F32x8) =
load_k4_to_f32(&src, row, p * 32, src_rows, quants.as_ptr());
targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]);
targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]);
targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]);
targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]);
}
let target0 = horizontal_sum_f32x8(targets8[0]);
let target1 = horizontal_sum_f32x8(targets8[1]);
let target2 = horizontal_sum_f32x8(targets8[2]);
let target3 = horizontal_sum_f32x8(targets8[3]);
let target = target0 + target1 + target2 + target3;
*tgt_data.add(row * self_cols_capacity + col) = target;
}
}
});
}
}

@ -225,6 +225,24 @@ fn compute_capacity_cols_f16(cols: i64) -> i64 {
} }
} }
lazy_static! {
static ref m: u32 = 0xFFFFFFFF;
static ref masks: [I32x8; 8] = [
i32x8_from_values_u32(*m, *m, *m, *m, *m, *m, *m, *m),
i32x8_from_values_u32(0, *m, *m, *m, *m, *m, *m, *m),
i32x8_from_values_u32(0, 0, *m, *m, *m, *m, *m, *m),
i32x8_from_values_u32(0, 0, 0, *m, *m, *m, *m, *m),
i32x8_from_values_u32(0, 0, 0, 0, *m, *m, *m, *m),
i32x8_from_values_u32(0, 0, 0, 0, 0, *m, *m, *m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, *m, *m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, *m),
];
static ref nomask: I32x8 = i32x8_from_values_u32(*m, *m, *m, *m, *m, *m, *m, *m);
static ref fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0);
static ref even_mask: I16x8 = i16x8_singleton_u16(0x0F0F);
static ref odd_mask: I16x8 = i16x8_singleton_u16(0xF0F0);
}
impl Tensor { impl Tensor {
#[inline] #[inline]
pub fn assume_on_gpu(&self) { pub fn assume_on_gpu(&self) {
@ -921,10 +939,7 @@ impl Tensor {
// We don't have implementation for f16, so don't use the vector function if we have // We don't have implementation for f16, so don't use the vector function if we have
// f16 // f16
#[cfg(not(feature = "opencl"))] #[cfg(not(feature = "opencl"))]
if other.rows == 1 if other.rows == 1 {
&& (other.dtype != TensorDType::K4BitQuantization
&& self.dtype != TensorDType::K4BitQuantization)
{
return self.matrix_vector_mul_transposed(other); return self.matrix_vector_mul_transposed(other);
} }
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
@ -1278,29 +1293,13 @@ impl Tensor {
quants: *const F32x8, quants: *const F32x8,
) -> (F32x8, F32x8, F32x8, F32x8) { ) -> (F32x8, F32x8, F32x8, F32x8) {
unsafe { unsafe {
let m: u32 = 0xFFFFFFFF;
let masks: [I32x8; 8] = [
i32x8_from_values_u32(m, m, m, m, m, m, m, m),
i32x8_from_values_u32(0, m, m, m, m, m, m, m),
i32x8_from_values_u32(0, 0, m, m, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, m, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, m),
];
let nomask: I32x8 = i32x8_from_values_u32(m, m, m, m, m, m, m, m);
let fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0);
if row < nrows { if row < nrows {
let col = col as i64; let col = col as i64;
let ncols = tensor.cols; let ncols = tensor.cols;
let (addr, side) = tensor.q4_address(row as i64, col); let (addr, side) = tensor.q4_address(row as i64, col);
let i = load_i16x8(addr as *const I16x8); let i = load_i16x8(addr as *const I16x8);
let even_mask = i16x8_singleton_u16(0x0F0F); let evens = and_i16x8(i, *even_mask);
let odd_mask = i16x8_singleton_u16(0xF0F0); let odds = and_i16x8(i, *odd_mask);
let evens = and_i16x8(i, even_mask);
let odds = and_i16x8(i, odd_mask);
let odds = shift_right_by_4_i16x8(odds); let odds = shift_right_by_4_i16x8(odds);
let indices1 = extend_i8_to_i32_i32x8(odds); let indices1 = extend_i8_to_i32_i32x8(odds);
@ -1319,32 +1318,32 @@ impl Tensor {
let unquantized4: F32x8 = let unquantized4: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices4); gather_scale4_f32x8(quants as *const f32, indices4);
let quan1_mask: I32x8 = if col <= ncols - 8 { let quan1_mask: I32x8 = if col <= ncols - 8 {
nomask *nomask
} else if col < ncols { } else if col < ncols {
masks[(col % 8) as usize] masks[(col % 8) as usize]
} else { } else {
fullmask *fullmask
}; };
let quan2_mask: I32x8 = if col <= ncols - 16 { let quan2_mask: I32x8 = if col <= ncols - 16 {
nomask *nomask
} else if col < ncols - 8 { } else if col < ncols - 8 {
masks[(col % 8) as usize] masks[(col % 8) as usize]
} else { } else {
fullmask *fullmask
}; };
let quan3_mask: I32x8 = if col <= ncols - 24 { let quan3_mask: I32x8 = if col <= ncols - 24 {
nomask *nomask
} else if col < ncols - 16 { } else if col < ncols - 16 {
masks[(col % 8) as usize] masks[(col % 8) as usize]
} else { } else {
fullmask *fullmask
}; };
let quan4_mask: I32x8 = if col <= ncols - 32 { let quan4_mask: I32x8 = if col <= ncols - 32 {
nomask *nomask
} else if col < ncols - 24 { } else if col < ncols - 24 {
masks[(col % 8) as usize] masks[(col % 8) as usize]
} else { } else {
fullmask *fullmask
}; };
let unquantized1 = and_f32x8(unquantized1, quan1_mask); let unquantized1 = and_f32x8(unquantized1, quan1_mask);
let unquantized2 = and_f32x8(unquantized2, quan2_mask); let unquantized2 = and_f32x8(unquantized2, quan2_mask);
@ -1429,6 +1428,11 @@ impl Tensor {
let self_cols: usize = self.cols as usize; let self_cols: usize = self.cols as usize;
let self_cols_capacity: usize = self.capacity_cols as usize; let self_cols_capacity: usize = self.capacity_cols as usize;
let self_cols_its = if self_cols % 4 == 0 {
self_cols / 4
} else {
self_cols / 4 + 1
};
// src_cols_its == also the shared dimension between src and other. // src_cols_its == also the shared dimension between src and other.
let src_cols_its = if src_cols % 32 == 0 { let src_cols_its = if src_cols % 32 == 0 {
src_cols / 32 src_cols / 32
@ -1455,11 +1459,15 @@ impl Tensor {
let quants: [F32x8; 2] = let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
for col in 0..self_cols { for col_raw in 0..self_cols_its {
let row_col = row * self_cols + col; let row_col = row * self_cols_its + col_raw;
if row_col % nthreads != thread_idx { if row_col % nthreads != thread_idx {
continue; continue;
} }
let col0 = col_raw * 4;
let col1 = col_raw * 4 + 1;
let col2 = col_raw * 4 + 2;
let col3 = col_raw * 4 + 3;
#[inline] #[inline]
fn load_f32( fn load_f32(
@ -1488,29 +1496,435 @@ impl Tensor {
quants: *const F32x8, quants: *const F32x8,
) -> (F32x8, F32x8, F32x8, F32x8) { ) -> (F32x8, F32x8, F32x8, F32x8) {
unsafe { unsafe {
let m: u32 = 0xFFFFFFFF; if row < nrows {
let masks: [I32x8; 8] = [ let col = col as i64;
i32x8_from_values_u32(m, m, m, m, m, m, m, m), let ncols = tensor.cols;
i32x8_from_values_u32(0, m, m, m, m, m, m, m), let (addr, side) = tensor.q4_address(row as i64, col);
i32x8_from_values_u32(0, 0, m, m, m, m, m, m), let i = load_i16x8(addr as *const I16x8);
i32x8_from_values_u32(0, 0, 0, m, m, m, m, m), let evens = and_i16x8(i, *even_mask);
i32x8_from_values_u32(0, 0, 0, 0, m, m, m, m), let odds = and_i16x8(i, *odd_mask);
i32x8_from_values_u32(0, 0, 0, 0, 0, m, m, m), let odds = shift_right_by_4_i16x8(odds);
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, m), let indices1 = extend_i8_to_i32_i32x8(odds);
let odds_shifted = shift_right_by_64_i128(odds);
let indices2 = extend_i8_to_i32_i32x8(odds_shifted);
let indices3 = extend_i8_to_i32_i32x8(evens);
let indices4 =
extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens));
let unquantized1: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices1);
let unquantized2: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices2);
let unquantized3: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices3);
let unquantized4: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices4);
let quan1_mask: I32x8 = if col <= ncols - 8 {
*nomask
} else if col < ncols {
masks[(col % 8) as usize]
} else {
*fullmask
};
let quan2_mask: I32x8 = if col <= ncols - 16 {
*nomask
} else if col < ncols - 8 {
masks[(col % 8) as usize]
} else {
*fullmask
};
let quan3_mask: I32x8 = if col <= ncols - 24 {
*nomask
} else if col < ncols - 16 {
masks[(col % 8) as usize]
} else {
*fullmask
};
let quan4_mask: I32x8 = if col <= ncols - 32 {
*nomask
} else if col < ncols - 24 {
masks[(col % 8) as usize]
} else {
*fullmask
};
let unquantized1 = and_f32x8(unquantized1, quan1_mask);
let unquantized2 = and_f32x8(unquantized2, quan2_mask);
let unquantized3 = and_f32x8(unquantized3, quan3_mask);
let unquantized4 = and_f32x8(unquantized4, quan4_mask);
(unquantized1, unquantized2, unquantized3, unquantized4)
} else {
(f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero())
}
}
}
let mut targets8: [[F32x8; 4]; 4] = [
[f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()],
[f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()],
[f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()],
[f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()],
]; ];
let nomask: I32x8 = i32x8_from_values_u32(m, m, m, m, m, m, m, m);
let fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0);
for p in 0..src_cols_its {
// Macro to make code shorter
macro_rules! lo {
($col:expr, $p:expr) => {
load_f32(
other_data,
$col,
$p,
other_cols,
other_rows,
other_cols_capacity,
)
};
}
let other8_00: F32x8 = lo!(col0, p * 32);
let other8_01: F32x8 = lo!(col0, p * 32 + 8);
let other8_02: F32x8 = lo!(col0, p * 32 + 16);
let other8_03: F32x8 = lo!(col0, p * 32 + 24);
let other8_10: F32x8 = lo!(col1, p * 32);
let other8_11: F32x8 = lo!(col1, p * 32 + 8);
let other8_12: F32x8 = lo!(col1, p * 32 + 16);
let other8_13: F32x8 = lo!(col1, p * 32 + 24);
let other8_20: F32x8 = lo!(col2, p * 32);
let other8_21: F32x8 = lo!(col2, p * 32 + 8);
let other8_22: F32x8 = lo!(col2, p * 32 + 16);
let other8_23: F32x8 = lo!(col2, p * 32 + 24);
let other8_30: F32x8 = lo!(col3, p * 32);
let other8_31: F32x8 = lo!(col3, p * 32 + 8);
let other8_32: F32x8 = lo!(col3, p * 32 + 16);
let other8_33: F32x8 = lo!(col3, p * 32 + 24);
let (src8_0, src8_1, src8_2, src8_3): (F32x8, F32x8, F32x8, F32x8) =
load_k4_to_f32(&src, row, p * 32, src_rows, quants.as_ptr());
targets8[0][0] = fma_f32x8(src8_0, other8_00, targets8[0][0]);
targets8[0][1] = fma_f32x8(src8_1, other8_01, targets8[0][1]);
targets8[0][2] = fma_f32x8(src8_2, other8_02, targets8[0][2]);
targets8[0][3] = fma_f32x8(src8_3, other8_03, targets8[0][3]);
targets8[1][0] = fma_f32x8(src8_0, other8_10, targets8[1][0]);
targets8[1][1] = fma_f32x8(src8_1, other8_11, targets8[1][1]);
targets8[1][2] = fma_f32x8(src8_2, other8_12, targets8[1][2]);
targets8[1][3] = fma_f32x8(src8_3, other8_13, targets8[1][3]);
targets8[2][0] = fma_f32x8(src8_0, other8_20, targets8[2][0]);
targets8[2][1] = fma_f32x8(src8_1, other8_21, targets8[2][1]);
targets8[2][2] = fma_f32x8(src8_2, other8_22, targets8[2][2]);
targets8[2][3] = fma_f32x8(src8_3, other8_23, targets8[2][3]);
targets8[3][0] = fma_f32x8(src8_0, other8_30, targets8[3][0]);
targets8[3][1] = fma_f32x8(src8_1, other8_31, targets8[3][1]);
targets8[3][2] = fma_f32x8(src8_2, other8_32, targets8[3][2]);
targets8[3][3] = fma_f32x8(src8_3, other8_33, targets8[3][3]);
}
let target00 = horizontal_sum_f32x8(targets8[0][0]);
let target01 = horizontal_sum_f32x8(targets8[0][1]);
let target02 = horizontal_sum_f32x8(targets8[0][2]);
let target03 = horizontal_sum_f32x8(targets8[0][3]);
let target0 = target00 + target01 + target02 + target03;
let target10 = horizontal_sum_f32x8(targets8[1][0]);
let target11 = horizontal_sum_f32x8(targets8[1][1]);
let target12 = horizontal_sum_f32x8(targets8[1][2]);
let target13 = horizontal_sum_f32x8(targets8[1][3]);
let target1 = target10 + target11 + target12 + target13;
let target20 = horizontal_sum_f32x8(targets8[2][0]);
let target21 = horizontal_sum_f32x8(targets8[2][1]);
let target22 = horizontal_sum_f32x8(targets8[2][2]);
let target23 = horizontal_sum_f32x8(targets8[2][3]);
let target2 = target20 + target21 + target22 + target23;
let target30 = horizontal_sum_f32x8(targets8[3][0]);
let target31 = horizontal_sum_f32x8(targets8[3][1]);
let target32 = horizontal_sum_f32x8(targets8[3][2]);
let target33 = horizontal_sum_f32x8(targets8[3][3]);
let target3 = target30 + target31 + target32 + target33;
*tgt_data.add(row * self_cols_capacity + col0) = target0;
if col1 < self_cols {
*tgt_data.add(row * self_cols_capacity + col1) = target1;
}
if col2 < self_cols {
*tgt_data.add(row * self_cols_capacity + col2) = target2;
}
if col3 < self_cols {
*tgt_data.add(row * self_cols_capacity + col3) = target3;
}
}
}
});
}
}
fn matrix_vector_mul_inplace_transposed_f32_and_k4bit(&mut self, src: &Tensor, other: &Tensor) {
// Assume: size checks have been done already.
assert!(src.dtype == TensorDType::Float32);
assert!(other.dtype == TensorDType::K4BitQuantization);
assert!(self.dtype == TensorDType::Float32);
assert_eq!(other.rows, 1);
assert_eq!(self.cols, 1);
unsafe {
let src_rows: usize = src.rows as usize;
let src_cols: usize = src.cols as usize;
let src_cols_capacity: usize = src.capacity_cols as usize;
let other_cols: usize = other.cols as usize;
let other_rows: usize = other.rows as usize;
let other_cols_capacity: usize = other.capacity_cols as usize;
let self_rows: usize = self.rows as usize;
let self_cols: usize = self.cols as usize;
let self_cols_capacity: usize = self.capacity_cols as usize;
// src_cols_its == also the shared dimension between src and other.
let src_cols_its = if src_cols % 32 == 0 {
src_cols / 32
} else {
src_cols / 32 + 1
};
debug_assert!(!other.q4_data.is_null());
let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data);
let other_data: WrappedPtr = WrappedPtr::wrap(other.data);
let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data);
let other_q4_data: WrappedPtr = WrappedPtr::wrap(other.q4_data);
let nthreads: usize = rayon::current_num_threads();
(0..nthreads).into_par_iter().for_each(|thread_idx| {
let other_q4_data: *const u8 = other_q4_data.unwrap() as *const u8;
let src_data: *const f32 = src_data_wrap.unwrap() as *const f32;
let other_data: *const u8 = other_data.unwrap() as *const u8;
let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32;
let quant0 = load_i16x8(other_q4_data.add(0) as *const I16x8);
let quant1 = load_i16x8(other_q4_data.add(16) as *const I16x8);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
let col = 0;
for row in 0..self_rows {
if row % nthreads != thread_idx {
continue;
}
#[inline]
fn load_f32(
src: *const f32,
row: usize,
col: usize,
ncols: usize,
nrows: usize,
cols_capacity: usize,
) -> F32x8 {
unsafe {
if row >= nrows || col >= ncols {
f32x8_zero()
} else {
load_f32x8(src.add(row * cols_capacity + col) as *const F32x8)
}
}
}
#[inline]
fn load_k4_to_f32(
tensor: &Tensor,
row: usize,
col: usize,
nrows: usize,
quants: *const F32x8,
) -> (F32x8, F32x8, F32x8, F32x8) {
unsafe {
if row < nrows {
let col = col as i64;
let ncols = tensor.cols;
let (addr, side) = tensor.q4_address(row as i64, col);
let i = load_i16x8(addr as *const I16x8);
let evens = and_i16x8(i, *even_mask);
let odds = and_i16x8(i, *odd_mask);
let odds = shift_right_by_4_i16x8(odds);
let indices1 = extend_i8_to_i32_i32x8(odds);
let odds_shifted = shift_right_by_64_i128(odds);
let indices2 = extend_i8_to_i32_i32x8(odds_shifted);
let indices3 = extend_i8_to_i32_i32x8(evens);
let indices4 =
extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens));
let unquantized1: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices1);
let unquantized2: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices2);
let unquantized3: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices3);
let unquantized4: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices4);
let quan1_mask: I32x8 = if col <= ncols - 8 {
*nomask
} else if col < ncols {
masks[(col % 8) as usize]
} else {
*fullmask
};
let quan2_mask: I32x8 = if col <= ncols - 16 {
*nomask
} else if col < ncols - 8 {
masks[(col % 8) as usize]
} else {
*fullmask
};
let quan3_mask: I32x8 = if col <= ncols - 24 {
*nomask
} else if col < ncols - 16 {
masks[(col % 8) as usize]
} else {
*fullmask
};
let quan4_mask: I32x8 = if col <= ncols - 32 {
*nomask
} else if col < ncols - 24 {
masks[(col % 8) as usize]
} else {
*fullmask
};
let unquantized1 = and_f32x8(unquantized1, quan1_mask);
let unquantized2 = and_f32x8(unquantized2, quan2_mask);
let unquantized3 = and_f32x8(unquantized3, quan3_mask);
let unquantized4 = and_f32x8(unquantized4, quan4_mask);
(unquantized1, unquantized2, unquantized3, unquantized4)
} else {
(f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero())
}
}
}
let mut targets8: [F32x8; 4] =
[f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()];
for p in 0..src_cols_its {
let (other8_0, other8_1, other8_2, other8_3) =
load_k4_to_f32(&other, col, p * 32, other_rows, quants.as_ptr());
let src8_0 =
load_f32(src_data, row, p * 32, src_cols, src_rows, src_cols_capacity);
let src8_1 = load_f32(
src_data,
row,
p * 32 + 8,
src_cols,
src_rows,
src_cols_capacity,
);
let src8_2 = load_f32(
src_data,
row,
p * 32 + 16,
src_cols,
src_rows,
src_cols_capacity,
);
let src8_3 = load_f32(
src_data,
row,
p * 32 + 24,
src_cols,
src_rows,
src_cols_capacity,
);
targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]);
targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]);
targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]);
targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]);
}
let target0 = horizontal_sum_f32x8(targets8[0]);
let target1 = horizontal_sum_f32x8(targets8[1]);
let target2 = horizontal_sum_f32x8(targets8[2]);
let target3 = horizontal_sum_f32x8(targets8[3]);
let target = target0 + target1 + target2 + target3;
*tgt_data.add(row * self_cols_capacity + col) = target;
}
});
}
}
fn matrix_vector_mul_inplace_transposed_k4bit_and_f32(&mut self, src: &Tensor, other: &Tensor) {
// Assume: size checks have been done already.
assert!(src.dtype == TensorDType::K4BitQuantization);
assert!(other.dtype == TensorDType::Float32);
assert!(self.dtype == TensorDType::Float32);
assert_eq!(other.rows, 1);
assert_eq!(self.cols, 1);
unsafe {
let src_rows: usize = src.rows as usize;
let src_cols: usize = src.cols as usize;
let src_cols_capacity: usize = src.capacity_cols as usize;
let other_cols: usize = other.cols as usize;
let other_rows: usize = other.rows as usize;
let other_cols_capacity: usize = other.capacity_cols as usize;
let self_rows: usize = self.rows as usize;
let self_cols: usize = self.cols as usize;
let self_cols_capacity: usize = self.capacity_cols as usize;
// src_cols_its == also the shared dimension between src and other.
let src_cols_its = if src_cols % 32 == 0 {
src_cols / 32
} else {
src_cols / 32 + 1
};
debug_assert!(!src.q4_data.is_null());
let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data);
let other_data: WrappedPtr = WrappedPtr::wrap(other.data);
let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data);
let src_q4_data: WrappedPtr = WrappedPtr::wrap(src.q4_data);
let nthreads: usize = rayon::current_num_threads();
(0..nthreads).into_par_iter().for_each(|thread_idx| {
let src_q4_data: *const u8 = src_q4_data.unwrap() as *const u8;
let src_data: *const u8 = src_data_wrap.unwrap() as *const u8;
let other_data: *const f32 = other_data.unwrap() as *const f32;
let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32;
for row in 0..self_rows {
if row % nthreads != thread_idx {
continue;
}
let quant0 = load_i16x8(src_q4_data.add(row * 32) as *const I16x8);
let quant1 = load_i16x8(src_q4_data.add(row * 32 + 16) as *const I16x8);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
let col = 0;
#[inline]
fn load_f32(
other: *const f32,
row: usize,
col: usize,
ncols: usize,
nrows: usize,
cols_capacity: usize,
) -> F32x8 {
unsafe {
if row >= nrows || col >= ncols {
f32x8_zero()
} else {
load_f32x8(other.add(row * cols_capacity + col) as *const F32x8)
}
}
}
#[inline]
fn load_k4_to_f32(
tensor: &Tensor,
row: usize,
col: usize,
nrows: usize,
quants: *const F32x8,
) -> (F32x8, F32x8, F32x8, F32x8) {
unsafe {
if row < nrows { if row < nrows {
let col = col as i64; let col = col as i64;
let ncols = tensor.cols; let ncols = tensor.cols;
let (addr, side) = tensor.q4_address(row as i64, col); let (addr, side) = tensor.q4_address(row as i64, col);
let i = load_i16x8(addr as *const I16x8); let i = load_i16x8(addr as *const I16x8);
let even_mask = i16x8_singleton_u16(0x0F0F); let evens = and_i16x8(i, *even_mask);
let odd_mask = i16x8_singleton_u16(0xF0F0); let odds = and_i16x8(i, *odd_mask);
let evens = and_i16x8(i, even_mask);
let odds = and_i16x8(i, odd_mask);
let odds = shift_right_by_4_i16x8(odds); let odds = shift_right_by_4_i16x8(odds);
let indices1 = extend_i8_to_i32_i32x8(odds); let indices1 = extend_i8_to_i32_i32x8(odds);
@ -1529,32 +1943,32 @@ impl Tensor {
let unquantized4: F32x8 = let unquantized4: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices4); gather_scale4_f32x8(quants as *const f32, indices4);
let quan1_mask: I32x8 = if col <= ncols - 8 { let quan1_mask: I32x8 = if col <= ncols - 8 {
nomask *nomask
} else if col < ncols { } else if col < ncols {
masks[(col % 8) as usize] masks[(col % 8) as usize]
} else { } else {
fullmask *fullmask
}; };
let quan2_mask: I32x8 = if col <= ncols - 16 { let quan2_mask: I32x8 = if col <= ncols - 16 {
nomask *nomask
} else if col < ncols - 8 { } else if col < ncols - 8 {
masks[(col % 8) as usize] masks[(col % 8) as usize]
} else { } else {
fullmask *fullmask
}; };
let quan3_mask: I32x8 = if col <= ncols - 24 { let quan3_mask: I32x8 = if col <= ncols - 24 {
nomask *nomask
} else if col < ncols - 16 { } else if col < ncols - 16 {
masks[(col % 8) as usize] masks[(col % 8) as usize]
} else { } else {
fullmask *fullmask
}; };
let quan4_mask: I32x8 = if col <= ncols - 32 { let quan4_mask: I32x8 = if col <= ncols - 32 {
nomask *nomask
} else if col < ncols - 24 { } else if col < ncols - 24 {
masks[(col % 8) as usize] masks[(col % 8) as usize]
} else { } else {
fullmask *fullmask
}; };
let unquantized1 = and_f32x8(unquantized1, quan1_mask); let unquantized1 = and_f32x8(unquantized1, quan1_mask);
let unquantized2 = and_f32x8(unquantized2, quan2_mask); let unquantized2 = and_f32x8(unquantized2, quan2_mask);
@ -1617,7 +2031,6 @@ impl Tensor {
let target = target0 + target1 + target2 + target3; let target = target0 + target1 + target2 + target3;
*tgt_data.add(row * self_cols_capacity + col) = target; *tgt_data.add(row * self_cols_capacity + col) = target;
} }
}
}); });
} }
} }
@ -2091,11 +2504,15 @@ impl Tensor {
} }
assert_eq!(other.rows, 1); assert_eq!(other.rows, 1);
// K4 bit currently has no implementation for matrix_vector_mul if self.dtype == TensorDType::K4BitQuantization {
if self.dtype == TensorDType::K4BitQuantization let mut result = unsafe { Tensor::uninitialized(self.rows, 1, other.dtype) };
|| other.dtype == TensorDType::K4BitQuantization result.matrix_vector_mul_inplace_transposed_k4bit_and_f32(self, other);
{ return result;
return self.matrix_mul_transposed(other); }
if other.dtype == TensorDType::K4BitQuantization {
let mut result = unsafe { Tensor::uninitialized(self.rows, 1, self.dtype) };
result.matrix_vector_mul_inplace_transposed_f32_and_k4bit(self, other);
return result;
} }
assert_eq!(other.dtype, self.dtype); assert_eq!(other.dtype, self.dtype);
@ -3872,4 +4289,80 @@ mod tests {
} }
} }
} }
#[test]
fn quantized_matrices_matrix_vector_mul_transposed_correctly_f32_mul_k4() {
// TODO: this test is mostly a copypaste from the matrix_mul tests except let b = 1;
let mut rng = rand::thread_rng();
for _ in 0..100 {
let a = rng.gen_range(1..=128);
let b = 1;
let c = rng.gen_range(1..=128);
let other_matrix = Tensor::random(a, c, TensorDType::Float32);
let mut reference = Tensor::zeros(b, c, TensorDType::Float32);
let mut quant_values: Vec<Vec<f32>> = Vec::with_capacity(c as usize);
for row in 0..b {
let mut quant_values_for_row: Vec<f32> = Vec::with_capacity(16);
for _ in 0..16 {
quant_values_for_row.push(rng.gen_range(0.0..=1.0));
}
quant_values.push(quant_values_for_row);
}
let mut quantized_values: Vec<Vec<u8>> = Vec::with_capacity(b as usize);
for row in 0..b {
let mut quant_values_for_row: Vec<u8> = Vec::with_capacity(c as usize);
for col in 0..c {
let i = rng.gen_range(0..=15);
reference.set_f32(row, col, quant_values[row as usize][i as usize]);
quant_values_for_row.push(i as u8);
}
quantized_values.push(quant_values_for_row);
}
let quantized = Tensor::make_k4bit_from_fn(
b,
c,
|row, col| quantized_values[row as usize][col as usize],
|row| {
let mut result: [f32; 16] = [0.0; 16];
for col in 0..16 {
result[col] = quant_values[row as usize][col];
}
result
},
);
assert_eq!(reference.rows(), quantized.rows());
assert_eq!(reference.cols(), quantized.cols());
for row in 0..reference.rows {
for col in 0..reference.cols {
// The quantized table always uses f16 so values may not be 100% equal.
assert_relative_eq!(
reference.get_f32(row, col),
quantized.get_f32(row, col),
epsilon = 1e-1,
);
}
}
let mult1 = other_matrix.matrix_mul_transposed(&reference);
let mult2 = other_matrix.matrix_mul_transposed(&quantized);
assert_eq!(mult1.rows(), mult2.rows());
assert_eq!(mult1.cols(), mult2.cols());
for row in 0..mult1.rows {
for col in 0..mult1.cols {
assert_relative_eq!(
mult1.get_f32(row, col),
mult2.get_f32(row, col),
epsilon = 1e-1,
);
}
}
}
}
} }

Loading…
Cancel
Save