Compare commits

...

5 Commits

Author SHA1 Message Date
Mikko Juola d7d13cd474 Bucketize the 4-bit quantization for more accuracy.
Currently going with 512 columns per bucket. Need to test a bit should I
go even smaller, 256 columns per bucket.
3 years ago
Mikko Juola 8cc82ae7e2 Make separate matrix_vector_muls for 4-bit quantization rather than using matrix_mul for them. 3 years ago
Mikko Juola 2f3e9bc0f5 K4 bit inference works now. Performance isn't as good as I'd like it to be though. 3 years ago
Mikko Juola 40121e1c82 Multithread the k4 * f32 matrix multiplication. 3 years ago
Mikko Juola b8946da2d8 Implement matrix multiplication for 4-bit * 32-bit floats.
As of this commit, test works. But I want to optimize this a bit, seeing
if increasing load instruction : arithmetic instruction ratio will make
single-threaded performance a bit speedier.
3 years ago

@ -97,6 +97,7 @@ pub fn tensor_benchmarks(c: &mut Criterion) {
let orig_84096_1 = Tensor::zeros(8, 4096, TensorDType::Float32);
let orig_84096_2 = Tensor::zeros(4096, 4096, TensorDType::Float32);
let orig_84096_quant = orig_84096_1.quantize();
let mut result_84096 = Tensor::zeros(8, 4096, TensorDType::Float32);
let orig_84096_1_f16 = Tensor::zeros(8, 4096, TensorDType::Float16);
@ -111,6 +112,38 @@ pub fn tensor_benchmarks(c: &mut Criterion) {
let m1_f16 = m1.to_f16();
let m2_f16 = m2.to_f16();
let quant = m1.quantize();
let quant2 = m2.quantize();
c.bench_function(
"1024x128 * 1x128 matrix vector transposed multiplication, k4 quantized * f32",
|b| {
b.iter(|| {
let _ = quant.matrix_vector_mul_transposed(black_box(&m2));
})
},
);
c.bench_function(
"1024x128 * 1x128 matrix vector transposed multiplication, f32 quantized * k4",
|b| {
b.iter(|| {
let _ = m1.matrix_vector_mul_transposed(black_box(&quant2));
})
},
);
c.bench_function(
"matrix multiplication 8x4096 @ 4096x4096 k4 quantized * f32 in-place, transposed",
|b| {
b.iter(|| {
let _ = result_84096.matrix_mul_inplace_transposed(
black_box(&orig_84096_quant),
black_box(&orig_84096_2),
);
})
},
);
c.bench_function(
"1024x128 * 1x128 matrix vector transposed multiplication, f32",
|b| {

@ -53,6 +53,9 @@ struct Cli {
#[arg(long, action)]
f16: bool,
#[arg(long, action)]
k4: bool,
#[cfg(feature = "opencl")]
#[arg(long)]
opencl_device: Option<usize>,
@ -233,6 +236,9 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
if cli.f16 {
data_settings = data_settings.force_f16();
}
if cli.k4 {
data_settings = data_settings.force_k4();
}
pln!("Loading transformer weights from {}...", model_path);
let tr = Transformer::from_unpickled(

@ -3,6 +3,7 @@
use core::arch::x86_64::*;
use half::f16;
use std::fmt::Write;
pub type I32x8 = __m256i;
pub type F32x8 = __m256;
@ -37,6 +38,11 @@ pub fn gather_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 {
unsafe { _mm256_i32gather_ps(ptr, indices, 1) }
}
#[inline]
pub fn gather_scale4_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 {
unsafe { _mm256_i32gather_ps(ptr, indices, 4) }
}
/* ------------------ */
/* Conversions */
/* ------------------ */
@ -51,22 +57,121 @@ pub fn f32x8_to_i16x8_as_f16(a: F32x8) -> I16x8 {
unsafe { _mm256_cvtps_ph(a, 0) }
}
#[inline]
/// Converts f32x8 to i32x8, by just casting the bits. I.e. it does not round any numbers or
/// anything, it just copies the bits.
pub fn f32x8_to_i32x8_bitcast(a: F32x8) -> I32x8 {
unsafe { _mm256_castps_si256(a) }
}
/*
* ------------------
* Accessing individual elements
*/
// Rust has no const arguments (yet, maybe in future).
// So we have this awkward match statement in each of these.
#[inline]
pub fn f32x8_get(a: F32x8, idx: usize) -> f32 {
unsafe {
let a = f32x8_to_i32x8_bitcast(a);
let a = match idx {
0 => _mm256_extract_epi32(a, 0),
1 => _mm256_extract_epi32(a, 1),
2 => _mm256_extract_epi32(a, 2),
3 => _mm256_extract_epi32(a, 3),
4 => _mm256_extract_epi32(a, 4),
5 => _mm256_extract_epi32(a, 5),
6 => _mm256_extract_epi32(a, 6),
7 => _mm256_extract_epi32(a, 7),
_ => panic!("f32x8_get: index out of bounds"),
};
// bitcast the i32 back to f32
core::mem::transmute(a)
}
}
#[inline]
pub fn i32x8_get(a: I32x8, idx: usize) -> i32 {
unsafe {
let a = match idx {
0 => _mm256_extract_epi32(a, 0),
1 => _mm256_extract_epi32(a, 1),
2 => _mm256_extract_epi32(a, 2),
3 => _mm256_extract_epi32(a, 3),
4 => _mm256_extract_epi32(a, 4),
5 => _mm256_extract_epi32(a, 5),
6 => _mm256_extract_epi32(a, 6),
7 => _mm256_extract_epi32(a, 7),
_ => panic!("i32x8_get: index out of bounds"),
};
a
}
}
#[inline]
pub fn i16x8_get(a: I16x8, idx: usize) -> i16 {
unsafe {
let a = match idx {
0 => _mm_extract_epi16(a, 0),
1 => _mm_extract_epi16(a, 1),
2 => _mm_extract_epi16(a, 2),
3 => _mm_extract_epi16(a, 3),
4 => _mm_extract_epi16(a, 4),
5 => _mm_extract_epi16(a, 5),
6 => _mm_extract_epi16(a, 6),
7 => _mm_extract_epi16(a, 7),
_ => panic!("i16x8_get: index out of bounds"),
};
a as i16
}
}
/*
* Constants, creating from constants
*/
#[inline]
pub fn f32x8_zero() -> F32x8 {
unsafe { _mm256_setzero_ps() }
}
#[inline]
pub fn i16x8_zero() -> I16x8 {
unsafe { _mm_setzero_si128() }
}
#[inline]
pub fn i16x8_singleton(value: i16) -> I16x8 {
unsafe { _mm_set1_epi16(value) }
}
#[inline]
pub fn i16x8_singleton_u16(value: u16) -> I16x8 {
unsafe { _mm_set1_epi16(value as i16) }
}
#[inline]
pub fn f32x8_singleton(value: f32) -> F32x8 {
unsafe { _mm256_set1_ps(value) }
}
#[inline]
pub fn f32x8_from_values(
val0: f32,
val1: f32,
val2: f32,
val3: f32,
val4: f32,
val5: f32,
val6: f32,
val7: f32,
) -> F32x8 {
unsafe { _mm256_set_ps(val0, val1, val2, val3, val4, val5, val6, val7) }
}
#[inline]
pub fn i32x8_from_values(
val0: i32,
val1: i32,
@ -80,6 +185,45 @@ pub fn i32x8_from_values(
unsafe { _mm256_set_epi32(val0, val1, val2, val3, val4, val5, val6, val7) }
}
#[inline]
pub fn i32x8_from_values_u32(
val0: u32,
val1: u32,
val2: u32,
val3: u32,
val4: u32,
val5: u32,
val6: u32,
val7: u32,
) -> I32x8 {
unsafe {
_mm256_set_epi32(
val0 as i32,
val1 as i32,
val2 as i32,
val3 as i32,
val4 as i32,
val5 as i32,
val6 as i32,
val7 as i32,
)
}
}
#[inline]
pub fn i16x8_from_values(
val0: i16,
val1: i16,
val2: i16,
val3: i16,
val4: i16,
val5: i16,
val6: i16,
val7: i16,
) -> I16x8 {
unsafe { _mm_set_epi16(val0, val1, val2, val3, val4, val5, val6, val7) }
}
/*
* Operations
*/
@ -87,10 +231,54 @@ pub fn i32x8_from_values(
// FMA
// a * b + c
#[inline]
pub fn fma_f32x8(a: F32x8, b: F32x8, c: F32x8) -> F32x8 {
unsafe { _mm256_fmadd_ps(a, b, c) }
}
// bitwise and
#[inline]
pub fn and_i16x8(a: I16x8, b: I16x8) -> I16x8 {
unsafe { _mm_and_si128(a, b) }
}
#[inline]
pub fn and_i32x8(a: I32x8, b: I32x8) -> I32x8 {
unsafe { _mm256_and_si256(a, b) }
}
#[inline]
pub fn and_f32x8(a: F32x8, b: I32x8) -> F32x8 {
unsafe { std::mem::transmute(_mm256_and_si256(std::mem::transmute(a), b)) }
}
// shift right by 4 bits exactly, for each individual i16 value.
// extends by zeros from left.
#[inline]
pub fn shift_right_by_4_i16x8(a: I16x8) -> I16x8 {
unsafe { _mm_srli_epi16(a, 4) }
}
// shift right by half of an entire i16x8
// extends by zeros from left.
#[inline]
pub fn shift_right_by_64_i128(a: I16x8) -> I16x8 {
unsafe { _mm_srli_si128(a, 64 / 8) }
}
// Extends 8 i8 values into 7 i16 values
//
// XXYYZZ -> 00XX00YY00ZZ
pub fn extend_i8_to_i16_i16x8(a: I16x8) -> I16x8 {
unsafe { _mm_cvtepi8_epi16(a) }
}
// Extends 8 i8 values into 4 i32 values
pub fn extend_i8_to_i32_i32x8(a: I16x8) -> I32x8 {
let i = extend_i8_to_i16_i16x8(a);
unsafe { _mm256_cvtepu16_epi32(i) }
}
// Horizontal sums
#[inline]
@ -114,3 +302,50 @@ pub fn horizontal_sum_and_f32_to_f16(mut ymm: __m256) -> f16 {
f16::from_f32(_mm256_cvtss_f32(ymm))
}
}
/*
* Debugging
*/
/// Prints a binary representation of i16x8 to stdout in this form:
///
/// ```ignore
/// 0 0 0 0
/// 0x0000 0x0000 0x0000 0x0000
/// 0000000000000000 0000000000000000 0000000000000000 0000000000000000 etc.
/// ```
///
/// decimal on first line, hex on second, binary on third.
pub fn print_i16x8(a: I16x8) {
let mut decimal_line = String::new();
let mut hex_line = String::new();
let mut binary_line = String::new();
for i in 0..8 {
let val = i16x8_get(a, i);
write!(decimal_line, "{:>5} ", val).unwrap();
write!(hex_line, "0x{:04X} ", val).unwrap();
write!(binary_line, "{:016b} ", val).unwrap();
}
println!("{}", decimal_line.trim_end());
println!("{}", hex_line.trim_end());
println!("{}", binary_line.trim_end());
}
pub fn print_i32x8(a: I32x8) {
let mut decimal_line = String::new();
let mut hex_line = String::new();
let mut binary_line = String::new();
for i in 0..8 {
let val = i32x8_get(a, i);
write!(decimal_line, "{:>10} ", val).unwrap();
write!(hex_line, "0x{:08X} ", val).unwrap();
write!(binary_line, "{:032b} ", val).unwrap();
}
println!("{}", decimal_line.trim_end());
println!("{}", hex_line.trim_end());
println!("{}", binary_line.trim_end());
}

@ -0,0 +1,116 @@
// This file contains platform-specific SIMD so that rest of rllama does not need to care which
// platform it is on.
use core::arch::aarch64::*;
use half::f16;
pub type I32x8 = int32x4x2_t;
pub type F32x8 = float32x4x2_t;
pub type I16x8 = int16x8_t;
/* ------------------ */
/* Loading and storing things */
/* ------------------ */
#[inline]
pub fn load_i16x8(ptr: *const I16x8) -> I16x8 {
unsafe { vld1q_s16(ptr) }
}
#[inline]
pub fn store_i16x8(ptr: *mut I16x8, a: I16x8) {
unsafe { vst1q_s16(ptr, a) }
}
#[inline]
pub fn load_f32x8(ptr: *const F32x8) -> F32x8 {
unsafe { vld1q_f32_x2(ptr as *const f32) }
}
#[inline]
pub fn store_f32x8(ptr: *mut F32x8, a: F32x8) {
unsafe { vst1q_f32_x2(ptr as *mut f32, a) }
}
#[inline]
pub fn gather_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 {
unsafe { _mm256_i32gather_ps(ptr, indices, 1) }
}
/* ------------------ */
/* Conversions */
/* ------------------ */
#[inline]
pub fn i16x8_as_f16_to_f32x8(a: I16x8) -> F32x8 {
unsafe { _mm256_cvtph_ps(a) }
}
#[inline]
pub fn f32x8_to_i16x8_as_f16(a: F32x8) -> I16x8 {
unsafe { _mm256_cvtps_ph(a, 0) }
}
/*
* Constants, creating from constants
*/
pub fn f32x8_zero() -> F32x8 {
unsafe { _mm256_setzero_ps() }
}
pub fn i16x8_zero() -> I16x8 {
unsafe { _mm_setzero_si128() }
}
pub fn f32x8_singleton(value: f32) -> F32x8 {
unsafe { _mm256_set1_ps(value) }
}
pub fn i32x8_from_values(
val0: i32,
val1: i32,
val2: i32,
val3: i32,
val4: i32,
val5: i32,
val6: i32,
val7: i32,
) -> I32x8 {
unsafe { _mm256_set_epi32(val0, val1, val2, val3, val4, val5, val6, val7) }
}
/*
* Operations
*/
// FMA
// a * b + c
pub fn fma_f32x8(a: F32x8, b: F32x8, c: F32x8) -> F32x8 {
unsafe { _mm256_fmadd_ps(a, b, c) }
}
// Horizontal sums
#[inline]
pub fn horizontal_sum_f32x8(mut ymm: __m256) -> f32 {
unsafe {
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1);
ymm = _mm256_add_ps(ymm, ymm2);
ymm = _mm256_hadd_ps(ymm, ymm);
ymm = _mm256_hadd_ps(ymm, ymm);
_mm256_cvtss_f32(ymm)
}
}
#[inline]
pub fn horizontal_sum_and_f32_to_f16(mut ymm: __m256) -> f16 {
unsafe {
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1);
ymm = _mm256_add_ps(ymm, ymm2);
ymm = _mm256_hadd_ps(ymm, ymm);
ymm = _mm256_hadd_ps(ymm, ymm);
f16::from_f32(_mm256_cvtss_f32(ymm))
}
}

@ -0,0 +1,116 @@
// This file contains platform-specific SIMD so that rest of rllama does not need to care which
// platform it is on.
use core::arch::x86_64::*;
use half::f16;
pub type I32x8 = __m256i;
pub type F32x8 = __m256;
pub type I16x8 = __m128i;
/* ------------------ */
/* Loading and storing things */
/* ------------------ */
#[inline]
pub fn load_i16x8(ptr: *const I16x8) -> I16x8 {
unsafe { _mm_loadu_si128(ptr) }
}
#[inline]
pub fn store_i16x8(ptr: *mut I16x8, a: I16x8) {
unsafe { _mm_storeu_si128(ptr, a) }
}
#[inline]
pub fn load_f32x8(ptr: *const F32x8) -> F32x8 {
unsafe { _mm256_loadu_ps(ptr as *const f32) }
}
#[inline]
pub fn store_f32x8(ptr: *mut F32x8, a: F32x8) {
unsafe { _mm256_storeu_ps(ptr as *mut f32, a) }
}
#[inline]
pub fn gather_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 {
unsafe { _mm256_i32gather_ps(ptr, indices, 1) }
}
/* ------------------ */
/* Conversions */
/* ------------------ */
#[inline]
pub fn i16x8_as_f16_to_f32x8(a: I16x8) -> F32x8 {
unsafe { _mm256_cvtph_ps(a) }
}
#[inline]
pub fn f32x8_to_i16x8_as_f16(a: F32x8) -> I16x8 {
unsafe { _mm256_cvtps_ph(a, 0) }
}
/*
* Constants, creating from constants
*/
pub fn f32x8_zero() -> F32x8 {
unsafe { _mm256_setzero_ps() }
}
pub fn i16x8_zero() -> I16x8 {
unsafe { _mm_setzero_si128() }
}
pub fn f32x8_singleton(value: f32) -> F32x8 {
unsafe { _mm256_set1_ps(value) }
}
pub fn i32x8_from_values(
val0: i32,
val1: i32,
val2: i32,
val3: i32,
val4: i32,
val5: i32,
val6: i32,
val7: i32,
) -> I32x8 {
unsafe { _mm256_set_epi32(val0, val1, val2, val3, val4, val5, val6, val7) }
}
/*
* Operations
*/
// FMA
// a * b + c
pub fn fma_f32x8(a: F32x8, b: F32x8, c: F32x8) -> F32x8 {
unsafe { _mm256_fmadd_ps(a, b, c) }
}
// Horizontal sums
#[inline]
pub fn horizontal_sum_f32x8(mut ymm: __m256) -> f32 {
unsafe {
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1);
ymm = _mm256_add_ps(ymm, ymm2);
ymm = _mm256_hadd_ps(ymm, ymm);
ymm = _mm256_hadd_ps(ymm, ymm);
_mm256_cvtss_f32(ymm)
}
}
#[inline]
pub fn horizontal_sum_and_f32_to_f16(mut ymm: __m256) -> f16 {
unsafe {
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1);
ymm = _mm256_add_ps(ymm, ymm2);
ymm = _mm256_hadd_ps(ymm, ymm);
ymm = _mm256_hadd_ps(ymm, ymm);
f16::from_f32(_mm256_cvtss_f32(ymm))
}
}

@ -0,0 +1,419 @@
fn matrix_mul_inplace_transposed_f32_and_k4bit(&mut self, src: &Tensor, other: &Tensor) {
// Assume: size checks have been done already.
assert!(src.dtype == TensorDType::Float32);
assert!(other.dtype == TensorDType::K4BitQuantization);
assert!(self.dtype == TensorDType::Float32);
unsafe {
let src_rows: usize = src.rows as usize;
let src_cols: usize = src.cols as usize;
let src_cols_capacity: usize = src.capacity_cols as usize;
let other_cols: usize = other.cols as usize;
let other_rows: usize = other.rows as usize;
let other_cols_capacity: usize = other.capacity_cols as usize;
let self_rows: usize = self.rows as usize;
let self_cols: usize = self.cols as usize;
let self_cols_capacity: usize = self.capacity_cols as usize;
// src_cols_its == also the shared dimension between src and other.
let src_cols_its = if src_cols % 32 == 0 {
src_cols / 32
} else {
src_cols / 32 + 1
};
debug_assert!(!other.q4_data.is_null());
let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data);
let other_data: WrappedPtr = WrappedPtr::wrap(other.data);
let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data);
let other_q4_data: WrappedPtr = WrappedPtr::wrap(other.q4_data);
let nthreads: usize = rayon::current_num_threads();
(0..nthreads).into_par_iter().for_each(|thread_idx| {
let other_q4_data: *const u8 = other_q4_data.unwrap() as *const u8;
let src_data: *const f32 = src_data_wrap.unwrap() as *const f32;
let other_data: *const u8 = other_data.unwrap() as *const u8;
let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32;
for row in 0..self_rows {
for col in 0..self_cols {
let row_col = row * self_cols + col;
if row_col % nthreads != thread_idx {
continue;
}
let quant0 = load_i16x8(other_q4_data.add(col * 32) as *const I16x8);
let quant1 = load_i16x8(other_q4_data.add(col * 32 + 16) as *const I16x8);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
#[inline]
fn load_f32(
src: *const f32,
row: usize,
col: usize,
ncols: usize,
nrows: usize,
cols_capacity: usize,
) -> F32x8 {
unsafe {
if row >= nrows || col >= ncols {
f32x8_zero()
} else {
load_f32x8(src.add(row * cols_capacity + col) as *const F32x8)
}
}
}
#[inline]
fn load_k4_to_f32(
tensor: &Tensor,
row: usize,
col: usize,
nrows: usize,
quants: *const F32x8,
) -> (F32x8, F32x8, F32x8, F32x8) {
unsafe {
let m: u32 = 0xFFFFFFFF;
let masks: [I32x8; 8] = [
i32x8_from_values_u32(m, m, m, m, m, m, m, m),
i32x8_from_values_u32(0, m, m, m, m, m, m, m),
i32x8_from_values_u32(0, 0, m, m, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, m, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, m),
];
let nomask: I32x8 = i32x8_from_values_u32(m, m, m, m, m, m, m, m);
let fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0);
if row < nrows {
let col = col as i64;
let ncols = tensor.cols;
let (addr, side) = tensor.q4_address(row as i64, col);
let i = load_i16x8(addr as *const I16x8);
let even_mask = i16x8_singleton_u16(0x0F0F);
let odd_mask = i16x8_singleton_u16(0xF0F0);
let evens = and_i16x8(i, even_mask);
let odds = and_i16x8(i, odd_mask);
let odds = shift_right_by_4_i16x8(odds);
let indices1 = extend_i8_to_i32_i32x8(odds);
let odds_shifted = shift_right_by_64_i128(odds);
let indices2 = extend_i8_to_i32_i32x8(odds_shifted);
let indices3 = extend_i8_to_i32_i32x8(evens);
let indices4 =
extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens));
let unquantized1: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices1);
let unquantized2: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices2);
let unquantized3: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices3);
let unquantized4: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices4);
let quan1_mask: I32x8 = if col <= ncols - 8 {
nomask
} else if col < ncols {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan2_mask: I32x8 = if col <= ncols - 16 {
nomask
} else if col < ncols - 8 {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan3_mask: I32x8 = if col <= ncols - 24 {
nomask
} else if col < ncols - 16 {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan4_mask: I32x8 = if col <= ncols - 32 {
nomask
} else if col < ncols - 24 {
masks[(col % 8) as usize]
} else {
fullmask
};
let unquantized1 = and_f32x8(unquantized1, quan1_mask);
let unquantized2 = and_f32x8(unquantized2, quan2_mask);
let unquantized3 = and_f32x8(unquantized3, quan3_mask);
let unquantized4 = and_f32x8(unquantized4, quan4_mask);
(unquantized1, unquantized2, unquantized3, unquantized4)
} else {
(f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero())
}
}
}
let mut targets8: [F32x8; 4] =
[f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()];
for p in 0..src_cols_its {
let (other8_0, other8_1, other8_2, other8_3) =
load_k4_to_f32(&other, col, p * 32, other_rows, quants.as_ptr());
let src8_0 = load_f32(
src_data,
row,
p * 32,
src_cols,
src_rows,
src_cols_capacity,
);
let src8_1 = load_f32(
src_data,
row,
p * 32 + 8,
src_cols,
src_rows,
src_cols_capacity,
);
let src8_2 = load_f32(
src_data,
row,
p * 32 + 16,
src_cols,
src_rows,
src_cols_capacity,
);
let src8_3 = load_f32(
src_data,
row,
p * 32 + 24,
src_cols,
src_rows,
src_cols_capacity,
);
targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]);
targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]);
targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]);
targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]);
}
let target0 = horizontal_sum_f32x8(targets8[0]);
let target1 = horizontal_sum_f32x8(targets8[1]);
let target2 = horizontal_sum_f32x8(targets8[2]);
let target3 = horizontal_sum_f32x8(targets8[3]);
let target = target0 + target1 + target2 + target3;
*tgt_data.add(row * self_cols_capacity + col) = target;
}
}
});
}
}
fn matrix_mul_inplace_transposed_k4bit_and_f32(&mut self, src: &Tensor, other: &Tensor) {
// Assume: size checks have been done already.
assert!(src.dtype == TensorDType::K4BitQuantization);
assert!(other.dtype == TensorDType::Float32);
assert!(self.dtype == TensorDType::Float32);
unsafe {
let src_rows: usize = src.rows as usize;
let src_cols: usize = src.cols as usize;
let src_cols_capacity: usize = src.capacity_cols as usize;
let other_cols: usize = other.cols as usize;
let other_rows: usize = other.rows as usize;
let other_cols_capacity: usize = other.capacity_cols as usize;
let self_rows: usize = self.rows as usize;
let self_cols: usize = self.cols as usize;
let self_cols_capacity: usize = self.capacity_cols as usize;
// src_cols_its == also the shared dimension between src and other.
let src_cols_its = if src_cols % 32 == 0 {
src_cols / 32
} else {
src_cols / 32 + 1
};
debug_assert!(!src.q4_data.is_null());
let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data);
let other_data: WrappedPtr = WrappedPtr::wrap(other.data);
let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data);
let src_q4_data: WrappedPtr = WrappedPtr::wrap(src.q4_data);
let nthreads: usize = rayon::current_num_threads();
(0..nthreads).into_par_iter().for_each(|thread_idx| {
let src_q4_data: *const u8 = src_q4_data.unwrap() as *const u8;
let src_data: *const u8 = src_data_wrap.unwrap() as *const u8;
let other_data: *const f32 = other_data.unwrap() as *const f32;
let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32;
for row in 0..self_rows {
let quant0 = load_i16x8(src_q4_data.add(row * 32) as *const I16x8);
let quant1 = load_i16x8(src_q4_data.add(row * 32 + 16) as *const I16x8);
let quants: [F32x8; 2] =
[i16x8_as_f16_to_f32x8(quant0), i16x8_as_f16_to_f32x8(quant1)];
for col in 0..self_cols {
let row_col = row * self_cols + col;
if row_col % nthreads != thread_idx {
continue;
}
#[inline]
fn load_f32(
other: *const f32,
row: usize,
col: usize,
ncols: usize,
nrows: usize,
cols_capacity: usize,
) -> F32x8 {
unsafe {
if row >= nrows || col >= ncols {
f32x8_zero()
} else {
load_f32x8(other.add(row * cols_capacity + col) as *const F32x8)
}
}
}
#[inline]
fn load_k4_to_f32(
tensor: &Tensor,
row: usize,
col: usize,
nrows: usize,
quants: *const F32x8,
) -> (F32x8, F32x8, F32x8, F32x8) {
unsafe {
let m: u32 = 0xFFFFFFFF;
let masks: [I32x8; 8] = [
i32x8_from_values_u32(m, m, m, m, m, m, m, m),
i32x8_from_values_u32(0, m, m, m, m, m, m, m),
i32x8_from_values_u32(0, 0, m, m, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, m, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, m, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, m, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, m, m),
i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, m),
];
let nomask: I32x8 = i32x8_from_values_u32(m, m, m, m, m, m, m, m);
let fullmask: I32x8 = i32x8_from_values_u32(0, 0, 0, 0, 0, 0, 0, 0);
if row < nrows {
let col = col as i64;
let ncols = tensor.cols;
let (addr, side) = tensor.q4_address(row as i64, col);
let i = load_i16x8(addr as *const I16x8);
let even_mask = i16x8_singleton_u16(0x0F0F);
let odd_mask = i16x8_singleton_u16(0xF0F0);
let evens = and_i16x8(i, even_mask);
let odds = and_i16x8(i, odd_mask);
let odds = shift_right_by_4_i16x8(odds);
let indices1 = extend_i8_to_i32_i32x8(odds);
let odds_shifted = shift_right_by_64_i128(odds);
let indices2 = extend_i8_to_i32_i32x8(odds_shifted);
let indices3 = extend_i8_to_i32_i32x8(evens);
let indices4 =
extend_i8_to_i32_i32x8(shift_right_by_64_i128(evens));
let unquantized1: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices1);
let unquantized2: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices2);
let unquantized3: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices3);
let unquantized4: F32x8 =
gather_scale4_f32x8(quants as *const f32, indices4);
let quan1_mask: I32x8 = if col <= ncols - 8 {
nomask
} else if col < ncols {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan2_mask: I32x8 = if col <= ncols - 16 {
nomask
} else if col < ncols - 8 {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan3_mask: I32x8 = if col <= ncols - 24 {
nomask
} else if col < ncols - 16 {
masks[(col % 8) as usize]
} else {
fullmask
};
let quan4_mask: I32x8 = if col <= ncols - 32 {
nomask
} else if col < ncols - 24 {
masks[(col % 8) as usize]
} else {
fullmask
};
let unquantized1 = and_f32x8(unquantized1, quan1_mask);
let unquantized2 = and_f32x8(unquantized2, quan2_mask);
let unquantized3 = and_f32x8(unquantized3, quan3_mask);
let unquantized4 = and_f32x8(unquantized4, quan4_mask);
(unquantized1, unquantized2, unquantized3, unquantized4)
} else {
(f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero())
}
}
}
let mut targets8: [F32x8; 4] =
[f32x8_zero(), f32x8_zero(), f32x8_zero(), f32x8_zero()];
for p in 0..src_cols_its {
let other8_0: F32x8 = load_f32(
other_data,
col,
p * 32,
other_cols,
other_rows,
other_cols_capacity,
);
let other8_1: F32x8 = load_f32(
other_data,
col,
p * 32 + 8,
other_cols,
other_rows,
other_cols_capacity,
);
let other8_2: F32x8 = load_f32(
other_data,
col,
p * 32 + 16,
other_cols,
other_rows,
other_cols_capacity,
);
let other8_3: F32x8 = load_f32(
other_data,
col,
p * 32 + 24,
other_cols,
other_rows,
other_cols_capacity,
);
let (src8_0, src8_1, src8_2, src8_3): (F32x8, F32x8, F32x8, F32x8) =
load_k4_to_f32(&src, row, p * 32, src_rows, quants.as_ptr());
targets8[0] = fma_f32x8(src8_0, other8_0, targets8[0]);
targets8[1] = fma_f32x8(src8_1, other8_1, targets8[1]);
targets8[2] = fma_f32x8(src8_2, other8_2, targets8[2]);
targets8[3] = fma_f32x8(src8_3, other8_3, targets8[3]);
}
let target0 = horizontal_sum_f32x8(targets8[0]);
let target1 = horizontal_sum_f32x8(targets8[1]);
let target2 = horizontal_sum_f32x8(targets8[2]);
let target3 = horizontal_sum_f32x8(targets8[3]);
let target = target0 + target1 + target2 + target3;
*tgt_data.add(row * self_cols_capacity + col) = target;
}
}
});
}
}

File diff suppressed because it is too large Load Diff

@ -33,7 +33,7 @@ pub struct Transformer {
}
// Clone is cheap
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct DataSettings {
#[cfg(feature = "opencl")]
use_opencl_for_feedforward: bool,
@ -43,6 +43,7 @@ pub struct DataSettings {
cl: Option<OpenCL>,
force_f16: bool,
force_k4: bool,
}
// OpenCL is safe to send to threads but Rust doesn't know that
@ -56,6 +57,7 @@ impl DataSettings {
use_opencl_for_feedforward: false,
use_opencl_for_attention: false,
force_f16: false,
force_k4: false,
cl: cl.clone(),
}
}
@ -63,7 +65,10 @@ impl DataSettings {
#[allow(clippy::new_without_default)]
#[cfg(not(feature = "opencl"))]
pub fn new() -> Self {
DataSettings { force_f16: false }
DataSettings {
force_f16: false,
force_k4: false,
}
}
#[cfg(feature = "opencl")]
@ -78,6 +83,17 @@ impl DataSettings {
pub fn force_f16(mut self) -> DataSettings {
self.force_f16 = true;
if self.force_k4 {
panic!("Cannot force both f16 and k4");
}
self
}
pub fn force_k4(mut self) -> DataSettings {
self.force_k4 = true;
if self.force_f16 {
panic!("Cannot force both f16 and k4");
}
self
}
}
@ -461,6 +477,10 @@ impl FeedForward {
w1 = w1.to_f16();
w2 = w2.to_f16();
w3 = w3.to_f16();
} else if data_settings.force_k4 {
w1 = w1.quantize();
w2 = w2.quantize();
w3 = w3.quantize();
}
#[cfg(feature = "opencl")]
@ -488,7 +508,7 @@ impl FeedForward {
pub fn forward(&self, x: &mut Tensor) -> Tensor {
let original_x_dtype = x.dtype();
if x.dtype() != self.w1.dtype() {
*x = x.to_same_type(&self.w1);
*x = x.to_compatible_matrix_mul_type(&self.w1);
}
#[cfg(feature = "opencl")]
let x_was_on_cpu: bool;
@ -513,7 +533,7 @@ impl FeedForward {
let w1_out = w1_out.silu();
let mut w1w3_out = w1_out.hadamard_product(&w3_out).transpose();
if w1w3_out.dtype() != self.w2.dtype() {
w1w3_out = w1w3_out.to_same_type(&self.w2);
w1w3_out = w1w3_out.to_compatible_matrix_mul_type(&self.w2);
}
#[cfg(not(feature = "opencl"))]
{
@ -575,6 +595,11 @@ impl Attention {
wk = wk.to_f16();
wv = wv.to_f16();
wo = wo.to_f16();
} else if data_settings.force_k4 {
wq = wq.quantize();
wk = wk.quantize();
wv = wv.quantize();
wo = wo.quantize();
}
#[cfg(feature = "opencl")]
@ -613,7 +638,7 @@ impl Attention {
) -> Tensor {
let original_x_dtype = x.dtype();
if x.dtype() != self.wq.dtype() {
*x = x.to_same_type(&self.wq);
*x = x.to_compatible_matrix_mul_type2(&self.wq);
}
#[cfg(feature = "opencl")]
@ -741,7 +766,7 @@ impl Attention {
{
let xq_row = Tensor::concat(&concat_vec2).view(1, self.wo.rows());
xq_row
.into_same_type(&self.wo)
.to_compatible_matrix_mul_type2(&self.wo)
.matrix_mul_transposed(&self.wo)
}
#[cfg(feature = "opencl")]

@ -1,5 +1,4 @@
use crate::tensor::Tensor;
use rand::{thread_rng, Rng};
use rayon::prelude::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
@ -9,17 +8,99 @@ pub fn quantize(tensor: &Tensor) -> Tensor {
* This is a simplistic rounding quantizer. It splits each row in a tensor to 16 buckets and
* takes the average value in said buckets as the quantized weight.
*/
let mut allowed_values_by_row_block: Vec<Vec<Vec<f32>>> =
Vec::with_capacity(tensor.rows() as usize);
let col_blocks = (tensor.cols() + 511) / 512;
for row in 0..tensor.rows() {
let mut block_values: Vec<Vec<f32>> = Vec::with_capacity(col_blocks as usize);
for block in 0..col_blocks {
let start = block * 512;
let end = std::cmp::min(start + 512, tensor.cols());
let mut values: Vec<f32> = Vec::with_capacity(512);
values.truncate(0);
let mut mi: f32 = std::f32::MAX;
let mut ma: f32 = std::f32::MIN;
for col in start..end {
let val = tensor.get_f32(row, col);
if val < mi {
mi = val;
}
if val > ma {
ma = val;
}
values.push(val);
}
values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
let mut allowed_values: Vec<f32> = Vec::with_capacity(16);
for i in 0..16 {
let start_idx = i * values.len() / 16;
let end_idx = (i + 1) * values.len() / 16;
let mut avg = 0.0;
for j in start_idx..end_idx {
avg += values[j];
}
avg /= (end_idx - start_idx) as f32;
allowed_values.push(avg);
}
allowed_values[0] = mi;
allowed_values[15] = ma;
allowed_values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
block_values.push(allowed_values);
}
allowed_values_by_row_block.push(block_values);
}
let result = Tensor::make_k4bit_from_fn(
tensor.rows(),
tensor.cols(),
|row, col| {
let col_block = col / 512;
let allowed_values: &[f32] =
&allowed_values_by_row_block[row as usize][col_block as usize];
let val = tensor.get_f32(row, col);
let mut best = 0;
let mut best_dist = std::f32::MAX;
for i in 0..16 {
let dist = (val - allowed_values[i] as f32).abs();
if dist < best_dist {
best = i;
best_dist = dist;
}
}
best as u8
},
|row: i64, col: i64| -> [f32; 16] {
let allowed_values: &[f32] =
&allowed_values_by_row_block[row as usize][col as usize / 512];
let mut result: [f32; 16] = [0.0; 16];
for i in 0..16 {
result[i] = allowed_values[i];
}
result
},
);
result
}
// Same as quantize but doesn't actually change the type of the tensor. It just changes the tensor
// itself. Used to test new quantization schemes without writing support for them.
pub fn quantize_test(tensor: &Tensor) -> Tensor {
let mut result = Tensor::zeros(tensor.rows(), tensor.cols(), tensor.dtype());
for row in 0..tensor.rows() {
let col_blocks = (tensor.cols() + 511) / 512;
for block in 0..col_blocks {
let mut values: Vec<f32> = Vec::with_capacity(tensor.cols() as usize);
if row % 500 == 0 {
println!("{}", row,);
}
values.truncate(0);
let mut mi: f32 = std::f32::MAX;
let mut ma: f32 = std::f32::MIN;
for col in 0..tensor.cols() {
let start = block * 512;
let end = std::cmp::min(start + 512, tensor.cols());
for col in start..end {
let val = tensor.get_f32(row, col);
if val < mi {
mi = val;
@ -31,7 +112,6 @@ pub fn quantize(tensor: &Tensor) -> Tensor {
}
values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
let mut allowed_values: Vec<f32> = Vec::with_capacity(16);
let mut rng = thread_rng();
for i in 0..16 {
let start_idx = i * values.len() / 16;
let end_idx = (i + 1) * values.len() / 16;
@ -47,7 +127,7 @@ pub fn quantize(tensor: &Tensor) -> Tensor {
allowed_values[15] = ma;
allowed_values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
for col in 0..tensor.cols() {
for col in start..end {
let val = tensor.get_f32(row, col);
let mut best = 0;
let mut best_dist = std::f32::MAX;
@ -58,7 +138,8 @@ pub fn quantize(tensor: &Tensor) -> Tensor {
best_dist = dist;
}
}
result.set_f32(row, col, allowed_values[best] as f32);
result.set_f32(row, col, allowed_values[best as usize]);
}
}
}
result

Loading…
Cancel
Save