K4 bit inference works now. Performance isn't as good as I'd like it to be though.

k4bit
Mikko Juola 3 years ago
parent 40121e1c82
commit 2f3e9bc0f5

@ -53,6 +53,9 @@ struct Cli {
#[arg(long, action)]
f16: bool,
#[arg(long, action)]
k4: bool,
#[cfg(feature = "opencl")]
#[arg(long)]
opencl_device: Option<usize>,
@ -233,6 +236,9 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
if cli.f16 {
data_settings = data_settings.force_f16();
}
if cli.k4 {
data_settings = data_settings.force_k4();
}
pln!("Loading transformer weights from {}...", model_path);
let tr = Transformer::from_unpickled(

@ -921,7 +921,10 @@ impl Tensor {
// We don't have implementation for f16, so don't use the vector function if we have
// f16
#[cfg(not(feature = "opencl"))]
if other.rows == 1 && other.dtype != TensorDType::Float16 {
if other.rows == 1
&& (other.dtype != TensorDType::K4BitQuantization
&& self.dtype != TensorDType::K4BitQuantization)
{
return self.matrix_vector_mul_transposed(other);
}
#[cfg(feature = "opencl")]
@ -1117,8 +1120,8 @@ impl Tensor {
// Casts data type to whatever the other tensors data type is.
pub fn to_same_type(&self, other: &Tensor) -> Tensor {
let result = self.clone();
if result.dtype() == other.dtype() {
if self.dtype() == other.dtype() {
let result = self.clone();
return result;
}
match other.dtype {
@ -1128,6 +1131,35 @@ impl Tensor {
}
}
// Casts data type so that it is valid to do: other.matrix_mul_transposed(result)
pub fn to_compatible_matrix_mul_type(&self, other: &Tensor) -> Tensor {
if self.dtype() != TensorDType::K4BitQuantization
&& other.dtype() != TensorDType::K4BitQuantization
{
return self.to_same_type(other);
}
if other.dtype() == TensorDType::K4BitQuantization {
return self.to_f32();
}
unimplemented!()
}
// Casts data type so that it is valid to do: result.matrix_mul_transposed(other)
pub fn to_compatible_matrix_mul_type2(&self, other: &Tensor) -> Tensor {
if self.dtype() != TensorDType::K4BitQuantization
&& other.dtype() != TensorDType::K4BitQuantization
{
return self.to_same_type(other);
}
if other.dtype() == TensorDType::Float32 {
return self.clone();
}
if other.dtype() == TensorDType::K4BitQuantization {
return self.to_f32();
}
unimplemented!()
}
pub fn into_same_type(self, other: &Tensor) -> Tensor {
if self.dtype() == other.dtype() {
return self;
@ -1170,6 +1202,216 @@ impl Tensor {
std::mem::drop(other_od);
}
fn matrix_mul_inplace_transposed_f32_and_k4bit(&mut self, src: &Tensor, other: &Tensor) {
// Assume: size checks have been done already.
assert!(src.dtype == TensorDType::Float32);
assert!(other.dtype == TensorDType::K4BitQuantization);
assert!(self.dtype == TensorDType::Float32);
unsafe {
let src_rows: usize = src.rows as usize;
let src_cols: usize = src.cols as usize;
let src_cols_capacity: usize = src.capacity_cols as usize;
let other_cols: usize = other.cols as usize;
let other_rows: usize = other.rows as usize;
let other_cols_capacity: usize = other.capacity_cols as usize;
let self_rows: usize = self.rows as usize;
let self_cols: usize = self.cols as usize;
let self_cols_capacity: usize = self.capacity_cols as usize;
// src_cols_its == also the shared dimension between src and other.
let src_cols_its = if src_cols % 32 == 0 {
src_cols / 32
} else {
src_cols / 32 + 1
};
debug_assert!(!other.q4_data.is_null());
let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data);
let other_data: WrappedPtr = WrappedPtr::wrap(other.data);
let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data);
let other_q4_data: WrappedPtr = WrappedPtr::wrap(other.q4_data);
let nthreads: usize = rayon::current_num_threads();
(0..nthreads).into_par_iter().for_each(|thread_idx| {
let other_q4_data: *const u8 = other_q4_data.unwrap() as *const u8;
let src_data: *const f32 = src_data_wrap.unwrap() as *const f32;
let other_data: *const u8 = other_data.unwrap() as *const u8;
let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32;
for row in 0..self_rows {
for col in 0..self_cols {
let row_col = row * self_cols + col;
if row_col % nthreads != thread_idx {
continue;
}
let quant0 = load_i16x8(other_q4_data.add(col * 32) as *const I16x8);
let quant1 = load_i16x8(other_q4_data.add(col * 32 + 16) as *const I16x8);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
#[inline]
fn load_f32(
src: *const f32,
row: usize,
col: usize,
ncols: usize,
nrows: usize,
cols_capacity: usize,
) -> F32x8 {
unsafe {
if row >= nrows || col >= ncols {
f32x8_zero()
} else {
load_f32x8(src.add(row * cols_capacity + col) as *const F32x8)
}
}
}
#[inline]
fn load_k4_to_f32(
tensor: &Tensor,
row: usize,
col: usize,
nrows: usize,
quants: *const F32x8,
) -> (F32x8, F32x8, F32x8, F32x8) {
unsafe {
let m: u32 = 0xFFFFFFFF;
let masks: [I32x8; 8] = [
i32x8_from_values_u32(m, m, m, m, m, m, m, m),
i32x8_from_values_u32(0, m, m, m, m, m, m, m),
i32x8_from_values_u32(0, 0, m, m, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, m, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, m),
];
let nomask: I32x8 = i32x8_from_values_u32(m, m, m, m, m, m, m, m);
let fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0);
if row < nrows {
let col = col as i64;
let ncols = tensor.cols;
let (addr, side) = tensor.q4_address(row as i64, col);
let i = load_i16x8(addr as *const I16x8);
let even_mask = i16x8_singleton_u16(0x0F0F);
let odd_mask = i16x8_singleton_u16(0xF0F0);
let evens = and_i16x8(i, even_mask);
let odds = and_i16x8(i, odd_mask);
let odds = shift_right_by_4_i16x8(odds);
let indices1 = extend_i8_to_i32_i32x8(odds);
let odds_shifted = shift_right_by_64_i128(odds);
let indices2 = extend_i8_to_i32_i32x8(odds_shifted);
let indices3 = extend_i8_to_i32_i32x8(evens);
let indices4 =
extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens));
let unquantized1: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices1);
let unquantized2: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices2);
let unquantized3: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices3);
let unquantized4: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices4);
let quan1_mask: I32x8 = if col <= ncols - 8 {
nomask
} else if col < ncols {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan2_mask: I32x8 = if col <= ncols - 16 {
nomask
} else if col < ncols - 8 {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan3_mask: I32x8 = if col <= ncols - 24 {
nomask
} else if col < ncols - 16 {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan4_mask: I32x8 = if col <= ncols - 32 {
nomask
} else if col < ncols - 24 {
masks[(col % 8) as usize]
} else {
fullmask
};
let unquantized1 = and_f32x8(unquantized1, quan1_mask);
let unquantized2 = and_f32x8(unquantized2, quan2_mask);
let unquantized3 = and_f32x8(unquantized3, quan3_mask);
let unquantized4 = and_f32x8(unquantized4, quan4_mask);
(unquantized1, unquantized2, unquantized3, unquantized4)
} else {
(f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero())
}
}
}
let mut targets8: [F32x8; 4] =
[f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()];
for p in 0..src_cols_its {
let (other8_0, other8_1, other8_2, other8_3) =
load_k4_to_f32(&other, col, p * 32, other_rows, quants.as_ptr());
let src8_0 = load_f32(
src_data,
row,
p * 32,
src_cols,
src_rows,
src_cols_capacity,
);
let src8_1 = load_f32(
src_data,
row,
p * 32 + 8,
src_cols,
src_rows,
src_cols_capacity,
);
let src8_2 = load_f32(
src_data,
row,
p * 32 + 16,
src_cols,
src_rows,
src_cols_capacity,
);
let src8_3 = load_f32(
src_data,
row,
p * 32 + 24,
src_cols,
src_rows,
src_cols_capacity,
);
targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]);
targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]);
targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]);
targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]);
}
let target0 = horizontal_sum_f32x8(targets8[0]);
let target1 = horizontal_sum_f32x8(targets8[1]);
let target2 = horizontal_sum_f32x8(targets8[2]);
let target3 = horizontal_sum_f32x8(targets8[3]);
let target = target0 + target1 + target2 + target3;
*tgt_data.add(row * self_cols_capacity + col) = target;
}
}
});
}
}
fn matrix_mul_inplace_transposed_k4bit_and_f32(&mut self, src: &Tensor, other: &Tensor) {
// Assume: size checks have been done already.
assert!(src.dtype == TensorDType::K4BitQuantization);
@ -1246,19 +1488,19 @@ impl Tensor {
quants: *const F32x8,
) -> (F32x8, F32x8, F32x8, F32x8) {
unsafe {
let M: u32 = 0xFFFFFFFF;
let MASKS: [I32x8; 8] = [
i32x8_from_values_u32(M, M, M, M, M, M, M, M),
i32x8_from_values_u32(0, M, M, M, M, M, M, M),
i32x8_from_values_u32(0, 0, M, M, M, M, M, M),
i32x8_from_values_u32(0, 0, 0, M, M, M, M, M),
i32x8_from_values_u32(0, 0, 0, 0, M, M, M, M),
i32x8_from_values_u32(0, 0, 0, 0, 0, M, M, M),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, M, M),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, M),
let m: u32 = 0xFFFFFFFF;
let masks: [I32x8; 8] = [
i32x8_from_values_u32(m, m, m, m, m, m, m, m),
i32x8_from_values_u32(0, m, m, m, m, m, m, m),
i32x8_from_values_u32(0, 0, m, m, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, m, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, m),
];
let NOMASK: I32x8 = i32x8_from_values_u32(M, M, M, M, M, M, M, M);
let FULLMASK: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0);
let nomask: I32x8 = i32x8_from_values_u32(m, m, m, m, m, m, m, m);
let fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0);
if row < nrows {
let col = col as i64;
@ -1287,32 +1529,32 @@ impl Tensor {
let unquantized4: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices4);
let quan1_mask: I32x8 = if col <= ncols - 8 {
NOMASK
nomask
} else if col < ncols {
MASKS[(col % 8) as usize]
masks[(col % 8) as usize]
} else {
FULLMASK
fullmask
};
let quan2_mask: I32x8 = if col <= ncols - 16 {
NOMASK
nomask
} else if col < ncols - 8 {
MASKS[(col % 8) as usize]
masks[(col % 8) as usize]
} else {
FULLMASK
fullmask
};
let quan3_mask: I32x8 = if col <= ncols - 24 {
NOMASK
nomask
} else if col < ncols - 16 {
MASKS[(col % 8) as usize]
masks[(col % 8) as usize]
} else {
FULLMASK
fullmask
};
let quan4_mask: I32x8 = if col <= ncols - 32 {
NOMASK
nomask
} else if col < ncols - 24 {
MASKS[(col % 8) as usize]
masks[(col % 8) as usize]
} else {
FULLMASK
fullmask
};
let unquantized1 = and_f32x8(unquantized1, quan1_mask);
let unquantized2 = and_f32x8(unquantized2, quan2_mask);
@ -1401,6 +1643,7 @@ impl Tensor {
}
if src.dtype != other.dtype
&& (src.dtype != TensorDType::K4BitQuantization || other.dtype != TensorDType::Float32)
&& (src.dtype != TensorDType::Float32 || other.dtype != TensorDType::K4BitQuantization)
{
panic!("Invalid matrix multiplication, different dtypes");
}
@ -1414,6 +1657,9 @@ impl Tensor {
if src.dtype == TensorDType::K4BitQuantization && other.dtype == TensorDType::Float32 {
return self.matrix_mul_inplace_transposed_k4bit_and_f32(src, other);
}
if src.dtype == TensorDType::Float32 && other.dtype == TensorDType::K4BitQuantization {
return self.matrix_mul_inplace_transposed_f32_and_k4bit(src, other);
}
match src.dtype {
TensorDType::K4BitQuantization => unimplemented!(),
@ -1844,6 +2090,14 @@ impl Tensor {
);
}
assert_eq!(other.rows, 1);
// K4 bit currently has no implementation for matrix_vector_mul
if self.dtype == TensorDType::K4BitQuantization
|| other.dtype == TensorDType::K4BitQuantization
{
return self.matrix_mul_transposed(other);
}
assert_eq!(other.dtype, self.dtype);
#[allow(unreachable_patterns)]
@ -3471,7 +3725,7 @@ mod tests {
}
#[test]
fn quantized_matrices_matrix_mul_transposed_correctly() {
fn quantized_matrices_matrix_mul_transposed_correctly_k4_mul_f32() {
let mut rng = rand::thread_rng();
for _ in 0..100 {
let a = rng.gen_range(1..=128);
@ -3543,4 +3797,79 @@ mod tests {
}
}
}
#[test]
fn quantized_matrices_matrix_mul_transposed_correctly_f32_mul_k4() {
let mut rng = rand::thread_rng();
for _ in 0..100 {
let a = rng.gen_range(1..=128);
let b = rng.gen_range(1..=128);
let c = rng.gen_range(1..=128);
let other_matrix = Tensor::random(a, c, TensorDType::Float32);
let mut reference = Tensor::zeros(b, c, TensorDType::Float32);
let mut quant_values: Vec<Vec<f32>> = Vec::with_capacity(c as usize);
for row in 0..b {
let mut quant_values_for_row: Vec<f32> = Vec::with_capacity(16);
for _ in 0..16 {
quant_values_for_row.push(rng.gen_range(0.0..=1.0));
}
quant_values.push(quant_values_for_row);
}
let mut quantized_values: Vec<Vec<u8>> = Vec::with_capacity(b as usize);
for row in 0..b {
let mut quant_values_for_row: Vec<u8> = Vec::with_capacity(c as usize);
for col in 0..c {
let i = rng.gen_range(0..=15);
reference.set_f32(row, col, quant_values[row as usize][i as usize]);
quant_values_for_row.push(i as u8);
}
quantized_values.push(quant_values_for_row);
}
let quantized = Tensor::make_k4bit_from_fn(
b,
c,
|row, col| quantized_values[row as usize][col as usize],
|row| {
let mut result: [f32; 16] = [0.0; 16];
for col in 0..16 {
result[col] = quant_values[row as usize][col];
}
result
},
);
assert_eq!(reference.rows(), quantized.rows());
assert_eq!(reference.cols(), quantized.cols());
for row in 0..reference.rows {
for col in 0..reference.cols {
// The quantized table always uses f16 so values may not be 100% equal.
assert_relative_eq!(
reference.get_f32(row, col),
quantized.get_f32(row, col),
epsilon = 1e-1,
);
}
}
let mult1 = other_matrix.matrix_mul_transposed(&reference);
let mult2 = other_matrix.matrix_mul_transposed(&quantized);
assert_eq!(mult1.rows(), mult2.rows());
assert_eq!(mult1.cols(), mult2.cols());
for row in 0..mult1.rows {
for col in 0..mult1.cols {
assert_relative_eq!(
mult1.get_f32(row, col),
mult2.get_f32(row, col),
epsilon = 1e-1,
);
}
}
}
}
}

@ -33,7 +33,7 @@ pub struct Transformer {
}
// Clone is cheap
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct DataSettings {
#[cfg(feature = "opencl")]
use_opencl_for_feedforward: bool,
@ -43,6 +43,7 @@ pub struct DataSettings {
cl: Option<OpenCL>,
force_f16: bool,
force_k4: bool,
}
// OpenCL is safe to send to threads but Rust doesn't know that
@ -56,6 +57,7 @@ impl DataSettings {
use_opencl_for_feedforward: false,
use_opencl_for_attention: false,
force_f16: false,
force_k4: false,
cl: cl.clone(),
}
}
@ -63,7 +65,10 @@ impl DataSettings {
#[allow(clippy::new_without_default)]
#[cfg(not(feature = "opencl"))]
pub fn new() -> Self {
DataSettings { force_f16: false }
DataSettings {
force_f16: false,
force_k4: false,
}
}
#[cfg(feature = "opencl")]
@ -78,6 +83,17 @@ impl DataSettings {
pub fn force_f16(mut self) -> DataSettings {
self.force_f16 = true;
if self.force_k4 {
panic!("Cannot force both f16 and k4");
}
self
}
pub fn force_k4(mut self) -> DataSettings {
self.force_k4 = true;
if self.force_f16 {
panic!("Cannot force both f16 and k4");
}
self
}
}
@ -457,13 +473,14 @@ impl FeedForward {
FromPiecesDirection::Rows,
)?;
w1 = crate::weight_compression::quantize(&w1);
panic!("stop");
if data_settings.force_f16 {
w1 = w1.to_f16();
w2 = w2.to_f16();
w3 = w3.to_f16();
} else if data_settings.force_k4 {
w1 = w1.quantize();
w2 = w2.quantize();
w3 = w3.quantize();
}
#[cfg(feature = "opencl")]
@ -491,7 +508,7 @@ impl FeedForward {
pub fn forward(&self, x: &mut Tensor) -> Tensor {
let original_x_dtype = x.dtype();
if x.dtype() != self.w1.dtype() {
*x = x.to_same_type(&self.w1);
*x = x.to_compatible_matrix_mul_type(&self.w1);
}
#[cfg(feature = "opencl")]
let x_was_on_cpu: bool;
@ -516,7 +533,7 @@ impl FeedForward {
let w1_out = w1_out.silu();
let mut w1w3_out = w1_out.hadamard_product(&w3_out).transpose();
if w1w3_out.dtype() != self.w2.dtype() {
w1w3_out = w1w3_out.to_same_type(&self.w2);
w1w3_out = w1w3_out.to_compatible_matrix_mul_type(&self.w2);
}
#[cfg(not(feature = "opencl"))]
{
@ -578,6 +595,11 @@ impl Attention {
wk = wk.to_f16();
wv = wv.to_f16();
wo = wo.to_f16();
} else if data_settings.force_k4 {
wq = wq.quantize();
wk = wk.quantize();
wv = wv.quantize();
wo = wo.quantize();
}
#[cfg(feature = "opencl")]
@ -616,7 +638,7 @@ impl Attention {
) -> Tensor {
let original_x_dtype = x.dtype();
if x.dtype() != self.wq.dtype() {
*x = x.to_same_type(&self.wq);
*x = x.to_compatible_matrix_mul_type2(&self.wq);
}
#[cfg(feature = "opencl")]
@ -744,7 +766,7 @@ impl Attention {
{
let xq_row = Tensor::concat(&concat_vec2).view(1, self.wo.rows());
xq_row
.into_same_type(&self.wo)
.to_compatible_matrix_mul_type2(&self.wo)
.matrix_mul_transposed(&self.wo)
}
#[cfg(feature = "opencl")]

@ -9,7 +9,7 @@ pub fn quantize(tensor: &Tensor) -> Tensor {
* This is a simplistic rounding quantizer. It splits each row in a tensor to 16 buckets and
* takes the average value in said buckets as the quantized weight.
*/
let mut result = Tensor::zeros(tensor.rows(), tensor.cols(), tensor.dtype());
let mut allowed_values_by_row: Vec<Vec<f32>> = Vec::with_capacity(tensor.rows() as usize);
for row in 0..tensor.rows() {
let mut values: Vec<f32> = Vec::with_capacity(tensor.cols() as usize);
values.truncate(0);
@ -43,8 +43,15 @@ pub fn quantize(tensor: &Tensor) -> Tensor {
allowed_values[0] = mi;
allowed_values[15] = ma;
allowed_values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
allowed_values_by_row.push(allowed_values);
}
for col in 0..tensor.cols() {
//let mut result = Tensor::zeros(tensor.rows(), tensor.cols(), tensor.dtype());
let result = Tensor::make_k4bit_from_fn(
tensor.rows(),
tensor.cols(),
|row, col| {
let allowed_values: &[f32] = &allowed_values_by_row[row as usize];
let val = tensor.get_f32(row, col);
let mut best = 0;
let mut best_dist = std::f32::MAX;
@ -55,8 +62,16 @@ pub fn quantize(tensor: &Tensor) -> Tensor {
best_dist = dist;
}
}
result.set_f32(row, col, allowed_values[best] as f32);
}
}
best as u8
},
|row: i64| -> [f32; 16] {
let allowed_values: &[f32] = &allowed_values_by_row[row as usize];
let mut result: [f32; 16] = [0.0; 16];
for i in 0..16 {
result[i] = allowed_values[i];
}
result
},
);
result
}

Loading…
Cancel
Save