diff --git a/Cargo.lock b/Cargo.lock index a15feda..065ab25 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -125,6 +125,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "487f1e0fcbe47deb8b0574e646def1c903389d95241dd1bbcc6ce4a715dfc0c1" + [[package]] name = "block-buffer" version = "0.9.0" @@ -215,7 +221,7 @@ version = "3.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71655c45cb9845d3270c9d6df84ebe72b4dad3c2ba3f7023ad47c144e4e473a5" dependencies = [ - "bitflags", + "bitflags 1.3.2", "clap_lex 0.2.4", "indexmap", "textwrap", @@ -223,11 +229,11 @@ dependencies = [ [[package]] name = "clap" -version = "4.1.10" +version = "4.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce38afc168d8665cfc75c7b1dd9672e50716a137f433f070991619744a67342a" +checksum = "42dfd32784433290c51d92c438bb72ea5063797fc3cc9a21a8c4346bebbb2098" dependencies = [ - "bitflags", + "bitflags 2.0.2", "clap_derive", "clap_lex 0.3.3", "is-terminal", @@ -476,7 +482,7 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d1053e9d5d5aade9bcedb5ab53b78df2b56ff9408a3138ce77eaaef87f932373" dependencies = [ - "bitflags", + "bitflags 1.3.2", "proc-macro2 0.4.30", "quote 0.6.13", "syn 0.15.44", @@ -749,9 +755,9 @@ dependencies = [ [[package]] name = "io-lifetimes" -version = "1.0.7" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76e86b86ae312accbf05ade23ce76b625e0e47a255712b7414037385a1c05380" +checksum = "09270fd4fa1111bc614ed2246c7ef56239a3063d5be0d1ec3b589c505d400aeb" dependencies = [ "hermit-abi 0.3.1", "libc", @@ -760,9 +766,9 @@ dependencies = [ [[package]] name = "is-terminal" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b6b32576413a8e69b90e952e4a026476040d81017b80445deda5f2d3921857" +checksum = "8687c819457e979cc940d09cb16e42a1bf70aa6b60a549de6d3a62a0ee90c69e" dependencies = [ "hermit-abi 0.3.1", "io-lifetimes", @@ -954,7 +960,7 @@ version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66d46a57dfd1bb6fbee28986e92d39db9b941719dcf6b539c64242006d3c030d" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cl-sys", "enum_primitive", "num-complex", @@ -993,9 +999,9 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" [[package]] name = "os_str_bytes" -version = "6.4.1" +version = "6.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" +checksum = "ceedf44fb00f2d1984b0bc98102627ce622e083e49a5bacdb3e514fa4238e267" [[package]] name = "pear" @@ -1261,7 +1267,7 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -1286,7 +1292,7 @@ name = "rllama" version = "0.3.0" dependencies = [ "approx", - "clap 4.1.10", + "clap 4.1.11", "colored", "criterion", "embedded-profiling", @@ -1371,11 +1377,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.36.9" +version = "0.36.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd5c6ff11fecd55b40746d1995a02f2eb375bf8c00d192d521ee09f42bef37bc" +checksum = "db4165c9963ab29e422d6c26fbc1d37f15bace6b2810221f9d925023480fcf0e" dependencies = [ - "bitflags", + "bitflags 1.3.2", "errno", "io-lifetimes", "libc", @@ -1418,22 +1424,22 @@ checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" [[package]] name = "serde" -version = "1.0.157" +version = "1.0.158" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "707de5fcf5df2b5788fca98dd7eab490bc2fd9b7ef1404defc462833b83f25ca" +checksum = "771d4d9c4163ee138805e12c710dd365e4f44be8be0503cb1bb9eb989425d9c9" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.157" +version = "1.0.158" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78997f4555c22a7971214540c4a661291970619afd56de19f77e0de86296e1e5" +checksum = "e801c1712f48475582b7696ac71e0ca34ebb30e09338425384269d9717c62cad" dependencies = [ "proc-macro2 1.0.52", "quote 1.0.26", - "syn 2.0.0", + "syn 2.0.4", ] [[package]] @@ -1508,9 +1514,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.0" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cff13bb1732bccfe3b246f3fdb09edfd51c01d6f5299b7ccd9457c2e4e37774" +checksum = "2c622ae390c9302e214c31013517c2061ecb2699935882c60a9b37f82f8625ae" dependencies = [ "proc-macro2 1.0.52", "quote 1.0.26", @@ -1562,7 +1568,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2 1.0.52", "quote 1.0.26", - "syn 2.0.0", + "syn 2.0.4", ] [[package]] @@ -1639,9 +1645,9 @@ dependencies = [ [[package]] name = "unicode-bidi" -version = "0.3.12" +version = "0.3.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d502c968c6a838ead8e69b2ee18ec708802f99db92a0d156705ec9ef801993b" +checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" diff --git a/src/lib.rs b/src/lib.rs index 3232002..612ec90 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ pub mod token_sampler; pub mod tokenizer; pub mod transformer; pub mod unpickler; +pub mod weight_compression; #[cfg(feature = "server")] #[macro_use] extern crate rocket; diff --git a/src/semaphore.rs b/src/semaphore.rs new file mode 100644 index 0000000..b69a7aa --- /dev/null +++ b/src/semaphore.rs @@ -0,0 +1,38 @@ +// There is no semaphore in Rust standard library. wat?? +// So I've made a simple one I can use out of a mutex and condition variable.. + +use std::sync::{Arc, Condvar, Mutex, MutexGuard}; + +#[derive(Clone)] +pub struct Semaphore { + count: Arc>, + waiters: Arc, +} + +pub struct SemaphoreGuard<'a> { + mutex_guard: MutexGuard<'a, usize>, +} + +impl<'a> Drop for SemaphoreGuard<'a> { + fn drop(&mut self) { + *self.mutex_guard += 1; + } +} + +impl Semaphore { + pub fn new(count: usize) -> Semaphore { + Semaphore { + count: Arc::new(Mutex::new(count)), + waiters: Arc::new(Condvar::new()), + } + } + + pub fn acquire(&self) -> SemaphoreGuard { + let mut count = self.count.lock().unwrap(); + while *count == 0 { + count = self.waiters.wait(count).unwrap(); + } + *count -= 1; + SemaphoreGuard { mutex_guard: count } + } +} diff --git a/src/tensor.rs b/src/tensor.rs index f026161..803c58b 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -46,6 +46,7 @@ pub struct TensorBuilder { #[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] pub enum TensorDType { + K4BitQuantization, Float16, Float32, } @@ -70,10 +71,17 @@ pub enum TensorError { } impl TensorDType { - pub fn bytes_per_item(&self) -> usize { + pub fn bytes_for_nvalues(&self, nvalues: usize) -> usize { match self { - Self::Float16 => 2, - Self::Float32 => 4, + Self::K4BitQuantization => { + if nvalues % 2 == 1 { + nvalues / 2 + 1 + } else { + nvalues / 2 + } + } + Self::Float16 => nvalues * 2, + Self::Float32 => nvalues * 4, } } } @@ -81,6 +89,11 @@ impl TensorDType { #[derive(Debug)] pub struct Tensor { data: *mut u8, + + // for quantization, only used if dtype == TensorDType::K4BitQuantization + // Contains 16 values per row in f16 (i.e. 32 bytes per row) + q4_data: *mut u8, + #[cfg(feature = "opencl")] opencl_data: Arc>>, #[cfg(feature = "opencl")] @@ -88,6 +101,7 @@ pub struct Tensor { dtype: TensorDType, layout: Layout, + q4_layout: Layout, rows: i64, cols: i64, // Every matrix is allocated so that cols are rounded to the next multiple of 32. @@ -117,8 +131,16 @@ impl Clone for Tensor { std::ptr::copy_nonoverlapping( self.data, new_tensor.data, - (self.rows * self.capacity_cols * self.dtype.bytes_per_item() as i64) as usize, + (self.rows * self.dtype.bytes_for_nvalues(self.capacity_cols as usize) as i64) + as usize, ); + if !self.q4_data.is_null() { + std::ptr::copy_nonoverlapping( + self.q4_data, + new_tensor.q4_data, + self.q4_layout.size(), + ); + } new_tensor } } @@ -141,6 +163,9 @@ impl Drop for Tensor { .fetch_sub(self.layout.size(), std::sync::atomic::Ordering::Relaxed); std::alloc::dealloc(self.data, self.layout); } + if !self.q4_data.is_null() { + std::alloc::dealloc(self.q4_data, self.q4_layout); + } } } } @@ -164,11 +189,20 @@ impl WrappedPtr { fn compute_capacity_cols(dtype: TensorDType, cols: i64) -> i64 { match dtype { + TensorDType::K4BitQuantization => compute_capacity_cols_k4(cols), TensorDType::Float16 => compute_capacity_cols_f16(cols), TensorDType::Float32 => compute_capacity_cols_f32(cols), } } +fn compute_capacity_cols_k4(cols: i64) -> i64 { + if cols % 64 == 0 { + cols + } else { + cols + 64 - cols % 64 + } +} + fn compute_capacity_cols_f32(cols: i64) -> i64 { if cols % 8 == 0 { cols @@ -277,6 +311,7 @@ impl Tensor { let idx = row * self.capacity_cols + col; match self.dtype { + TensorDType::K4BitQuantization => unimplemented!(), TensorDType::Float16 => { let val: f16 = unsafe { *(self.data.add(idx as usize * 2) as *const f16) }; val.to_f32() @@ -294,6 +329,7 @@ impl Tensor { self.assume_on_cpu(); let idx = row * self.capacity_cols + col; match self.dtype { + TensorDType::K4BitQuantization => unimplemented!(), TensorDType::Float16 => { let val: f16 = f16::from_f32(val); unsafe { *(self.data.add(idx as usize * 2) as *mut f16) = val }; @@ -323,12 +359,14 @@ impl Tensor { pub fn empty() -> Self { Self { data: std::ptr::null_mut(), + q4_data: std::ptr::null_mut(), #[cfg(feature = "opencl")] opencl_data: Arc::new(RwLock::new(None)), #[cfg(feature = "opencl")] waiting_for_data: None, dtype: TensorDType::Float16, - layout: Layout::from_size_align(0, 0).unwrap(), + layout: Layout::from_size_align(1, 1).unwrap(), + q4_layout: Layout::from_size_align(1, 1).unwrap(), rows: 0, cols: 0, capacity_cols: 0, @@ -346,8 +384,7 @@ impl Tensor { // Rouns up cols to 8 let capacity_cols = compute_capacity_cols(dtype, cols); let nitems = rows * capacity_cols; - let layout = - Layout::from_size_align((nitems as usize) * dtype.bytes_per_item(), 32).unwrap(); + let layout = Layout::from_size_align(dtype.bytes_for_nvalues(nitems as usize), 32).unwrap(); let data = unsafe { std::alloc::alloc(layout) }; if data.is_null() { panic!("Failed to allocate tensor"); @@ -360,6 +397,7 @@ impl Tensor { for row in 0..rows { let idx = row * capacity_cols + extra_col; match dtype { + TensorDType::K4BitQuantization => unimplemented!(), TensorDType::Float16 => { let val: f16 = f16::from_f32(0.0); unsafe { *(data.add(idx as usize * 2) as *mut f16) = val }; @@ -373,6 +411,7 @@ impl Tensor { Self { data, + q4_data: std::ptr::null_mut(), #[cfg(feature = "opencl")] opencl_data: Arc::new(RwLock::new(None)), #[cfg(feature = "opencl")] @@ -382,6 +421,7 @@ impl Tensor { cols, capacity_cols, layout, + q4_layout: Layout::from_size_align(1, 1).unwrap(), } } @@ -889,6 +929,7 @@ impl Tensor { } match src.dtype { + TensorDType::K4BitQuantization => unimplemented!(), TensorDType::Float32 => { // not actual cache line size, but this represents 8 floats which is the number we can // operate with AVX2 @@ -1043,6 +1084,7 @@ impl Tensor { return result; } match other.dtype { + TensorDType::K4BitQuantization => unimplemented!(), TensorDType::Float32 => self.to_f32(), TensorDType::Float16 => self.to_f16(), } @@ -1053,6 +1095,7 @@ impl Tensor { return self; } match other.dtype { + TensorDType::K4BitQuantization => unimplemented!(), TensorDType::Float32 => self.to_f32(), TensorDType::Float16 => self.to_f16(), } @@ -1060,6 +1103,7 @@ impl Tensor { pub fn into_dtype(self, dtype: TensorDType) -> Tensor { match dtype { + TensorDType::K4BitQuantization => unimplemented!(), TensorDType::Float32 => self.to_f32(), TensorDType::Float16 => self.to_f16(), } @@ -1114,6 +1158,7 @@ impl Tensor { } match src.dtype { + TensorDType::K4BitQuantization => unimplemented!(), TensorDType::Float32 => { const ITEMS_PER_LINE: usize = 8; @@ -1847,8 +1892,7 @@ impl Tensor { } let capacity_cols = compute_capacity_cols(dtype, cols); let nitems = rows * capacity_cols; - let layout = - Layout::from_size_align((nitems as usize) * dtype.bytes_per_item(), 32).unwrap(); + let layout = Layout::from_size_align(dtype.bytes_for_nvalues(nitems as usize), 32).unwrap(); let data = unsafe { std::alloc::alloc_zeroed(layout) }; if data.is_null() { panic!("Failed to allocate tensor"); @@ -1856,6 +1900,7 @@ impl Tensor { TENSORS_BYTES_ALLOCATED.fetch_add(layout.size(), std::sync::atomic::Ordering::Relaxed); Self { data, + q4_data: std::ptr::null_mut(), #[cfg(feature = "opencl")] opencl_data: Arc::new(RwLock::new(None)), #[cfg(feature = "opencl")] @@ -1865,6 +1910,7 @@ impl Tensor { cols, capacity_cols, layout, + q4_layout: Layout::from_size_align(1, 1).unwrap(), } } @@ -1880,12 +1926,14 @@ impl Tensor { unsafe { std::ptr::copy_nonoverlapping( self.data.add( - (row * self.capacity_cols * self.dtype.bytes_per_item() as i64) as usize, + (row * self.dtype.bytes_for_nvalues(self.capacity_cols as usize) as i64) + as usize, ), result.data.add( - (row * result.capacity_cols * self.dtype.bytes_per_item() as i64) as usize, + (row * self.dtype.bytes_for_nvalues(result.capacity_cols as usize) as i64) + as usize, ), - cols * self.dtype.bytes_per_item(), + self.dtype.bytes_for_nvalues(cols), ); } } @@ -1908,6 +1956,7 @@ impl Tensor { result.rows = rows; result.cols = cols; match self.dtype { + TensorDType::K4BitQuantization => unimplemented!(), TensorDType::Float16 => { let mut tgt_row: usize = 0; let mut tgt_col: usize = 0; @@ -2154,9 +2203,9 @@ impl Tensor { unsafe { std::ptr::copy_nonoverlapping( self.data - .add((row * self.capacity_cols) as usize * self.dtype.bytes_per_item()), + .add(row as usize * self.dtype.bytes_for_nvalues(self.capacity_cols as usize)), result.data, - self.cols as usize * self.dtype.bytes_per_item(), + self.dtype.bytes_for_nvalues(self.cols as usize), ); } result @@ -2187,16 +2236,16 @@ impl TensorBuilder { let mut f = std::fs::File::open(&path).unwrap(); f.seek(std::io::SeekFrom::Start( - (self.offset as u64) * self.dtype.bytes_per_item() as u64, + self.dtype.bytes_for_nvalues(self.offset as usize) as u64, ))?; let mut cursor: usize = 0; - let mut buf: Vec = vec![0; self.cols as usize * self.dtype.bytes_per_item()]; + let mut buf: Vec = vec![0; self.dtype.bytes_for_nvalues(self.cols as usize)]; for _row in 0..self.rows { f.read_exact(&mut buf)?; unsafe { std::ptr::copy_nonoverlapping(buf.as_ptr(), tensor.data.add(cursor), buf.len()); } - cursor += tensor.capacity_cols as usize * self.dtype.bytes_per_item(); + cursor += self.dtype.bytes_for_nvalues(tensor.capacity_cols as usize); } Ok(tensor.to_f32()) } @@ -2251,10 +2300,10 @@ impl TensorBuilder { .join("data") .join(&builder.src_path); buf.truncate(0); - buf.resize(builder.cols as usize * builder.dtype.bytes_per_item(), 0); + buf.resize(builder.dtype.bytes_for_nvalues(builder.cols as usize), 0); let mut f = std::fs::File::open(&path).unwrap(); f.seek(std::io::SeekFrom::Start( - (builder.offset as u64) * builder.dtype.bytes_per_item() as u64, + builder.dtype.bytes_for_nvalues(builder.offset as usize) as u64, ))?; for row in 0..builder.rows { match f.read_exact(&mut buf) { @@ -2275,10 +2324,9 @@ impl TensorBuilder { unsafe { std::ptr::copy_nonoverlapping( buf.as_ptr(), - tensor.data.add( - ((row * tensor.capacity_cols + col_offset) as usize) - * builder.dtype.bytes_per_item(), - ), + tensor.data.add(builder.dtype.bytes_for_nvalues( + (row * tensor.capacity_cols + col_offset) as usize, + )), buf.len(), ); } @@ -2326,10 +2374,10 @@ impl TensorBuilder { .join("data") .join(&builder.src_path); buf.truncate(0); - buf.resize(builder.cols as usize * builder.dtype.bytes_per_item(), 0); + buf.resize(builder.dtype.bytes_for_nvalues(builder.cols as usize), 0); let mut f = std::fs::File::open(&path).unwrap(); f.seek(std::io::SeekFrom::Start( - (builder.offset as u64) * builder.dtype.bytes_per_item() as u64, + builder.dtype.bytes_for_nvalues(builder.offset as usize) as u64, ))?; for row in 0..builder.rows { match f.read_exact(&mut buf) { @@ -2350,10 +2398,9 @@ impl TensorBuilder { unsafe { std::ptr::copy_nonoverlapping( buf.as_ptr(), - tensor.data.add( - (((row + row_offset) * tensor.capacity_cols) as usize) - * builder.dtype.bytes_per_item(), - ), + tensor.data.add(builder.dtype.bytes_for_nvalues( + ((row + row_offset) * tensor.capacity_cols) as usize, + )), buf.len(), ); } diff --git a/src/token_sampler.rs b/src/token_sampler.rs index 138aaf3..e4072c1 100644 --- a/src/token_sampler.rs +++ b/src/token_sampler.rs @@ -158,7 +158,11 @@ impl TokenSampler { total_p += v.1; } let mut rng = rand::thread_rng(); - let p: f32 = rng.gen_range(0.0..=total_p); + let p: f32 = if total_p > 0.0 { + rng.gen_range(0.0..=total_p) + } else { + 0.0 + }; p_accum = 0.0; for v in logitsf.into_iter() { p_accum += v.1; diff --git a/src/transformer.rs b/src/transformer.rs index feb99d1..f6f823d 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -415,6 +415,7 @@ impl RMSNorm { FromPiecesDirection::Rows, )? .to_f32(); + Ok(Self { eps, weight: weights, diff --git a/src/weight_compression.rs b/src/weight_compression.rs new file mode 100644 index 0000000..92f59f3 --- /dev/null +++ b/src/weight_compression.rs @@ -0,0 +1,65 @@ +use crate::tensor::Tensor; +use rand::{thread_rng, Rng}; +use rayon::prelude::*; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, RwLock}; + +pub fn quantize(tensor: &Tensor) -> Tensor { + /* + * This is a simplistic rounding quantizer. It splits each row in a tensor to 16 buckets and + * takes the average value in said buckets as the quantized weight. + */ + let mut result = Tensor::zeros(tensor.rows(), tensor.cols(), tensor.dtype()); + for row in 0..tensor.rows() { + let mut values: Vec = Vec::with_capacity(tensor.cols() as usize); + if row % 500 == 0 { + println!("{}", row,); + } + values.truncate(0); + let mut mi: f32 = std::f32::MAX; + let mut ma: f32 = std::f32::MIN; + + for col in 0..tensor.cols() { + let val = tensor.get_f32(row, col); + if val < mi { + mi = val; + } + if val > ma { + ma = val; + } + values.push(val); + } + values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + let mut allowed_values: Vec = Vec::with_capacity(16); + let mut rng = thread_rng(); + for i in 0..16 { + let start_idx = i * values.len() / 16; + let end_idx = (i + 1) * values.len() / 16; + + let mut avg = 0.0; + for j in start_idx..end_idx { + avg += values[j]; + } + avg /= (end_idx - start_idx) as f32; + allowed_values.push(avg); + } + allowed_values[0] = mi; + allowed_values[15] = ma; + allowed_values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + + for col in 0..tensor.cols() { + let val = tensor.get_f32(row, col); + let mut best = 0; + let mut best_dist = std::f32::MAX; + for i in 0..16 { + let dist = (val - allowed_values[i] as f32).abs(); + if dist < best_dist { + best = i; + best_dist = dist; + } + } + result.set_f32(row, col, allowed_values[best] as f32); + } + } + result +}