diff --git a/src/tensor.rs b/src/tensor.rs index 6f80224..2626743 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -91,7 +91,11 @@ pub struct Tensor { data: *mut u8, // for quantization, only used if dtype == TensorDType::K4BitQuantization - // Contains 16 values per row in f16 (i.e. 32 bytes per row) + // q4_data is (NxM) where N is: + // + // ((cols + 511) / 512) * 32 (i.e. 32 bytes for every 512 columns, rounded up) + // + // and M is number of rows in the tensor. q4_data: *mut u8, #[cfg(feature = "opencl")] @@ -345,11 +349,9 @@ impl Tensor { BitSide::Lower => (addr_val & 0x0F), } }; - let table: I16x8 = if quant_val <= 7 { - unsafe { load_i16x8(self.q4_data.add(row as usize * 32) as *const I16x8) } - } else { - unsafe { load_i16x8(self.q4_data.add(row as usize * 32 + 16) as *const I16x8) } - }; + + let (table1, table2) = self.q4_lookup_table(row, col); + let table = if quant_val <= 7 { table1 } else { table2 }; let table = i16x8_as_f16_to_f32x8(table); f32x8_get(table, (quant_val % 8) as usize) } @@ -1261,11 +1263,6 @@ impl Tensor { continue; } - let quant0 = load_i16x8(other_q4_data.add(col * 32) as *const I16x8); - let quant1 = load_i16x8(other_q4_data.add(col * 32 + 16) as *const I16x8); - let quants: [F32x8; 2] = - [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; - #[inline] fn load_f32( src: *const f32, @@ -1360,6 +1357,10 @@ impl Tensor { [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; for p in 0..src_cols_its { + let (quant0, quant1) = other.q4_lookup_table(col as i64, p as i64 * 32); + let quants: [F32x8; 2] = + [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; + let (other8_0, other8_1, other8_2, other8_3) = load_k4_to_f32(&other, col, p * 32, other_rows, quants.as_ptr()); let src8_0 = load_f32( @@ -1454,11 +1455,6 @@ impl Tensor { let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32; for row in 0..self_rows { - let quant0 = load_i16x8(src_q4_data.add(row * 32) as *const I16x8); - let quant1 = load_i16x8(src_q4_data.add(row * 32 + 16) as *const I16x8); - let quants: [F32x8; 2] = - [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; - for col_raw in 0..self_cols_its { let row_col = row * self_cols_its + col_raw; if row_col % nthreads != thread_idx { @@ -1567,6 +1563,10 @@ impl Tensor { ]; for p in 0..src_cols_its { + let (quant0, quant1) = src.q4_lookup_table(row as i64, p as i64 * 32); + let quants: [F32x8; 2] = + [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; + // Macro to make code shorter macro_rules! lo { ($col:expr, $p:expr) => { @@ -1693,11 +1693,6 @@ impl Tensor { let other_data: *const u8 = other_data.unwrap() as *const u8; let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32; - let quant0 = load_i16x8(other_q4_data.add(0) as *const I16x8); - let quant1 = load_i16x8(other_q4_data.add(16) as *const I16x8); - let quants: [F32x8; 2] = - [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; - let col = 0; for row in 0..self_rows { if row % nthreads != thread_idx { @@ -1797,6 +1792,10 @@ impl Tensor { [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; for p in 0..src_cols_its { + let (quant0, quant1) = other.q4_lookup_table(col as i64, p as i64 * 32); + let quants: [F32x8; 2] = + [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; + let (other8_0, other8_1, other8_2, other8_3) = load_k4_to_f32(&other, col, p * 32, other_rows, quants.as_ptr()); let src8_0 = @@ -1884,10 +1883,6 @@ impl Tensor { if row % nthreads != thread_idx { continue; } - let quant0 = load_i16x8(src_q4_data.add(row * 32) as *const I16x8); - let quant1 = load_i16x8(src_q4_data.add(row * 32 + 16) as *const I16x8); - let quants: [F32x8; 2] = - [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; let col = 0; @@ -1985,6 +1980,9 @@ impl Tensor { [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; for p in 0..src_cols_its { + let (quant0, quant1) = src.q4_lookup_table(row as i64, p as i64 * 32); + let quants: [F32x8; 2] = + [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; let other8_0: F32x8 = load_f32( other_data, col, @@ -2816,9 +2814,12 @@ impl Tensor { /// values and quantization lookup table. /// /// The first function must always return numbers from 0 to 15. These are put in the matrix. - /// The second function must return, for each row, 16 floats that correspond the numbers + /// The second function must return, for each row and 512-long column block, 16 floats that correspond the numbers /// returns from the first function. For this quantization scheme, each row has its own set of /// floats. + /// + /// The second function is called for each (row, column) where column is 0, 512, 1024, etc. so + /// that it would cover the entire matrix. #[inline] pub fn make_k4bit_from_fn( rows: i64, @@ -2828,82 +2829,76 @@ impl Tensor { ) -> Self where F: FnMut(i64, i64) -> u8, - F2: FnMut(i64) -> [f32; 16], + F2: FnMut(i64, i64) -> [f32; 16], { let mut result = unsafe { Tensor::uninitialized(rows, cols, TensorDType::K4BitQuantization) }; result.allocate_q4_data(); assert!(!result.q4_data.is_null()); + let col_blocks = (cols + 511) / 512; + let q4_col_capacity = col_blocks as usize * 32; + unsafe { for row in 0..rows { - // Set the lookup table to - let lookup_table = get_lookup_table(row); - /* - let table1 = f32x8_from_values( - lookup_table[0], - lookup_table[1], - lookup_table[2], - lookup_table[3], - lookup_table[4], - lookup_table[5], - lookup_table[6], - lookup_table[7], - ); - let table2 = f32x8_from_values( - lookup_table[8], - lookup_table[9], - lookup_table[10], - lookup_table[11], - lookup_table[12], - lookup_table[13], - lookup_table[14], - lookup_table[15], - ); - */ - let table1 = f32x8_from_values( - lookup_table[7], - lookup_table[6], - lookup_table[5], - lookup_table[4], - lookup_table[3], - lookup_table[2], - lookup_table[1], - lookup_table[0], - ); - let table2 = f32x8_from_values( - lookup_table[15], - lookup_table[14], - lookup_table[13], - lookup_table[12], - lookup_table[11], - lookup_table[10], - lookup_table[9], - lookup_table[8], - ); + for block in 0..col_blocks { + // Set the lookup table to + let lookup_table = get_lookup_table(row, block * 512); + let table1 = f32x8_from_values( + lookup_table[7], + lookup_table[6], + lookup_table[5], + lookup_table[4], + lookup_table[3], + lookup_table[2], + lookup_table[1], + lookup_table[0], + ); + let table2 = f32x8_from_values( + lookup_table[15], + lookup_table[14], + lookup_table[13], + lookup_table[12], + lookup_table[11], + lookup_table[10], + lookup_table[9], + lookup_table[8], + ); - let table1 = f32x8_to_i16x8_as_f16(table1); - let table2 = f32x8_to_i16x8_as_f16(table2); - store_i16x8(result.q4_data.add(row as usize * 32) as *mut I16x8, table1); - store_i16x8( - result.q4_data.add(row as usize * 32 + 16) as *mut I16x8, - table2, - ); + let table1 = f32x8_to_i16x8_as_f16(table1); + let table2 = f32x8_to_i16x8_as_f16(table2); + store_i16x8( + result + .q4_data + .add(row as usize * q4_col_capacity + block as usize * 32) + as *mut I16x8, + table1, + ); + store_i16x8( + result + .q4_data + .add(row as usize * q4_col_capacity + block as usize * 32 + 16) + as *mut I16x8, + table2, + ); - for col in 0..cols { - let v = get_value(row, col); + let start = block * 512; + let end = std::cmp::min(start + 512, cols); + for col in start..end { + let v = get_value(row, col); - let (addr, side) = result.q4_address(row, col); - let mut addr_value = *addr; - match side { - BitSide::Upper => { - addr_value = (addr_value & 0x0F) | (v << 4); - } - BitSide::Lower => { - addr_value = (addr_value & 0xF0) | v; + let (addr, side) = result.q4_address(row, col); + let mut addr_value = *addr; + match side { + BitSide::Upper => { + addr_value = (addr_value & 0x0F) | (v << 4); + } + BitSide::Lower => { + addr_value = (addr_value & 0xF0) | v; + } } + *addr = addr_value; } - *addr = addr_value; } } } @@ -2940,6 +2935,26 @@ impl Tensor { } } + /// K4 bit quantization; loads the quantization table for given row/column + #[inline] + fn q4_lookup_table(&self, row: i64, col: i64) -> (I16x8, I16x8) { + let q4_capacity = (((self.cols + 511) / 512) * 32) as usize; + let col_block = col / 512; + unsafe { + let v1 = load_i16x8( + self.q4_data + .add(row as usize * q4_capacity + col_block as usize * 32) + as *const I16x8, + ); + let v2 = load_i16x8( + self.q4_data + .add(row as usize * q4_capacity + col_block as usize * 32 + 16) + as *const I16x8, + ); + (v1, v2) + } + } + fn allocate_q4_data(&mut self) { if self.dtype != TensorDType::K4BitQuantization { panic!("Can only allocate q4 data for K4BitQuantization"); @@ -2949,7 +2964,10 @@ impl Tensor { return; } - let layout = Layout::from_size_align(self.rows as usize * 32, 32).unwrap(); + let q4_cols_capacity = ((self.cols + 511) / 512) * 32; + + let layout = + Layout::from_size_align(self.rows as usize * q4_cols_capacity as usize, 32).unwrap(); let q4_data = unsafe { std::alloc::alloc_zeroed(layout) }; if q4_data.is_null() { panic!("Failed to allocate q4 data"); @@ -4116,7 +4134,7 @@ mod tests { 16, 16, |_row, col| col as u8, - |row| { + |row, _col| { let mut result: [f32; 16] = [0.0; 16]; for col in 0..16 { result[col] = reference.get_f32(row, col as i64); @@ -4174,7 +4192,7 @@ mod tests { a, b, |row, col| quantized_values[row as usize][col as usize], - |row| { + |row, _col| { let mut result: [f32; 16] = [0.0; 16]; for col in 0..16 { result[col] = quant_values[row as usize][col]; @@ -4249,7 +4267,7 @@ mod tests { b, c, |row, col| quantized_values[row as usize][col as usize], - |row| { + |row, _col| { let mut result: [f32; 16] = [0.0; 16]; for col in 0..16 { result[col] = quant_values[row as usize][col]; @@ -4325,7 +4343,7 @@ mod tests { b, c, |row, col| quantized_values[row as usize][col as usize], - |row| { + |row, _col| { let mut result: [f32; 16] = [0.0; 16]; for col in 0..16 { result[col] = quant_values[row as usize][col]; diff --git a/src/weight_compression.rs b/src/weight_compression.rs index a39af64..bdb0374 100644 --- a/src/weight_compression.rs +++ b/src/weight_compression.rs @@ -1,5 +1,4 @@ use crate::tensor::Tensor; -use rand::{thread_rng, Rng}; use rayon::prelude::*; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, RwLock}; @@ -9,49 +8,58 @@ pub fn quantize(tensor: &Tensor) -> Tensor { * This is a simplistic rounding quantizer. It splits each row in a tensor to 16 buckets and * takes the average value in said buckets as the quantized weight. */ - let mut allowed_values_by_row: Vec> = Vec::with_capacity(tensor.rows() as usize); + let mut allowed_values_by_row_block: Vec>> = + Vec::with_capacity(tensor.rows() as usize); + let col_blocks = (tensor.cols() + 511) / 512; for row in 0..tensor.rows() { - let mut values: Vec = Vec::with_capacity(tensor.cols() as usize); - values.truncate(0); - let mut mi: f32 = std::f32::MAX; - let mut ma: f32 = std::f32::MIN; + let mut block_values: Vec> = Vec::with_capacity(col_blocks as usize); + for block in 0..col_blocks { + let start = block * 512; + let end = std::cmp::min(start + 512, tensor.cols()); - for col in 0..tensor.cols() { - let val = tensor.get_f32(row, col); - if val < mi { - mi = val; - } - if val > ma { - ma = val; + let mut values: Vec = Vec::with_capacity(512); + values.truncate(0); + let mut mi: f32 = std::f32::MAX; + let mut ma: f32 = std::f32::MIN; + + for col in start..end { + let val = tensor.get_f32(row, col); + if val < mi { + mi = val; + } + if val > ma { + ma = val; + } + values.push(val); } - values.push(val); - } - values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); - let mut allowed_values: Vec = Vec::with_capacity(16); - let mut rng = thread_rng(); - for i in 0..16 { - let start_idx = i * values.len() / 16; - let end_idx = (i + 1) * values.len() / 16; + values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + let mut allowed_values: Vec = Vec::with_capacity(16); + for i in 0..16 { + let start_idx = i * values.len() / 16; + let end_idx = (i + 1) * values.len() / 16; - let mut avg = 0.0; - for j in start_idx..end_idx { - avg += values[j]; + let mut avg = 0.0; + for j in start_idx..end_idx { + avg += values[j]; + } + avg /= (end_idx - start_idx) as f32; + allowed_values.push(avg); } - avg /= (end_idx - start_idx) as f32; - allowed_values.push(avg); + allowed_values[0] = mi; + allowed_values[15] = ma; + allowed_values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + block_values.push(allowed_values); } - allowed_values[0] = mi; - allowed_values[15] = ma; - allowed_values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); - allowed_values_by_row.push(allowed_values); + allowed_values_by_row_block.push(block_values); } - //let mut result = Tensor::zeros(tensor.rows(), tensor.cols(), tensor.dtype()); let result = Tensor::make_k4bit_from_fn( tensor.rows(), tensor.cols(), |row, col| { - let allowed_values: &[f32] = &allowed_values_by_row[row as usize]; + let col_block = col / 512; + let allowed_values: &[f32] = + &allowed_values_by_row_block[row as usize][col_block as usize]; let val = tensor.get_f32(row, col); let mut best = 0; let mut best_dist = std::f32::MAX; @@ -64,8 +72,9 @@ pub fn quantize(tensor: &Tensor) -> Tensor { } best as u8 }, - |row: i64| -> [f32; 16] { - let allowed_values: &[f32] = &allowed_values_by_row[row as usize]; + |row: i64, col: i64| -> [f32; 16] { + let allowed_values: &[f32] = + &allowed_values_by_row_block[row as usize][col as usize / 512]; let mut result: [f32; 16] = [0.0; 16]; for i in 0..16 { result[i] = allowed_values[i]; @@ -75,3 +84,63 @@ pub fn quantize(tensor: &Tensor) -> Tensor { ); result } + +// Same as quantize but doesn't actually change the type of the tensor. It just changes the tensor +// itself. Used to test new quantization schemes without writing support for them. +pub fn quantize_test(tensor: &Tensor) -> Tensor { + let mut result = Tensor::zeros(tensor.rows(), tensor.cols(), tensor.dtype()); + for row in 0..tensor.rows() { + let col_blocks = (tensor.cols() + 511) / 512; + for block in 0..col_blocks { + let mut values: Vec = Vec::with_capacity(tensor.cols() as usize); + values.truncate(0); + let mut mi: f32 = std::f32::MAX; + let mut ma: f32 = std::f32::MIN; + + let start = block * 512; + let end = std::cmp::min(start + 512, tensor.cols()); + + for col in start..end { + let val = tensor.get_f32(row, col); + if val < mi { + mi = val; + } + if val > ma { + ma = val; + } + values.push(val); + } + values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + let mut allowed_values: Vec = Vec::with_capacity(16); + for i in 0..16 { + let start_idx = i * values.len() / 16; + let end_idx = (i + 1) * values.len() / 16; + + let mut avg = 0.0; + for j in start_idx..end_idx { + avg += values[j]; + } + avg /= (end_idx - start_idx) as f32; + allowed_values.push(avg); + } + allowed_values[0] = mi; + allowed_values[15] = ma; + allowed_values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + + for col in start..end { + let val = tensor.get_f32(row, col); + let mut best = 0; + let mut best_dist = std::f32::MAX; + for i in 0..16 { + let dist = (val - allowed_values[i] as f32).abs(); + if dist < best_dist { + best = i; + best_dist = dist; + } + } + result.set_f32(row, col, allowed_values[best as usize]); + } + } + } + result +}