diff --git a/src/benches/benchmark.rs b/src/benches/benchmark.rs index 2453d79..80175ce 100644 --- a/src/benches/benchmark.rs +++ b/src/benches/benchmark.rs @@ -97,6 +97,7 @@ pub fn tensor_benchmarks(c: &mut Criterion) { let orig_84096_1 = Tensor::zeros(8, 4096, TensorDType::Float32); let orig_84096_2 = Tensor::zeros(4096, 4096, TensorDType::Float32); + let orig_84096_quant = orig_84096_1.quantize(); let mut result_84096 = Tensor::zeros(8, 4096, TensorDType::Float32); let orig_84096_1_f16 = Tensor::zeros(8, 4096, TensorDType::Float16); @@ -111,6 +112,29 @@ pub fn tensor_benchmarks(c: &mut Criterion) { let m1_f16 = m1.to_f16(); let m2_f16 = m2.to_f16(); + let quant = m1.quantize(); + + c.bench_function( + "1024x128 * 1x128 matrix vector transposed multiplication, k4 quantized * f32", + |b| { + b.iter(|| { + let _ = quant.matrix_vector_mul_transposed(black_box(&m2)); + }) + }, + ); + + c.bench_function( + "matrix multiplication 8x4096 @ 4096x4096 k8 quantized * f32 in-place, transposed", + |b| { + b.iter(|| { + let _ = result_84096.matrix_mul_inplace_transposed( + black_box(&orig_84096_quant), + black_box(&orig_84096_2), + ); + }) + }, + ); + c.bench_function( "1024x128 * 1x128 matrix vector transposed multiplication, f32", |b| { diff --git a/src/simd_support.rs b/src/simd_support.rs index 660677b..2324ded 100644 --- a/src/simd_support.rs +++ b/src/simd_support.rs @@ -3,6 +3,7 @@ use core::arch::x86_64::*; use half::f16; +use std::fmt::Write; pub type I32x8 = __m256i; pub type F32x8 = __m256; @@ -37,6 +38,11 @@ pub fn gather_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 { unsafe { _mm256_i32gather_ps(ptr, indices, 1) } } +#[inline] +pub fn gather_scale4_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 { + unsafe { _mm256_i32gather_ps(ptr, indices, 4) } +} + /* ------------------ */ /* Conversions */ /* ------------------ */ @@ -51,22 +57,121 @@ pub fn f32x8_to_i16x8_as_f16(a: F32x8) -> I16x8 { unsafe { _mm256_cvtps_ph(a, 0) } } +#[inline] +/// Converts f32x8 to i32x8, by just casting the bits. I.e. it does not round any numbers or +/// anything, it just copies the bits. +pub fn f32x8_to_i32x8_bitcast(a: F32x8) -> I32x8 { + unsafe { _mm256_castps_si256(a) } +} + +/* + * ------------------ + * Accessing individual elements + */ + +// Rust has no const arguments (yet, maybe in future). +// So we have this awkward match statement in each of these. + +#[inline] +pub fn f32x8_get(a: F32x8, idx: usize) -> f32 { + unsafe { + let a = f32x8_to_i32x8_bitcast(a); + let a = match idx { + 0 => _mm256_extract_epi32(a, 0), + 1 => _mm256_extract_epi32(a, 1), + 2 => _mm256_extract_epi32(a, 2), + 3 => _mm256_extract_epi32(a, 3), + 4 => _mm256_extract_epi32(a, 4), + 5 => _mm256_extract_epi32(a, 5), + 6 => _mm256_extract_epi32(a, 6), + 7 => _mm256_extract_epi32(a, 7), + _ => panic!("f32x8_get: index out of bounds"), + }; + // bitcast the i32 back to f32 + core::mem::transmute(a) + } +} + +#[inline] +pub fn i32x8_get(a: I32x8, idx: usize) -> i32 { + unsafe { + let a = match idx { + 0 => _mm256_extract_epi32(a, 0), + 1 => _mm256_extract_epi32(a, 1), + 2 => _mm256_extract_epi32(a, 2), + 3 => _mm256_extract_epi32(a, 3), + 4 => _mm256_extract_epi32(a, 4), + 5 => _mm256_extract_epi32(a, 5), + 6 => _mm256_extract_epi32(a, 6), + 7 => _mm256_extract_epi32(a, 7), + _ => panic!("i32x8_get: index out of bounds"), + }; + a + } +} + +#[inline] +pub fn i16x8_get(a: I16x8, idx: usize) -> i16 { + unsafe { + let a = match idx { + 0 => _mm_extract_epi16(a, 0), + 1 => _mm_extract_epi16(a, 1), + 2 => _mm_extract_epi16(a, 2), + 3 => _mm_extract_epi16(a, 3), + 4 => _mm_extract_epi16(a, 4), + 5 => _mm_extract_epi16(a, 5), + 6 => _mm_extract_epi16(a, 6), + 7 => _mm_extract_epi16(a, 7), + _ => panic!("i16x8_get: index out of bounds"), + }; + a as i16 + } +} + /* * Constants, creating from constants */ +#[inline] pub fn f32x8_zero() -> F32x8 { unsafe { _mm256_setzero_ps() } } +#[inline] pub fn i16x8_zero() -> I16x8 { unsafe { _mm_setzero_si128() } } +#[inline] +pub fn i16x8_singleton(value: i16) -> I16x8 { + unsafe { _mm_set1_epi16(value) } +} + +#[inline] +pub fn i16x8_singleton_u16(value: u16) -> I16x8 { + unsafe { _mm_set1_epi16(value as i16) } +} + +#[inline] pub fn f32x8_singleton(value: f32) -> F32x8 { unsafe { _mm256_set1_ps(value) } } +#[inline] +pub fn f32x8_from_values( + val0: f32, + val1: f32, + val2: f32, + val3: f32, + val4: f32, + val5: f32, + val6: f32, + val7: f32, +) -> F32x8 { + unsafe { _mm256_set_ps(val0, val1, val2, val3, val4, val5, val6, val7) } +} + +#[inline] pub fn i32x8_from_values( val0: i32, val1: i32, @@ -80,6 +185,45 @@ pub fn i32x8_from_values( unsafe { _mm256_set_epi32(val0, val1, val2, val3, val4, val5, val6, val7) } } +#[inline] +pub fn i32x8_from_values_u32( + val0: u32, + val1: u32, + val2: u32, + val3: u32, + val4: u32, + val5: u32, + val6: u32, + val7: u32, +) -> I32x8 { + unsafe { + _mm256_set_epi32( + val0 as i32, + val1 as i32, + val2 as i32, + val3 as i32, + val4 as i32, + val5 as i32, + val6 as i32, + val7 as i32, + ) + } +} + +#[inline] +pub fn i16x8_from_values( + val0: i16, + val1: i16, + val2: i16, + val3: i16, + val4: i16, + val5: i16, + val6: i16, + val7: i16, +) -> I16x8 { + unsafe { _mm_set_epi16(val0, val1, val2, val3, val4, val5, val6, val7) } +} + /* * Operations */ @@ -87,10 +231,59 @@ pub fn i32x8_from_values( // FMA // a * b + c +#[inline] pub fn fma_f32x8(a: F32x8, b: F32x8, c: F32x8) -> F32x8 { unsafe { _mm256_fmadd_ps(a, b, c) } } +// bitwise and +#[inline] +pub fn and_i16x8(a: I16x8, b: I16x8) -> I16x8 { + unsafe { _mm_and_si128(a, b) } +} + +#[inline] +pub fn and_i32x8(a: I32x8, b: I32x8) -> I32x8 { + unsafe { _mm256_and_si256(a, b) } +} + +#[inline] +pub fn and_f32x8(a: F32x8, b: I32x8) -> F32x8 { + unsafe { std::mem::transmute(_mm256_and_si256(std::mem::transmute(a), b)) } +} + +// shift right by 4 bits exactly, for each individual i16 value. +// extends by zeros from left. +#[inline] +pub fn shift_right_by_4_i16x8(a: I16x8) -> I16x8 { + unsafe { _mm_srli_epi16(a, 4) } +} + +// shift right by half of an entire i16x8 +// extends by zeros from left. +#[inline] +pub fn shift_right_by_64_i128(a: I16x8) -> I16x8 { + unsafe { _mm_srli_si128(a, 64 / 8) } +} + +// Shuffle/premute +pub fn shuffle_i16x8(a: I16x8, permutation: I16x8) -> I16x8 { + unsafe { _mm_shuffle_epi8(a, permutation) } +} + +// Extends 8 i8 values into 7 i16 values +// +// XXYYZZ -> 00XX00YY00ZZ +pub fn extend_i8_to_i16_i16x8(a: I16x8) -> I16x8 { + unsafe { _mm_cvtepi8_epi16(a) } +} + +// Extends 8 i8 values into 4 i32 values +pub fn extend_i8_to_i32_i32x8(a: I16x8) -> I32x8 { + let i = extend_i8_to_i16_i16x8(a); + unsafe { _mm256_cvtepu16_epi32(i) } +} + // Horizontal sums #[inline] @@ -114,3 +307,48 @@ pub fn horizontal_sum_and_f32_to_f16(mut ymm: __m256) -> f16 { f16::from_f32(_mm256_cvtss_f32(ymm)) } } + +/* + * Debugging + */ + +/// Prints a binary representation of i16x8 to stdout in this form: +/// +/// 0 0 0 0 +/// 0x0000 0x0000 0x0000 0x0000 +/// 0000000000000000 0000000000000000 0000000000000000 0000000000000000 etc. +/// +/// decimal on first line, hex on second, binary on third. +pub fn print_i16x8(a: I16x8) { + let mut decimal_line = String::new(); + let mut hex_line = String::new(); + let mut binary_line = String::new(); + + for i in 0..8 { + let val = i16x8_get(a, i); + write!(decimal_line, "{:>5} ", val).unwrap(); + write!(hex_line, "0x{:04X} ", val).unwrap(); + write!(binary_line, "{:016b} ", val).unwrap(); + } + + println!("{}", decimal_line.trim_end()); + println!("{}", hex_line.trim_end()); + println!("{}", binary_line.trim_end()); +} + +pub fn print_i32x8(a: I32x8) { + let mut decimal_line = String::new(); + let mut hex_line = String::new(); + let mut binary_line = String::new(); + + for i in 0..8 { + let val = i32x8_get(a, i); + write!(decimal_line, "{:>10} ", val).unwrap(); + write!(hex_line, "0x{:08X} ", val).unwrap(); + write!(binary_line, "{:032b} ", val).unwrap(); + } + + println!("{}", decimal_line.trim_end()); + println!("{}", hex_line.trim_end()); + println!("{}", binary_line.trim_end()); +} diff --git a/src/simd_support_aarch64.rs b/src/simd_support_aarch64.rs new file mode 100644 index 0000000..a9e0461 --- /dev/null +++ b/src/simd_support_aarch64.rs @@ -0,0 +1,116 @@ +// This file contains platform-specific SIMD so that rest of rllama does not need to care which +// platform it is on. + +use core::arch::aarch64::*; +use half::f16; + +pub type I32x8 = int32x4x2_t; +pub type F32x8 = float32x4x2_t; +pub type I16x8 = int16x8_t; + +/* ------------------ */ +/* Loading and storing things */ +/* ------------------ */ + +#[inline] +pub fn load_i16x8(ptr: *const I16x8) -> I16x8 { + unsafe { vld1q_s16(ptr) } +} + +#[inline] +pub fn store_i16x8(ptr: *mut I16x8, a: I16x8) { + unsafe { vst1q_s16(ptr, a) } +} + +#[inline] +pub fn load_f32x8(ptr: *const F32x8) -> F32x8 { + unsafe { vld1q_f32_x2(ptr as *const f32) } +} + +#[inline] +pub fn store_f32x8(ptr: *mut F32x8, a: F32x8) { + unsafe { vst1q_f32_x2(ptr as *mut f32, a) } +} + +#[inline] +pub fn gather_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 { + unsafe { _mm256_i32gather_ps(ptr, indices, 1) } +} + +/* ------------------ */ +/* Conversions */ +/* ------------------ */ + +#[inline] +pub fn i16x8_as_f16_to_f32x8(a: I16x8) -> F32x8 { + unsafe { _mm256_cvtph_ps(a) } +} + +#[inline] +pub fn f32x8_to_i16x8_as_f16(a: F32x8) -> I16x8 { + unsafe { _mm256_cvtps_ph(a, 0) } +} + +/* + * Constants, creating from constants + */ + +pub fn f32x8_zero() -> F32x8 { + unsafe { _mm256_setzero_ps() } +} + +pub fn i16x8_zero() -> I16x8 { + unsafe { _mm_setzero_si128() } +} + +pub fn f32x8_singleton(value: f32) -> F32x8 { + unsafe { _mm256_set1_ps(value) } +} + +pub fn i32x8_from_values( + val0: i32, + val1: i32, + val2: i32, + val3: i32, + val4: i32, + val5: i32, + val6: i32, + val7: i32, +) -> I32x8 { + unsafe { _mm256_set_epi32(val0, val1, val2, val3, val4, val5, val6, val7) } +} + +/* + * Operations + */ + +// FMA + +// a * b + c +pub fn fma_f32x8(a: F32x8, b: F32x8, c: F32x8) -> F32x8 { + unsafe { _mm256_fmadd_ps(a, b, c) } +} + +// Horizontal sums + +#[inline] +pub fn horizontal_sum_f32x8(mut ymm: __m256) -> f32 { + unsafe { + let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1); + ymm = _mm256_add_ps(ymm, ymm2); + ymm = _mm256_hadd_ps(ymm, ymm); + ymm = _mm256_hadd_ps(ymm, ymm); + _mm256_cvtss_f32(ymm) + } +} + +#[inline] +pub fn horizontal_sum_and_f32_to_f16(mut ymm: __m256) -> f16 { + unsafe { + let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1); + ymm = _mm256_add_ps(ymm, ymm2); + ymm = _mm256_hadd_ps(ymm, ymm); + ymm = _mm256_hadd_ps(ymm, ymm); + f16::from_f32(_mm256_cvtss_f32(ymm)) + } +} diff --git a/src/simd_support_amd64.rs b/src/simd_support_amd64.rs new file mode 100644 index 0000000..660677b --- /dev/null +++ b/src/simd_support_amd64.rs @@ -0,0 +1,116 @@ +// This file contains platform-specific SIMD so that rest of rllama does not need to care which +// platform it is on. + +use core::arch::x86_64::*; +use half::f16; + +pub type I32x8 = __m256i; +pub type F32x8 = __m256; +pub type I16x8 = __m128i; + +/* ------------------ */ +/* Loading and storing things */ +/* ------------------ */ + +#[inline] +pub fn load_i16x8(ptr: *const I16x8) -> I16x8 { + unsafe { _mm_loadu_si128(ptr) } +} + +#[inline] +pub fn store_i16x8(ptr: *mut I16x8, a: I16x8) { + unsafe { _mm_storeu_si128(ptr, a) } +} + +#[inline] +pub fn load_f32x8(ptr: *const F32x8) -> F32x8 { + unsafe { _mm256_loadu_ps(ptr as *const f32) } +} + +#[inline] +pub fn store_f32x8(ptr: *mut F32x8, a: F32x8) { + unsafe { _mm256_storeu_ps(ptr as *mut f32, a) } +} + +#[inline] +pub fn gather_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 { + unsafe { _mm256_i32gather_ps(ptr, indices, 1) } +} + +/* ------------------ */ +/* Conversions */ +/* ------------------ */ + +#[inline] +pub fn i16x8_as_f16_to_f32x8(a: I16x8) -> F32x8 { + unsafe { _mm256_cvtph_ps(a) } +} + +#[inline] +pub fn f32x8_to_i16x8_as_f16(a: F32x8) -> I16x8 { + unsafe { _mm256_cvtps_ph(a, 0) } +} + +/* + * Constants, creating from constants + */ + +pub fn f32x8_zero() -> F32x8 { + unsafe { _mm256_setzero_ps() } +} + +pub fn i16x8_zero() -> I16x8 { + unsafe { _mm_setzero_si128() } +} + +pub fn f32x8_singleton(value: f32) -> F32x8 { + unsafe { _mm256_set1_ps(value) } +} + +pub fn i32x8_from_values( + val0: i32, + val1: i32, + val2: i32, + val3: i32, + val4: i32, + val5: i32, + val6: i32, + val7: i32, +) -> I32x8 { + unsafe { _mm256_set_epi32(val0, val1, val2, val3, val4, val5, val6, val7) } +} + +/* + * Operations + */ + +// FMA + +// a * b + c +pub fn fma_f32x8(a: F32x8, b: F32x8, c: F32x8) -> F32x8 { + unsafe { _mm256_fmadd_ps(a, b, c) } +} + +// Horizontal sums + +#[inline] +pub fn horizontal_sum_f32x8(mut ymm: __m256) -> f32 { + unsafe { + let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1); + ymm = _mm256_add_ps(ymm, ymm2); + ymm = _mm256_hadd_ps(ymm, ymm); + ymm = _mm256_hadd_ps(ymm, ymm); + _mm256_cvtss_f32(ymm) + } +} + +#[inline] +pub fn horizontal_sum_and_f32_to_f16(mut ymm: __m256) -> f16 { + unsafe { + let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1); + ymm = _mm256_add_ps(ymm, ymm2); + ymm = _mm256_hadd_ps(ymm, ymm); + ymm = _mm256_hadd_ps(ymm, ymm); + f16::from_f32(_mm256_cvtss_f32(ymm)) + } +} diff --git a/src/tensor.rs b/src/tensor.rs index 803c58b..8aeb70e 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -187,6 +187,12 @@ impl WrappedPtr { } } +#[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd)] +enum BitSide { + Upper, + Lower, +} + fn compute_capacity_cols(dtype: TensorDType, cols: i64) -> i64 { match dtype { TensorDType::K4BitQuantization => compute_capacity_cols_k4(cols), @@ -311,7 +317,24 @@ impl Tensor { let idx = row * self.capacity_cols + col; match self.dtype { - TensorDType::K4BitQuantization => unimplemented!(), + TensorDType::K4BitQuantization => { + assert!(!self.q4_data.is_null()); + let (addr, side) = self.q4_address(row, col); + let addr_val: u8 = unsafe { *(addr as *const u8) }; + let quant_val: u8 = unsafe { + match side { + BitSide::Upper => (addr_val >> 4), + BitSide::Lower => (addr_val & 0x0F), + } + }; + let table: I16x8 = if quant_val <= 7 { + unsafe { load_i16x8(self.q4_data.add(row as usize * 32) as *const I16x8) } + } else { + unsafe { load_i16x8(self.q4_data.add(row as usize * 32 + 16) as *const I16x8) } + }; + let table = i16x8_as_f16_to_f32x8(table); + f32x8_get(table, (quant_val % 8) as usize) + } TensorDType::Float16 => { let val: f16 = unsafe { *(self.data.add(idx as usize * 2) as *const f16) }; val.to_f32() @@ -390,6 +413,22 @@ impl Tensor { panic!("Failed to allocate tensor"); } TENSORS_BYTES_ALLOCATED.fetch_add(layout.size(), std::sync::atomic::Ordering::Relaxed); + + let result = Self { + data, + q4_data: std::ptr::null_mut(), + #[cfg(feature = "opencl")] + opencl_data: Arc::new(RwLock::new(None)), + #[cfg(feature = "opencl")] + waiting_for_data: None, + dtype, + rows, + cols, + capacity_cols, + layout, + q4_layout: Layout::from_size_align(1, 1).unwrap(), + }; + // Even though we are uninitialized, we should zero out the extra space between the // columns. // Otherwise there might be problems later as other operations assume it is zeroed. @@ -397,7 +436,13 @@ impl Tensor { for row in 0..rows { let idx = row * capacity_cols + extra_col; match dtype { - TensorDType::K4BitQuantization => unimplemented!(), + TensorDType::K4BitQuantization => { + // We traverse each byte twice in this particular loop but eh who cares + let (addr, _side) = result.q4_address(row, extra_col); + unsafe { + *addr = 0; + } + } TensorDType::Float16 => { let val: f16 = f16::from_f32(0.0); unsafe { *(data.add(idx as usize * 2) as *mut f16) = val }; @@ -409,20 +454,7 @@ impl Tensor { } } - Self { - data, - q4_data: std::ptr::null_mut(), - #[cfg(feature = "opencl")] - opencl_data: Arc::new(RwLock::new(None)), - #[cfg(feature = "opencl")] - waiting_for_data: None, - dtype, - rows, - cols, - capacity_cols, - layout, - q4_layout: Layout::from_size_align(1, 1).unwrap(), - } + result } pub fn full(rows: i64, cols: i64, dtype: TensorDType, value: f32) -> Self { @@ -896,7 +928,13 @@ impl Tensor { if other.rows == 1 && self.is_on_cpu() { return self.matrix_vector_mul_transposed(other); } - let mut result = unsafe { Tensor::uninitialized(self.rows, other.rows, self.dtype) }; + // k4bit * float32 = float32 (not k4bit) + let result_dtype = if self.dtype != TensorDType::K4BitQuantization { + self.dtype + } else { + TensorDType::Float32 + }; + let mut result = unsafe { Tensor::uninitialized(self.rows, other.rows, result_dtype) }; #[cfg(feature = "opencl")] if self.is_on_gpu() { let od = self.opencl_data.write().unwrap(); @@ -1109,6 +1147,10 @@ impl Tensor { } } + pub fn quantize(&self) -> Tensor { + crate::weight_compression::quantize(self) + } + #[cfg(feature = "opencl")] fn matrix_mul_inplace_transposed_gpu(&mut self, src: &Tensor, other: &Tensor) { let mut self_od = self.opencl_data.write().unwrap(); @@ -1128,6 +1170,202 @@ impl Tensor { std::mem::drop(other_od); } + fn matrix_mul_inplace_transposed_k4bit_and_f32(&mut self, src: &Tensor, other: &Tensor) { + // Assume: size checks have been done already. + assert!(src.dtype == TensorDType::K4BitQuantization); + assert!(other.dtype == TensorDType::Float32); + assert!(self.dtype == TensorDType::Float32); + + unsafe { + let src_rows: usize = src.rows as usize; + let src_cols: usize = src.cols as usize; + let src_cols_capacity: usize = src.capacity_cols as usize; + let other_cols: usize = other.cols as usize; + let other_rows: usize = other.rows as usize; + let other_cols_capacity: usize = other.capacity_cols as usize; + let self_rows: usize = self.rows as usize; + let self_cols: usize = self.cols as usize; + let self_cols_capacity: usize = self.capacity_cols as usize; + + let src_data: *const u8 = src.data; + let other_data: *const f32 = other.data as *const f32; + let tgt_data: *mut f32 = self.data as *mut f32; + + // src_cols_its == also the shared dimension between src and other. + let src_cols_its = if src_cols % 32 == 0 { + src_cols / 32 + } else { + src_cols / 32 + 1 + }; + debug_assert!(!src.q4_data.is_null()); + + for row in 0..self_rows { + let quant0 = load_i16x8(src.q4_data.add(row * 32) as *const I16x8); + let quant1 = load_i16x8(src.q4_data.add(row * 32 + 16) as *const I16x8); + let quants: [F32x8; 2] = + [i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)]; + + for col in 0..self_cols { + #[inline] + fn load_f32( + other: *const f32, + row: usize, + col: usize, + ncols: usize, + nrows: usize, + cols_capacity: usize, + ) -> F32x8 { + unsafe { + if row >= nrows || col >= ncols { + f32x8_zero() + } else { + load_f32x8(other.add(row * cols_capacity + col) as *const F32x8) + } + } + } + + #[inline] + fn load_k4_to_f32( + tensor: &Tensor, + row: usize, + col: usize, + nrows: usize, + quants: *const F32x8, + ) -> (F32x8, F32x8, F32x8, F32x8) { + unsafe { + let M: u32 = 0xFFFFFFFF; + let MASKS: [I32x8; 8] = [ + i32x8_from_values_u32(M, M, M, M, M, M, M, M), + i32x8_from_values_u32(0, M, M, M, M, M, M, M), + i32x8_from_values_u32(0, 0, M, M, M, M, M, M), + i32x8_from_values_u32(0, 0, 0, M, M, M, M, M), + i32x8_from_values_u32(0, 0, 0, 0, M, M, M, M), + i32x8_from_values_u32(0, 0, 0, 0, 0, M, M, M), + i32x8_from_values_u32(0, 0, 0, 0, 0, 0, M, M), + i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, M), + ]; + let NOMASK: I32x8 = i32x8_from_values_u32(M, M, M, M, M, M, M, M); + let FULLMASK: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0); + + if row < nrows { + let col = col as i64; + let ncols = tensor.cols; + let (addr, side) = tensor.q4_address(row as i64, col); + let i = load_i16x8(addr as *const I16x8); + let even_mask = i16x8_singleton_u16(0x0F0F); + let odd_mask = i16x8_singleton_u16(0xF0F0); + let evens = and_i16x8(i, even_mask); + let odds = and_i16x8(i, odd_mask); + let odds = shift_right_by_4_i16x8(odds); + + let indices1 = extend_i8_to_i32_i32x8(odds); + let odds_shifted = shift_right_by_64_i128(odds); + let indices2 = extend_i8_to_i32_i32x8(odds_shifted); + let indices3 = extend_i8_to_i32_i32x8(evens); + let indices4 = + extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens)); + + let unquantized1: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices1); + let unquantized2: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices2); + let unquantized3: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices3); + let unquantized4: F32x8 = + gather_scale4_f32x8(quants as *const f32, indices4); + let quan1_mask: I32x8 = if col <= ncols - 8 { + NOMASK + } else if col < ncols { + MASKS[(col % 8) as usize] + } else { + FULLMASK + }; + let quan2_mask: I32x8 = if col <= ncols - 16 { + NOMASK + } else if col < ncols - 8 { + MASKS[(col % 8) as usize] + } else { + FULLMASK + }; + let quan3_mask: I32x8 = if col <= ncols - 24 { + NOMASK + } else if col < ncols - 16 { + MASKS[(col % 8) as usize] + } else { + FULLMASK + }; + let quan4_mask: I32x8 = if col <= ncols - 32 { + NOMASK + } else if col < ncols - 24 { + MASKS[(col % 8) as usize] + } else { + FULLMASK + }; + let unquantized1 = and_f32x8(unquantized1, quan1_mask); + let unquantized2 = and_f32x8(unquantized2, quan2_mask); + let unquantized3 = and_f32x8(unquantized3, quan3_mask); + let unquantized4 = and_f32x8(unquantized4, quan4_mask); + (unquantized1, unquantized2, unquantized3, unquantized4) + } else { + (f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()) + } + } + } + + let mut targets8: [F32x8; 4] = + [f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()]; + + for p in 0..src_cols_its { + let other8_0: F32x8 = load_f32( + other_data, + col, + p * 32, + other_cols, + other_rows, + other_cols_capacity, + ); + let other8_1: F32x8 = load_f32( + other_data, + col, + p * 32 + 8, + other_cols, + other_rows, + other_cols_capacity, + ); + let other8_2: F32x8 = load_f32( + other_data, + col, + p * 32 + 16, + other_cols, + other_rows, + other_cols_capacity, + ); + let other8_3: F32x8 = load_f32( + other_data, + col, + p * 32 + 24, + other_cols, + other_rows, + other_cols_capacity, + ); + let (src8_0, src8_1, src8_2, src8_3): (F32x8, F32x8, F32x8, F32x8) = + load_k4_to_f32(&src, row, p * 32, src_rows, quants.as_ptr()); + targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]); + targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]); + targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]); + targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]); + } + let target0 = horizontal_sum_f32x8(targets8[0]); + let target1 = horizontal_sum_f32x8(targets8[1]); + let target2 = horizontal_sum_f32x8(targets8[2]); + let target3 = horizontal_sum_f32x8(targets8[3]); + let target = target0 + target1 + target2 + target3; + *tgt_data.add(row * self_cols_capacity + col) = target; + } + } + } + } + /// Matrix multiplication done in-place, but the second matrix is transposed. /// With this, you can avoid using .transpose() on the second matrix. pub fn matrix_mul_inplace_transposed(&mut self, src: &Tensor, other: &Tensor) { @@ -1147,7 +1385,9 @@ impl Tensor { self.rows, self.cols, other.rows, other.cols ); } - if src.dtype != other.dtype { + if src.dtype != other.dtype + && (src.dtype != TensorDType::K4BitQuantization || other.dtype != TensorDType::Float32) + { panic!("Invalid matrix multiplication, different dtypes"); } if self.rows != src.rows { @@ -1157,6 +1397,10 @@ impl Tensor { panic!("Invalid matrix multiplication, different number of cols"); } + if src.dtype == TensorDType::K4BitQuantization && other.dtype == TensorDType::Float32 { + return self.matrix_mul_inplace_transposed_k4bit_and_f32(src, other); + } + match src.dtype { TensorDType::K4BitQuantization => unimplemented!(), TensorDType::Float32 => { @@ -1883,6 +2127,152 @@ impl Tensor { result } + /// Creates a tensor with TensorDType::K4BitQuantization, using two functions to initialize the + /// values and quantization lookup table. + /// + /// The first function must always return numbers from 0 to 15. These are put in the matrix. + /// The second function must return, for each row, 16 floats that correspond the numbers + /// returns from the first function. For this quantization scheme, each row has its own set of + /// floats. + #[inline] + pub fn make_k4bit_from_fn( + rows: i64, + cols: i64, + mut get_value: F, + mut get_lookup_table: F2, + ) -> Self + where + F: FnMut(i64, i64) -> u8, + F2: FnMut(i64) -> [f32; 16], + { + let mut result = + unsafe { Tensor::uninitialized(rows, cols, TensorDType::K4BitQuantization) }; + result.allocate_q4_data(); + assert!(!result.q4_data.is_null()); + + unsafe { + for row in 0..rows { + // Set the lookup table to + let lookup_table = get_lookup_table(row); + /* + let table1 = f32x8_from_values( + lookup_table[0], + lookup_table[1], + lookup_table[2], + lookup_table[3], + lookup_table[4], + lookup_table[5], + lookup_table[6], + lookup_table[7], + ); + let table2 = f32x8_from_values( + lookup_table[8], + lookup_table[9], + lookup_table[10], + lookup_table[11], + lookup_table[12], + lookup_table[13], + lookup_table[14], + lookup_table[15], + ); + */ + let table1 = f32x8_from_values( + lookup_table[7], + lookup_table[6], + lookup_table[5], + lookup_table[4], + lookup_table[3], + lookup_table[2], + lookup_table[1], + lookup_table[0], + ); + let table2 = f32x8_from_values( + lookup_table[15], + lookup_table[14], + lookup_table[13], + lookup_table[12], + lookup_table[11], + lookup_table[10], + lookup_table[9], + lookup_table[8], + ); + + let table1 = f32x8_to_i16x8_as_f16(table1); + let table2 = f32x8_to_i16x8_as_f16(table2); + store_i16x8(result.q4_data.add(row as usize * 32) as *mut I16x8, table1); + store_i16x8( + result.q4_data.add(row as usize * 32 + 16) as *mut I16x8, + table2, + ); + + for col in 0..cols { + let v = get_value(row, col); + + let (addr, side) = result.q4_address(row, col); + let mut addr_value = *addr; + match side { + BitSide::Upper => { + addr_value = (addr_value & 0x0F) | (v << 4); + } + BitSide::Lower => { + addr_value = (addr_value & 0xF0) | v; + } + } + *addr = addr_value; + } + } + } + result + } + + /// K4 bit quantization does not store the values in successive bits, but rather interleaved. + /// + /// byte + /// <---> + /// 00 88 11 99 22 AA 33 BB 44 CC 55 DD 66 EE 77 FF + /// (4 bits each, i.e. nibbles) + /// (actually goes up to 32 but I ran out of space) + /// + /// Upper 4 bits are used if: col % 32 < 16 + /// Lower 4 bits are used if: col % 32 >= 16 + /// + /// The reason it works like this is to make matrix multiplication SIMD code a bit simpler. The + /// instructions don't like 4-bit pieces. + #[inline] + fn q4_address(&self, row: i64, col: i64) -> (*mut u8, BitSide) { + let row = row as usize; + let col = col as usize; + let col_base = ((col / 32) * 32); + let mut offset = (row * self.capacity_cols as usize + col_base as usize) / 2; + unsafe { + if col % 32 < 16 { + offset += col % 16; + (self.data.add(offset), BitSide::Upper) + } else { + offset += col % 16; + (self.data.add(offset), BitSide::Lower) + } + } + } + + fn allocate_q4_data(&mut self) { + if self.dtype != TensorDType::K4BitQuantization { + panic!("Can only allocate q4 data for K4BitQuantization"); + } + // Already allocated? back off + if !self.q4_data.is_null() { + return; + } + + let layout = Layout::from_size_align(self.rows as usize * 32, 32).unwrap(); + let q4_data = unsafe { std::alloc::alloc_zeroed(layout) }; + if q4_data.is_null() { + panic!("Failed to allocate q4 data"); + } + self.q4_data = q4_data; + self.q4_layout = layout; + } + pub fn zeros(rows: i64, cols: i64, dtype: TensorDType) -> Self { if rows == 0 || cols == 0 { let mut tensor = Self::empty(); @@ -3031,4 +3421,112 @@ mod tests { } } } + + #[test] + fn tiny_quantized_16x16_matrix_equals_regular_16x16_matrix() { + for _ in 0..100 { + let reference = Tensor::random(16, 16, TensorDType::Float32); + + let quantized = Tensor::make_k4bit_from_fn( + 16, + 16, + |_row, col| col as u8, + |row| { + let mut result: [f32; 16] = [0.0; 16]; + for col in 0..16 { + result[col] = reference.get_f32(row, col as i64); + } + result + }, + ); + + assert_eq!(reference.rows(), quantized.rows()); + assert_eq!(reference.cols(), quantized.cols()); + + for row in 0..reference.rows { + for col in 0..reference.cols { + // The quantized table always uses f16 so values may not be 100% equal. + assert_relative_eq!( + reference.get_f32(row, col), + quantized.get_f32(row, col), + epsilon = 1e-3, + ); + } + } + } + } + + #[test] + fn quantized_matrices_matrix_mul_transposed_correctly() { + let mut rng = rand::thread_rng(); + for _ in 0..100 { + let a = rng.gen_range(1..=128); + let b = rng.gen_range(1..=128); + let mut reference = Tensor::zeros(a, b, TensorDType::Float32); + let other_matrix = Tensor::random(128, b, TensorDType::Float32); + + let mut quant_values: Vec> = Vec::with_capacity(a as usize); + for row in 0..a { + let mut quant_values_for_row: Vec = Vec::with_capacity(16); + for _ in 0..16 { + quant_values_for_row.push(rng.gen_range(0.0..=1.0)); + } + quant_values.push(quant_values_for_row); + } + + let mut quantized_values: Vec> = Vec::with_capacity(a as usize); + for row in 0..a { + let mut quant_values_for_row: Vec = Vec::with_capacity(b as usize); + for col in 0..b { + let i = rng.gen_range(0..=15); + reference.set_f32(row, col, quant_values[row as usize][i as usize]); + quant_values_for_row.push(i as u8); + } + quantized_values.push(quant_values_for_row); + } + + let quantized = Tensor::make_k4bit_from_fn( + a, + b, + |row, col| quantized_values[row as usize][col as usize], + |row| { + let mut result: [f32; 16] = [0.0; 16]; + for col in 0..16 { + result[col] = quant_values[row as usize][col]; + } + result + }, + ); + + assert_eq!(reference.rows(), quantized.rows()); + assert_eq!(reference.cols(), quantized.cols()); + + for row in 0..reference.rows { + for col in 0..reference.cols { + // The quantized table always uses f16 so values may not be 100% equal. + assert_relative_eq!( + reference.get_f32(row, col), + quantized.get_f32(row, col), + epsilon = 1e-1, + ); + } + } + + let mult1 = reference.matrix_mul_transposed(&other_matrix); + let mult2 = quantized.matrix_mul_transposed(&other_matrix); + + assert_eq!(mult1.rows(), mult2.rows()); + assert_eq!(mult1.cols(), mult2.cols()); + + for row in 0..mult1.rows { + for col in 0..mult1.cols { + assert_relative_eq!( + mult1.get_f32(row, col), + mult2.get_f32(row, col), + epsilon = 1e-1, + ); + } + } + } + } } diff --git a/src/transformer.rs b/src/transformer.rs index f6f823d..cf3dc6c 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -457,6 +457,9 @@ impl FeedForward { FromPiecesDirection::Rows, )?; + w1 = crate::weight_compression::quantize(&w1); + panic!("stop"); + if data_settings.force_f16 { w1 = w1.to_f16(); w2 = w2.to_f16(); diff --git a/src/weight_compression.rs b/src/weight_compression.rs index 92f59f3..6f5f9d1 100644 --- a/src/weight_compression.rs +++ b/src/weight_compression.rs @@ -12,9 +12,6 @@ pub fn quantize(tensor: &Tensor) -> Tensor { let mut result = Tensor::zeros(tensor.rows(), tensor.cols(), tensor.dtype()); for row in 0..tensor.rows() { let mut values: Vec = Vec::with_capacity(tensor.cols() as usize); - if row % 500 == 0 { - println!("{}", row,); - } values.truncate(0); let mut mi: f32 = std::f32::MAX; let mut ma: f32 = std::f32::MIN;