Bucketize the 4-bit quantization for more accuracy.

Currently going with 512 columns per bucket. Need to test a bit should I
go even smaller, 256 columns per bucket.
k4bit
Mikko Juola 3 years ago
parent 8cc82ae7e2
commit d7d13cd474

@ -91,7 +91,11 @@ pub struct Tensor {
data: *mut u8, data: *mut u8,
// for quantization, only used if dtype == TensorDType::K4BitQuantization // for quantization, only used if dtype == TensorDType::K4BitQuantization
// Contains 16 values per row in f16 (i.e. 32 bytes per row) // q4_data is (NxM) where N is:
//
// ((cols + 511) / 512) * 32 (i.e. 32 bytes for every 512 columns, rounded up)
//
// and M is number of rows in the tensor.
q4_data: *mut u8, q4_data: *mut u8,
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
@ -345,11 +349,9 @@ impl Tensor {
BitSide::Lower => (addr_val & 0x0F), BitSide::Lower => (addr_val & 0x0F),
} }
}; };
let table: I16x8 = if quant_val <= 7 {
unsafe { load_i16x8(self.q4_data.add(row as usize * 32) as *const I16x8) } let (table1, table2) = self.q4_lookup_table(row, col);
} else { let table = if quant_val <= 7 { table1 } else { table2 };
unsafe { load_i16x8(self.q4_data.add(row as usize * 32 + 16) as *const I16x8) }
};
let table = i16x8_as_f16_to_f32x8(table); let table = i16x8_as_f16_to_f32x8(table);
f32x8_get(table, (quant_val % 8) as usize) f32x8_get(table, (quant_val % 8) as usize)
} }
@ -1261,11 +1263,6 @@ impl Tensor {
continue; continue;
} }
let quant0 = load_i16x8(other_q4_data.add(col * 32) as *const I16x8);
let quant1 = load_i16x8(other_q4_data.add(col * 32 + 16) as *const I16x8);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
#[inline] #[inline]
fn load_f32( fn load_f32(
src: *const f32, src: *const f32,
@ -1360,6 +1357,10 @@ impl Tensor {
[f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()];
for p in 0..src_cols_its { for p in 0..src_cols_its {
let (quant0, quant1) = other.q4_lookup_table(col as i64, p as i64 * 32);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
let (other8_0, other8_1, other8_2, other8_3) = let (other8_0, other8_1, other8_2, other8_3) =
load_k4_to_f32(&other, col, p * 32, other_rows, quants.as_ptr()); load_k4_to_f32(&other, col, p * 32, other_rows, quants.as_ptr());
let src8_0 = load_f32( let src8_0 = load_f32(
@ -1454,11 +1455,6 @@ impl Tensor {
let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32; let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32;
for row in 0..self_rows { for row in 0..self_rows {
let quant0 = load_i16x8(src_q4_data.add(row * 32) as *const I16x8);
let quant1 = load_i16x8(src_q4_data.add(row * 32 + 16) as *const I16x8);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
for col_raw in 0..self_cols_its { for col_raw in 0..self_cols_its {
let row_col = row * self_cols_its + col_raw; let row_col = row * self_cols_its + col_raw;
if row_col % nthreads != thread_idx { if row_col % nthreads != thread_idx {
@ -1567,6 +1563,10 @@ impl Tensor {
]; ];
for p in 0..src_cols_its { for p in 0..src_cols_its {
let (quant0, quant1) = src.q4_lookup_table(row as i64, p as i64 * 32);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
// Macro to make code shorter // Macro to make code shorter
macro_rules! lo { macro_rules! lo {
($col:expr, $p:expr) => { ($col:expr, $p:expr) => {
@ -1693,11 +1693,6 @@ impl Tensor {
let other_data: *const u8 = other_data.unwrap() as *const u8; let other_data: *const u8 = other_data.unwrap() as *const u8;
let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32; let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32;
let quant0 = load_i16x8(other_q4_data.add(0) as *const I16x8);
let quant1 = load_i16x8(other_q4_data.add(16) as *const I16x8);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
let col = 0; let col = 0;
for row in 0..self_rows { for row in 0..self_rows {
if row % nthreads != thread_idx { if row % nthreads != thread_idx {
@ -1797,6 +1792,10 @@ impl Tensor {
[f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()];
for p in 0..src_cols_its { for p in 0..src_cols_its {
let (quant0, quant1) = other.q4_lookup_table(col as i64, p as i64 * 32);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
let (other8_0, other8_1, other8_2, other8_3) = let (other8_0, other8_1, other8_2, other8_3) =
load_k4_to_f32(&other, col, p * 32, other_rows, quants.as_ptr()); load_k4_to_f32(&other, col, p * 32, other_rows, quants.as_ptr());
let src8_0 = let src8_0 =
@ -1884,10 +1883,6 @@ impl Tensor {
if row % nthreads != thread_idx { if row % nthreads != thread_idx {
continue; continue;
} }
let quant0 = load_i16x8(src_q4_data.add(row * 32) as *const I16x8);
let quant1 = load_i16x8(src_q4_data.add(row * 32 + 16) as *const I16x8);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
let col = 0; let col = 0;
@ -1985,6 +1980,9 @@ impl Tensor {
[f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()];
for p in 0..src_cols_its { for p in 0..src_cols_its {
let (quant0, quant1) = src.q4_lookup_table(row as i64, p as i64 * 32);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
let other8_0: F32x8 = load_f32( let other8_0: F32x8 = load_f32(
other_data, other_data,
col, col,
@ -2816,9 +2814,12 @@ impl Tensor {
/// values and quantization lookup table. /// values and quantization lookup table.
/// ///
/// The first function must always return numbers from 0 to 15. These are put in the matrix. /// The first function must always return numbers from 0 to 15. These are put in the matrix.
/// The second function must return, for each row, 16 floats that correspond the numbers /// The second function must return, for each row and 512-long column block, 16 floats that correspond the numbers
/// returns from the first function. For this quantization scheme, each row has its own set of /// returns from the first function. For this quantization scheme, each row has its own set of
/// floats. /// floats.
///
/// The second function is called for each (row, column) where column is 0, 512, 1024, etc. so
/// that it would cover the entire matrix.
#[inline] #[inline]
pub fn make_k4bit_from_fn<F, F2>( pub fn make_k4bit_from_fn<F, F2>(
rows: i64, rows: i64,
@ -2828,82 +2829,76 @@ impl Tensor {
) -> Self ) -> Self
where where
F: FnMut(i64, i64) -> u8, F: FnMut(i64, i64) -> u8,
F2: FnMut(i64) -> [f32; 16], F2: FnMut(i64, i64) -> [f32; 16],
{ {
let mut result = let mut result =
unsafe { Tensor::uninitialized(rows, cols, TensorDType::K4BitQuantization) }; unsafe { Tensor::uninitialized(rows, cols, TensorDType::K4BitQuantization) };
result.allocate_q4_data(); result.allocate_q4_data();
assert!(!result.q4_data.is_null()); assert!(!result.q4_data.is_null());
let col_blocks = (cols + 511) / 512;
let q4_col_capacity = col_blocks as usize * 32;
unsafe { unsafe {
for row in 0..rows { for row in 0..rows {
// Set the lookup table to for block in 0..col_blocks {
let lookup_table = get_lookup_table(row); // Set the lookup table to
/* let lookup_table = get_lookup_table(row, block * 512);
let table1 = f32x8_from_values( let table1 = f32x8_from_values(
lookup_table[0], lookup_table[7],
lookup_table[1], lookup_table[6],
lookup_table[2], lookup_table[5],
lookup_table[3], lookup_table[4],
lookup_table[4], lookup_table[3],
lookup_table[5], lookup_table[2],
lookup_table[6], lookup_table[1],
lookup_table[7], lookup_table[0],
); );
let table2 = f32x8_from_values( let table2 = f32x8_from_values(
lookup_table[8], lookup_table[15],
lookup_table[9], lookup_table[14],
lookup_table[10], lookup_table[13],
lookup_table[11], lookup_table[12],
lookup_table[12], lookup_table[11],
lookup_table[13], lookup_table[10],
lookup_table[14], lookup_table[9],
lookup_table[15], lookup_table[8],
); );
*/
let table1 = f32x8_from_values(
lookup_table[7],
lookup_table[6],
lookup_table[5],
lookup_table[4],
lookup_table[3],
lookup_table[2],
lookup_table[1],
lookup_table[0],
);
let table2 = f32x8_from_values(
lookup_table[15],
lookup_table[14],
lookup_table[13],
lookup_table[12],
lookup_table[11],
lookup_table[10],
lookup_table[9],
lookup_table[8],
);
let table1 = f32x8_to_i16x8_as_f16(table1); let table1 = f32x8_to_i16x8_as_f16(table1);
let table2 = f32x8_to_i16x8_as_f16(table2); let table2 = f32x8_to_i16x8_as_f16(table2);
store_i16x8(result.q4_data.add(row as usize * 32) as *mut I16x8, table1); store_i16x8(
store_i16x8( result
result.q4_data.add(row as usize * 32 + 16) as *mut I16x8, .q4_data
table2, .add(row as usize * q4_col_capacity + block as usize * 32)
); as *mut I16x8,
table1,
);
store_i16x8(
result
.q4_data
.add(row as usize * q4_col_capacity + block as usize * 32 + 16)
as *mut I16x8,
table2,
);
for col in 0..cols { let start = block * 512;
let v = get_value(row, col); let end = std::cmp::min(start + 512, cols);
for col in start..end {
let v = get_value(row, col);
let (addr, side) = result.q4_address(row, col); let (addr, side) = result.q4_address(row, col);
let mut addr_value = *addr; let mut addr_value = *addr;
match side { match side {
BitSide::Upper => { BitSide::Upper => {
addr_value = (addr_value & 0x0F) | (v << 4); addr_value = (addr_value & 0x0F) | (v << 4);
} }
BitSide::Lower => { BitSide::Lower => {
addr_value = (addr_value & 0xF0) | v; addr_value = (addr_value & 0xF0) | v;
}
} }
*addr = addr_value;
} }
*addr = addr_value;
} }
} }
} }
@ -2940,6 +2935,26 @@ impl Tensor {
} }
} }
/// K4 bit quantization; loads the quantization table for given row/column
#[inline]
fn q4_lookup_table(&self, row: i64, col: i64) -> (I16x8, I16x8) {
let q4_capacity = (((self.cols + 511) / 512) * 32) as usize;
let col_block = col / 512;
unsafe {
let v1 = load_i16x8(
self.q4_data
.add(row as usize * q4_capacity + col_block as usize * 32)
as *const I16x8,
);
let v2 = load_i16x8(
self.q4_data
.add(row as usize * q4_capacity + col_block as usize * 32 + 16)
as *const I16x8,
);
(v1, v2)
}
}
fn allocate_q4_data(&mut self) { fn allocate_q4_data(&mut self) {
if self.dtype != TensorDType::K4BitQuantization { if self.dtype != TensorDType::K4BitQuantization {
panic!("Can only allocate q4 data for K4BitQuantization"); panic!("Can only allocate q4 data for K4BitQuantization");
@ -2949,7 +2964,10 @@ impl Tensor {
return; return;
} }
let layout = Layout::from_size_align(self.rows as usize * 32, 32).unwrap(); let q4_cols_capacity = ((self.cols + 511) / 512) * 32;
let layout =
Layout::from_size_align(self.rows as usize * q4_cols_capacity as usize, 32).unwrap();
let q4_data = unsafe { std::alloc::alloc_zeroed(layout) }; let q4_data = unsafe { std::alloc::alloc_zeroed(layout) };
if q4_data.is_null() { if q4_data.is_null() {
panic!("Failed to allocate q4 data"); panic!("Failed to allocate q4 data");
@ -4116,7 +4134,7 @@ mod tests {
16, 16,
16, 16,
|_row, col| col as u8, |_row, col| col as u8,
|row| { |row, _col| {
let mut result: [f32; 16] = [0.0; 16]; let mut result: [f32; 16] = [0.0; 16];
for col in 0..16 { for col in 0..16 {
result[col] = reference.get_f32(row, col as i64); result[col] = reference.get_f32(row, col as i64);
@ -4174,7 +4192,7 @@ mod tests {
a, a,
b, b,
|row, col| quantized_values[row as usize][col as usize], |row, col| quantized_values[row as usize][col as usize],
|row| { |row, _col| {
let mut result: [f32; 16] = [0.0; 16]; let mut result: [f32; 16] = [0.0; 16];
for col in 0..16 { for col in 0..16 {
result[col] = quant_values[row as usize][col]; result[col] = quant_values[row as usize][col];
@ -4249,7 +4267,7 @@ mod tests {
b, b,
c, c,
|row, col| quantized_values[row as usize][col as usize], |row, col| quantized_values[row as usize][col as usize],
|row| { |row, _col| {
let mut result: [f32; 16] = [0.0; 16]; let mut result: [f32; 16] = [0.0; 16];
for col in 0..16 { for col in 0..16 {
result[col] = quant_values[row as usize][col]; result[col] = quant_values[row as usize][col];
@ -4325,7 +4343,7 @@ mod tests {
b, b,
c, c,
|row, col| quantized_values[row as usize][col as usize], |row, col| quantized_values[row as usize][col as usize],
|row| { |row, _col| {
let mut result: [f32; 16] = [0.0; 16]; let mut result: [f32; 16] = [0.0; 16];
for col in 0..16 { for col in 0..16 {
result[col] = quant_values[row as usize][col]; result[col] = quant_values[row as usize][col];

@ -1,5 +1,4 @@
use crate::tensor::Tensor; use crate::tensor::Tensor;
use rand::{thread_rng, Rng};
use rayon::prelude::*; use rayon::prelude::*;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
@ -9,49 +8,58 @@ pub fn quantize(tensor: &Tensor) -> Tensor {
* This is a simplistic rounding quantizer. It splits each row in a tensor to 16 buckets and * This is a simplistic rounding quantizer. It splits each row in a tensor to 16 buckets and
* takes the average value in said buckets as the quantized weight. * takes the average value in said buckets as the quantized weight.
*/ */
let mut allowed_values_by_row: Vec<Vec<f32>> = Vec::with_capacity(tensor.rows() as usize); let mut allowed_values_by_row_block: Vec<Vec<Vec<f32>>> =
Vec::with_capacity(tensor.rows() as usize);
let col_blocks = (tensor.cols() + 511) / 512;
for row in 0..tensor.rows() { for row in 0..tensor.rows() {
let mut values: Vec<f32> = Vec::with_capacity(tensor.cols() as usize); let mut block_values: Vec<Vec<f32>> = Vec::with_capacity(col_blocks as usize);
values.truncate(0); for block in 0..col_blocks {
let mut mi: f32 = std::f32::MAX; let start = block * 512;
let mut ma: f32 = std::f32::MIN; let end = std::cmp::min(start + 512, tensor.cols());
for col in 0..tensor.cols() { let mut values: Vec<f32> = Vec::with_capacity(512);
let val = tensor.get_f32(row, col); values.truncate(0);
if val < mi { let mut mi: f32 = std::f32::MAX;
mi = val; let mut ma: f32 = std::f32::MIN;
}
if val > ma { for col in start..end {
ma = val; let val = tensor.get_f32(row, col);
if val < mi {
mi = val;
}
if val > ma {
ma = val;
}
values.push(val);
} }
values.push(val); values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
} let mut allowed_values: Vec<f32> = Vec::with_capacity(16);
values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); for i in 0..16 {
let mut allowed_values: Vec<f32> = Vec::with_capacity(16); let start_idx = i * values.len() / 16;
let mut rng = thread_rng(); let end_idx = (i + 1) * values.len() / 16;
for i in 0..16 {
let start_idx = i * values.len() / 16;
let end_idx = (i + 1) * values.len() / 16;
let mut avg = 0.0; let mut avg = 0.0;
for j in start_idx..end_idx { for j in start_idx..end_idx {
avg += values[j]; avg += values[j];
}
avg /= (end_idx - start_idx) as f32;
allowed_values.push(avg);
} }
avg /= (end_idx - start_idx) as f32; allowed_values[0] = mi;
allowed_values.push(avg); allowed_values[15] = ma;
allowed_values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
block_values.push(allowed_values);
} }
allowed_values[0] = mi; allowed_values_by_row_block.push(block_values);
allowed_values[15] = ma;
allowed_values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
allowed_values_by_row.push(allowed_values);
} }
//let mut result = Tensor::zeros(tensor.rows(), tensor.cols(), tensor.dtype());
let result = Tensor::make_k4bit_from_fn( let result = Tensor::make_k4bit_from_fn(
tensor.rows(), tensor.rows(),
tensor.cols(), tensor.cols(),
|row, col| { |row, col| {
let allowed_values: &[f32] = &allowed_values_by_row[row as usize]; let col_block = col / 512;
let allowed_values: &[f32] =
&allowed_values_by_row_block[row as usize][col_block as usize];
let val = tensor.get_f32(row, col); let val = tensor.get_f32(row, col);
let mut best = 0; let mut best = 0;
let mut best_dist = std::f32::MAX; let mut best_dist = std::f32::MAX;
@ -64,8 +72,9 @@ pub fn quantize(tensor: &Tensor) -> Tensor {
} }
best as u8 best as u8
}, },
|row: i64| -> [f32; 16] { |row: i64, col: i64| -> [f32; 16] {
let allowed_values: &[f32] = &allowed_values_by_row[row as usize]; let allowed_values: &[f32] =
&allowed_values_by_row_block[row as usize][col as usize / 512];
let mut result: [f32; 16] = [0.0; 16]; let mut result: [f32; 16] = [0.0; 16];
for i in 0..16 { for i in 0..16 {
result[i] = allowed_values[i]; result[i] = allowed_values[i];
@ -75,3 +84,63 @@ pub fn quantize(tensor: &Tensor) -> Tensor {
); );
result result
} }
// Same as quantize but doesn't actually change the type of the tensor. It just changes the tensor
// itself. Used to test new quantization schemes without writing support for them.
pub fn quantize_test(tensor: &Tensor) -> Tensor {
let mut result = Tensor::zeros(tensor.rows(), tensor.cols(), tensor.dtype());
for row in 0..tensor.rows() {
let col_blocks = (tensor.cols() + 511) / 512;
for block in 0..col_blocks {
let mut values: Vec<f32> = Vec::with_capacity(tensor.cols() as usize);
values.truncate(0);
let mut mi: f32 = std::f32::MAX;
let mut ma: f32 = std::f32::MIN;
let start = block * 512;
let end = std::cmp::min(start + 512, tensor.cols());
for col in start..end {
let val = tensor.get_f32(row, col);
if val < mi {
mi = val;
}
if val > ma {
ma = val;
}
values.push(val);
}
values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
let mut allowed_values: Vec<f32> = Vec::with_capacity(16);
for i in 0..16 {
let start_idx = i * values.len() / 16;
let end_idx = (i + 1) * values.len() / 16;
let mut avg = 0.0;
for j in start_idx..end_idx {
avg += values[j];
}
avg /= (end_idx - start_idx) as f32;
allowed_values.push(avg);
}
allowed_values[0] = mi;
allowed_values[15] = ma;
allowed_values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
for col in start..end {
let val = tensor.get_f32(row, col);
let mut best = 0;
let mut best_dist = std::f32::MAX;
for i in 0..16 {
let dist = (val - allowed_values[i] as f32).abs();
if dist < best_dist {
best = i;
best_dist = dist;
}
}
result.set_f32(row, col, allowed_values[best as usize]);
}
}
}
result
}

Loading…
Cancel
Save