Add skeleton code for 4-bit quantization.

The type is now recognized and I have a very simple quantizer too but no
operations are done yet.
master
Mikko Juola 3 years ago
parent 26f343ad15
commit f6249e8d9f

60
Cargo.lock generated

@ -125,6 +125,12 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "487f1e0fcbe47deb8b0574e646def1c903389d95241dd1bbcc6ce4a715dfc0c1"
[[package]] [[package]]
name = "block-buffer" name = "block-buffer"
version = "0.9.0" version = "0.9.0"
@ -215,7 +221,7 @@ version = "3.2.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71655c45cb9845d3270c9d6df84ebe72b4dad3c2ba3f7023ad47c144e4e473a5" checksum = "71655c45cb9845d3270c9d6df84ebe72b4dad3c2ba3f7023ad47c144e4e473a5"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"clap_lex 0.2.4", "clap_lex 0.2.4",
"indexmap", "indexmap",
"textwrap", "textwrap",
@ -223,11 +229,11 @@ dependencies = [
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.1.10" version = "4.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce38afc168d8665cfc75c7b1dd9672e50716a137f433f070991619744a67342a" checksum = "42dfd32784433290c51d92c438bb72ea5063797fc3cc9a21a8c4346bebbb2098"
dependencies = [ dependencies = [
"bitflags", "bitflags 2.0.2",
"clap_derive", "clap_derive",
"clap_lex 0.3.3", "clap_lex 0.3.3",
"is-terminal", "is-terminal",
@ -476,7 +482,7 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1053e9d5d5aade9bcedb5ab53b78df2b56ff9408a3138ce77eaaef87f932373" checksum = "d1053e9d5d5aade9bcedb5ab53b78df2b56ff9408a3138ce77eaaef87f932373"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"proc-macro2 0.4.30", "proc-macro2 0.4.30",
"quote 0.6.13", "quote 0.6.13",
"syn 0.15.44", "syn 0.15.44",
@ -749,9 +755,9 @@ dependencies = [
[[package]] [[package]]
name = "io-lifetimes" name = "io-lifetimes"
version = "1.0.7" version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76e86b86ae312accbf05ade23ce76b625e0e47a255712b7414037385a1c05380" checksum = "09270fd4fa1111bc614ed2246c7ef56239a3063d5be0d1ec3b589c505d400aeb"
dependencies = [ dependencies = [
"hermit-abi 0.3.1", "hermit-abi 0.3.1",
"libc", "libc",
@ -760,9 +766,9 @@ dependencies = [
[[package]] [[package]]
name = "is-terminal" name = "is-terminal"
version = "0.4.4" version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21b6b32576413a8e69b90e952e4a026476040d81017b80445deda5f2d3921857" checksum = "8687c819457e979cc940d09cb16e42a1bf70aa6b60a549de6d3a62a0ee90c69e"
dependencies = [ dependencies = [
"hermit-abi 0.3.1", "hermit-abi 0.3.1",
"io-lifetimes", "io-lifetimes",
@ -954,7 +960,7 @@ version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "66d46a57dfd1bb6fbee28986e92d39db9b941719dcf6b539c64242006d3c030d" checksum = "66d46a57dfd1bb6fbee28986e92d39db9b941719dcf6b539c64242006d3c030d"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"cl-sys", "cl-sys",
"enum_primitive", "enum_primitive",
"num-complex", "num-complex",
@ -993,9 +999,9 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
[[package]] [[package]]
name = "os_str_bytes" name = "os_str_bytes"
version = "6.4.1" version = "6.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" checksum = "ceedf44fb00f2d1984b0bc98102627ce622e083e49a5bacdb3e514fa4238e267"
[[package]] [[package]]
name = "pear" name = "pear"
@ -1261,7 +1267,7 @@ version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
] ]
[[package]] [[package]]
@ -1286,7 +1292,7 @@ name = "rllama"
version = "0.3.0" version = "0.3.0"
dependencies = [ dependencies = [
"approx", "approx",
"clap 4.1.10", "clap 4.1.11",
"colored", "colored",
"criterion", "criterion",
"embedded-profiling", "embedded-profiling",
@ -1371,11 +1377,11 @@ dependencies = [
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "0.36.9" version = "0.36.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd5c6ff11fecd55b40746d1995a02f2eb375bf8c00d192d521ee09f42bef37bc" checksum = "db4165c9963ab29e422d6c26fbc1d37f15bace6b2810221f9d925023480fcf0e"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"errno", "errno",
"io-lifetimes", "io-lifetimes",
"libc", "libc",
@ -1418,22 +1424,22 @@ checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.157" version = "1.0.158"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "707de5fcf5df2b5788fca98dd7eab490bc2fd9b7ef1404defc462833b83f25ca" checksum = "771d4d9c4163ee138805e12c710dd365e4f44be8be0503cb1bb9eb989425d9c9"
dependencies = [ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.157" version = "1.0.158"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78997f4555c22a7971214540c4a661291970619afd56de19f77e0de86296e1e5" checksum = "e801c1712f48475582b7696ac71e0ca34ebb30e09338425384269d9717c62cad"
dependencies = [ dependencies = [
"proc-macro2 1.0.52", "proc-macro2 1.0.52",
"quote 1.0.26", "quote 1.0.26",
"syn 2.0.0", "syn 2.0.4",
] ]
[[package]] [[package]]
@ -1508,9 +1514,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.0" version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cff13bb1732bccfe3b246f3fdb09edfd51c01d6f5299b7ccd9457c2e4e37774" checksum = "2c622ae390c9302e214c31013517c2061ecb2699935882c60a9b37f82f8625ae"
dependencies = [ dependencies = [
"proc-macro2 1.0.52", "proc-macro2 1.0.52",
"quote 1.0.26", "quote 1.0.26",
@ -1562,7 +1568,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f"
dependencies = [ dependencies = [
"proc-macro2 1.0.52", "proc-macro2 1.0.52",
"quote 1.0.26", "quote 1.0.26",
"syn 2.0.0", "syn 2.0.4",
] ]
[[package]] [[package]]
@ -1639,9 +1645,9 @@ dependencies = [
[[package]] [[package]]
name = "unicode-bidi" name = "unicode-bidi"
version = "0.3.12" version = "0.3.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d502c968c6a838ead8e69b2ee18ec708802f99db92a0d156705ec9ef801993b" checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460"
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"

@ -13,6 +13,7 @@ pub mod token_sampler;
pub mod tokenizer; pub mod tokenizer;
pub mod transformer; pub mod transformer;
pub mod unpickler; pub mod unpickler;
pub mod weight_compression;
#[cfg(feature = "server")] #[cfg(feature = "server")]
#[macro_use] #[macro_use]
extern crate rocket; extern crate rocket;

@ -0,0 +1,38 @@
// There is no semaphore in Rust standard library. wat??
// So I've made a simple one I can use out of a mutex and condition variable..
use std::sync::{Arc, Condvar, Mutex, MutexGuard};
#[derive(Clone)]
pub struct Semaphore {
count: Arc<Mutex<usize>>,
waiters: Arc<Condvar>,
}
pub struct SemaphoreGuard<'a> {
mutex_guard: MutexGuard<'a, usize>,
}
impl<'a> Drop for SemaphoreGuard<'a> {
fn drop(&mut self) {
*self.mutex_guard += 1;
}
}
impl Semaphore {
pub fn new(count: usize) -> Semaphore {
Semaphore {
count: Arc::new(Mutex::new(count)),
waiters: Arc::new(Condvar::new()),
}
}
pub fn acquire(&self) -> SemaphoreGuard {
let mut count = self.count.lock().unwrap();
while *count == 0 {
count = self.waiters.wait(count).unwrap();
}
*count -= 1;
SemaphoreGuard { mutex_guard: count }
}
}

@ -46,6 +46,7 @@ pub struct TensorBuilder {
#[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] #[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub enum TensorDType { pub enum TensorDType {
K4BitQuantization,
Float16, Float16,
Float32, Float32,
} }
@ -70,10 +71,17 @@ pub enum TensorError {
} }
impl TensorDType { impl TensorDType {
pub fn bytes_per_item(&self) -> usize { pub fn bytes_for_nvalues(&self, nvalues: usize) -> usize {
match self { match self {
Self::Float16 => 2, Self::K4BitQuantization => {
Self::Float32 => 4, if nvalues % 2 == 1 {
nvalues / 2 + 1
} else {
nvalues / 2
}
}
Self::Float16 => nvalues * 2,
Self::Float32 => nvalues * 4,
} }
} }
} }
@ -81,6 +89,11 @@ impl TensorDType {
#[derive(Debug)] #[derive(Debug)]
pub struct Tensor { pub struct Tensor {
data: *mut u8, data: *mut u8,
// for quantization, only used if dtype == TensorDType::K4BitQuantization
// Contains 16 values per row in f16 (i.e. 32 bytes per row)
q4_data: *mut u8,
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
opencl_data: Arc<RwLock<Option<OpenCLTensor>>>, opencl_data: Arc<RwLock<Option<OpenCLTensor>>>,
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
@ -88,6 +101,7 @@ pub struct Tensor {
dtype: TensorDType, dtype: TensorDType,
layout: Layout, layout: Layout,
q4_layout: Layout,
rows: i64, rows: i64,
cols: i64, cols: i64,
// Every matrix is allocated so that cols are rounded to the next multiple of 32. // Every matrix is allocated so that cols are rounded to the next multiple of 32.
@ -117,8 +131,16 @@ impl Clone for Tensor {
std::ptr::copy_nonoverlapping( std::ptr::copy_nonoverlapping(
self.data, self.data,
new_tensor.data, new_tensor.data,
(self.rows * self.capacity_cols * self.dtype.bytes_per_item() as i64) as usize, (self.rows * self.dtype.bytes_for_nvalues(self.capacity_cols as usize) as i64)
as usize,
);
if !self.q4_data.is_null() {
std::ptr::copy_nonoverlapping(
self.q4_data,
new_tensor.q4_data,
self.q4_layout.size(),
); );
}
new_tensor new_tensor
} }
} }
@ -141,6 +163,9 @@ impl Drop for Tensor {
.fetch_sub(self.layout.size(), std::sync::atomic::Ordering::Relaxed); .fetch_sub(self.layout.size(), std::sync::atomic::Ordering::Relaxed);
std::alloc::dealloc(self.data, self.layout); std::alloc::dealloc(self.data, self.layout);
} }
if !self.q4_data.is_null() {
std::alloc::dealloc(self.q4_data, self.q4_layout);
}
} }
} }
} }
@ -164,11 +189,20 @@ impl WrappedPtr {
fn compute_capacity_cols(dtype: TensorDType, cols: i64) -> i64 { fn compute_capacity_cols(dtype: TensorDType, cols: i64) -> i64 {
match dtype { match dtype {
TensorDType::K4BitQuantization => compute_capacity_cols_k4(cols),
TensorDType::Float16 => compute_capacity_cols_f16(cols), TensorDType::Float16 => compute_capacity_cols_f16(cols),
TensorDType::Float32 => compute_capacity_cols_f32(cols), TensorDType::Float32 => compute_capacity_cols_f32(cols),
} }
} }
fn compute_capacity_cols_k4(cols: i64) -> i64 {
if cols % 64 == 0 {
cols
} else {
cols + 64 - cols % 64
}
}
fn compute_capacity_cols_f32(cols: i64) -> i64 { fn compute_capacity_cols_f32(cols: i64) -> i64 {
if cols % 8 == 0 { if cols % 8 == 0 {
cols cols
@ -277,6 +311,7 @@ impl Tensor {
let idx = row * self.capacity_cols + col; let idx = row * self.capacity_cols + col;
match self.dtype { match self.dtype {
TensorDType::K4BitQuantization => unimplemented!(),
TensorDType::Float16 => { TensorDType::Float16 => {
let val: f16 = unsafe { *(self.data.add(idx as usize * 2) as *const f16) }; let val: f16 = unsafe { *(self.data.add(idx as usize * 2) as *const f16) };
val.to_f32() val.to_f32()
@ -294,6 +329,7 @@ impl Tensor {
self.assume_on_cpu(); self.assume_on_cpu();
let idx = row * self.capacity_cols + col; let idx = row * self.capacity_cols + col;
match self.dtype { match self.dtype {
TensorDType::K4BitQuantization => unimplemented!(),
TensorDType::Float16 => { TensorDType::Float16 => {
let val: f16 = f16::from_f32(val); let val: f16 = f16::from_f32(val);
unsafe { *(self.data.add(idx as usize * 2) as *mut f16) = val }; unsafe { *(self.data.add(idx as usize * 2) as *mut f16) = val };
@ -323,12 +359,14 @@ impl Tensor {
pub fn empty() -> Self { pub fn empty() -> Self {
Self { Self {
data: std::ptr::null_mut(), data: std::ptr::null_mut(),
q4_data: std::ptr::null_mut(),
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
opencl_data: Arc::new(RwLock::new(None)), opencl_data: Arc::new(RwLock::new(None)),
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
waiting_for_data: None, waiting_for_data: None,
dtype: TensorDType::Float16, dtype: TensorDType::Float16,
layout: Layout::from_size_align(0, 0).unwrap(), layout: Layout::from_size_align(1, 1).unwrap(),
q4_layout: Layout::from_size_align(1, 1).unwrap(),
rows: 0, rows: 0,
cols: 0, cols: 0,
capacity_cols: 0, capacity_cols: 0,
@ -346,8 +384,7 @@ impl Tensor {
// Rouns up cols to 8 // Rouns up cols to 8
let capacity_cols = compute_capacity_cols(dtype, cols); let capacity_cols = compute_capacity_cols(dtype, cols);
let nitems = rows * capacity_cols; let nitems = rows * capacity_cols;
let layout = let layout = Layout::from_size_align(dtype.bytes_for_nvalues(nitems as usize), 32).unwrap();
Layout::from_size_align((nitems as usize) * dtype.bytes_per_item(), 32).unwrap();
let data = unsafe { std::alloc::alloc(layout) }; let data = unsafe { std::alloc::alloc(layout) };
if data.is_null() { if data.is_null() {
panic!("Failed to allocate tensor"); panic!("Failed to allocate tensor");
@ -360,6 +397,7 @@ impl Tensor {
for row in 0..rows { for row in 0..rows {
let idx = row * capacity_cols + extra_col; let idx = row * capacity_cols + extra_col;
match dtype { match dtype {
TensorDType::K4BitQuantization => unimplemented!(),
TensorDType::Float16 => { TensorDType::Float16 => {
let val: f16 = f16::from_f32(0.0); let val: f16 = f16::from_f32(0.0);
unsafe { *(data.add(idx as usize * 2) as *mut f16) = val }; unsafe { *(data.add(idx as usize * 2) as *mut f16) = val };
@ -373,6 +411,7 @@ impl Tensor {
Self { Self {
data, data,
q4_data: std::ptr::null_mut(),
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
opencl_data: Arc::new(RwLock::new(None)), opencl_data: Arc::new(RwLock::new(None)),
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
@ -382,6 +421,7 @@ impl Tensor {
cols, cols,
capacity_cols, capacity_cols,
layout, layout,
q4_layout: Layout::from_size_align(1, 1).unwrap(),
} }
} }
@ -889,6 +929,7 @@ impl Tensor {
} }
match src.dtype { match src.dtype {
TensorDType::K4BitQuantization => unimplemented!(),
TensorDType::Float32 => { TensorDType::Float32 => {
// not actual cache line size, but this represents 8 floats which is the number we can // not actual cache line size, but this represents 8 floats which is the number we can
// operate with AVX2 // operate with AVX2
@ -1043,6 +1084,7 @@ impl Tensor {
return result; return result;
} }
match other.dtype { match other.dtype {
TensorDType::K4BitQuantization => unimplemented!(),
TensorDType::Float32 => self.to_f32(), TensorDType::Float32 => self.to_f32(),
TensorDType::Float16 => self.to_f16(), TensorDType::Float16 => self.to_f16(),
} }
@ -1053,6 +1095,7 @@ impl Tensor {
return self; return self;
} }
match other.dtype { match other.dtype {
TensorDType::K4BitQuantization => unimplemented!(),
TensorDType::Float32 => self.to_f32(), TensorDType::Float32 => self.to_f32(),
TensorDType::Float16 => self.to_f16(), TensorDType::Float16 => self.to_f16(),
} }
@ -1060,6 +1103,7 @@ impl Tensor {
pub fn into_dtype(self, dtype: TensorDType) -> Tensor { pub fn into_dtype(self, dtype: TensorDType) -> Tensor {
match dtype { match dtype {
TensorDType::K4BitQuantization => unimplemented!(),
TensorDType::Float32 => self.to_f32(), TensorDType::Float32 => self.to_f32(),
TensorDType::Float16 => self.to_f16(), TensorDType::Float16 => self.to_f16(),
} }
@ -1114,6 +1158,7 @@ impl Tensor {
} }
match src.dtype { match src.dtype {
TensorDType::K4BitQuantization => unimplemented!(),
TensorDType::Float32 => { TensorDType::Float32 => {
const ITEMS_PER_LINE: usize = 8; const ITEMS_PER_LINE: usize = 8;
@ -1847,8 +1892,7 @@ impl Tensor {
} }
let capacity_cols = compute_capacity_cols(dtype, cols); let capacity_cols = compute_capacity_cols(dtype, cols);
let nitems = rows * capacity_cols; let nitems = rows * capacity_cols;
let layout = let layout = Layout::from_size_align(dtype.bytes_for_nvalues(nitems as usize), 32).unwrap();
Layout::from_size_align((nitems as usize) * dtype.bytes_per_item(), 32).unwrap();
let data = unsafe { std::alloc::alloc_zeroed(layout) }; let data = unsafe { std::alloc::alloc_zeroed(layout) };
if data.is_null() { if data.is_null() {
panic!("Failed to allocate tensor"); panic!("Failed to allocate tensor");
@ -1856,6 +1900,7 @@ impl Tensor {
TENSORS_BYTES_ALLOCATED.fetch_add(layout.size(), std::sync::atomic::Ordering::Relaxed); TENSORS_BYTES_ALLOCATED.fetch_add(layout.size(), std::sync::atomic::Ordering::Relaxed);
Self { Self {
data, data,
q4_data: std::ptr::null_mut(),
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
opencl_data: Arc::new(RwLock::new(None)), opencl_data: Arc::new(RwLock::new(None)),
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
@ -1865,6 +1910,7 @@ impl Tensor {
cols, cols,
capacity_cols, capacity_cols,
layout, layout,
q4_layout: Layout::from_size_align(1, 1).unwrap(),
} }
} }
@ -1880,12 +1926,14 @@ impl Tensor {
unsafe { unsafe {
std::ptr::copy_nonoverlapping( std::ptr::copy_nonoverlapping(
self.data.add( self.data.add(
(row * self.capacity_cols * self.dtype.bytes_per_item() as i64) as usize, (row * self.dtype.bytes_for_nvalues(self.capacity_cols as usize) as i64)
as usize,
), ),
result.data.add( result.data.add(
(row * result.capacity_cols * self.dtype.bytes_per_item() as i64) as usize, (row * self.dtype.bytes_for_nvalues(result.capacity_cols as usize) as i64)
as usize,
), ),
cols * self.dtype.bytes_per_item(), self.dtype.bytes_for_nvalues(cols),
); );
} }
} }
@ -1908,6 +1956,7 @@ impl Tensor {
result.rows = rows; result.rows = rows;
result.cols = cols; result.cols = cols;
match self.dtype { match self.dtype {
TensorDType::K4BitQuantization => unimplemented!(),
TensorDType::Float16 => { TensorDType::Float16 => {
let mut tgt_row: usize = 0; let mut tgt_row: usize = 0;
let mut tgt_col: usize = 0; let mut tgt_col: usize = 0;
@ -2154,9 +2203,9 @@ impl Tensor {
unsafe { unsafe {
std::ptr::copy_nonoverlapping( std::ptr::copy_nonoverlapping(
self.data self.data
.add((row * self.capacity_cols) as usize * self.dtype.bytes_per_item()), .add(row as usize * self.dtype.bytes_for_nvalues(self.capacity_cols as usize)),
result.data, result.data,
self.cols as usize * self.dtype.bytes_per_item(), self.dtype.bytes_for_nvalues(self.cols as usize),
); );
} }
result result
@ -2187,16 +2236,16 @@ impl TensorBuilder {
let mut f = std::fs::File::open(&path).unwrap(); let mut f = std::fs::File::open(&path).unwrap();
f.seek(std::io::SeekFrom::Start( f.seek(std::io::SeekFrom::Start(
(self.offset as u64) * self.dtype.bytes_per_item() as u64, self.dtype.bytes_for_nvalues(self.offset as usize) as u64,
))?; ))?;
let mut cursor: usize = 0; let mut cursor: usize = 0;
let mut buf: Vec<u8> = vec![0; self.cols as usize * self.dtype.bytes_per_item()]; let mut buf: Vec<u8> = vec![0; self.dtype.bytes_for_nvalues(self.cols as usize)];
for _row in 0..self.rows { for _row in 0..self.rows {
f.read_exact(&mut buf)?; f.read_exact(&mut buf)?;
unsafe { unsafe {
std::ptr::copy_nonoverlapping(buf.as_ptr(), tensor.data.add(cursor), buf.len()); std::ptr::copy_nonoverlapping(buf.as_ptr(), tensor.data.add(cursor), buf.len());
} }
cursor += tensor.capacity_cols as usize * self.dtype.bytes_per_item(); cursor += self.dtype.bytes_for_nvalues(tensor.capacity_cols as usize);
} }
Ok(tensor.to_f32()) Ok(tensor.to_f32())
} }
@ -2251,10 +2300,10 @@ impl TensorBuilder {
.join("data") .join("data")
.join(&builder.src_path); .join(&builder.src_path);
buf.truncate(0); buf.truncate(0);
buf.resize(builder.cols as usize * builder.dtype.bytes_per_item(), 0); buf.resize(builder.dtype.bytes_for_nvalues(builder.cols as usize), 0);
let mut f = std::fs::File::open(&path).unwrap(); let mut f = std::fs::File::open(&path).unwrap();
f.seek(std::io::SeekFrom::Start( f.seek(std::io::SeekFrom::Start(
(builder.offset as u64) * builder.dtype.bytes_per_item() as u64, builder.dtype.bytes_for_nvalues(builder.offset as usize) as u64,
))?; ))?;
for row in 0..builder.rows { for row in 0..builder.rows {
match f.read_exact(&mut buf) { match f.read_exact(&mut buf) {
@ -2275,10 +2324,9 @@ impl TensorBuilder {
unsafe { unsafe {
std::ptr::copy_nonoverlapping( std::ptr::copy_nonoverlapping(
buf.as_ptr(), buf.as_ptr(),
tensor.data.add( tensor.data.add(builder.dtype.bytes_for_nvalues(
((row * tensor.capacity_cols + col_offset) as usize) (row * tensor.capacity_cols + col_offset) as usize,
* builder.dtype.bytes_per_item(), )),
),
buf.len(), buf.len(),
); );
} }
@ -2326,10 +2374,10 @@ impl TensorBuilder {
.join("data") .join("data")
.join(&builder.src_path); .join(&builder.src_path);
buf.truncate(0); buf.truncate(0);
buf.resize(builder.cols as usize * builder.dtype.bytes_per_item(), 0); buf.resize(builder.dtype.bytes_for_nvalues(builder.cols as usize), 0);
let mut f = std::fs::File::open(&path).unwrap(); let mut f = std::fs::File::open(&path).unwrap();
f.seek(std::io::SeekFrom::Start( f.seek(std::io::SeekFrom::Start(
(builder.offset as u64) * builder.dtype.bytes_per_item() as u64, builder.dtype.bytes_for_nvalues(builder.offset as usize) as u64,
))?; ))?;
for row in 0..builder.rows { for row in 0..builder.rows {
match f.read_exact(&mut buf) { match f.read_exact(&mut buf) {
@ -2350,10 +2398,9 @@ impl TensorBuilder {
unsafe { unsafe {
std::ptr::copy_nonoverlapping( std::ptr::copy_nonoverlapping(
buf.as_ptr(), buf.as_ptr(),
tensor.data.add( tensor.data.add(builder.dtype.bytes_for_nvalues(
(((row + row_offset) * tensor.capacity_cols) as usize) ((row + row_offset) * tensor.capacity_cols) as usize,
* builder.dtype.bytes_per_item(), )),
),
buf.len(), buf.len(),
); );
} }

@ -158,7 +158,11 @@ impl TokenSampler {
total_p += v.1; total_p += v.1;
} }
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let p: f32 = rng.gen_range(0.0..=total_p); let p: f32 = if total_p > 0.0 {
rng.gen_range(0.0..=total_p)
} else {
0.0
};
p_accum = 0.0; p_accum = 0.0;
for v in logitsf.into_iter() { for v in logitsf.into_iter() {
p_accum += v.1; p_accum += v.1;

@ -415,6 +415,7 @@ impl RMSNorm {
FromPiecesDirection::Rows, FromPiecesDirection::Rows,
)? )?
.to_f32(); .to_f32();
Ok(Self { Ok(Self {
eps, eps,
weight: weights, weight: weights,

@ -0,0 +1,65 @@
use crate::tensor::Tensor;
use rand::{thread_rng, Rng};
use rayon::prelude::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
pub fn quantize(tensor: &Tensor) -> Tensor {
/*
* This is a simplistic rounding quantizer. It splits each row in a tensor to 16 buckets and
* takes the average value in said buckets as the quantized weight.
*/
let mut result = Tensor::zeros(tensor.rows(), tensor.cols(), tensor.dtype());
for row in 0..tensor.rows() {
let mut values: Vec<f32> = Vec::with_capacity(tensor.cols() as usize);
if row % 500 == 0 {
println!("{}", row,);
}
values.truncate(0);
let mut mi: f32 = std::f32::MAX;
let mut ma: f32 = std::f32::MIN;
for col in 0..tensor.cols() {
let val = tensor.get_f32(row, col);
if val < mi {
mi = val;
}
if val > ma {
ma = val;
}
values.push(val);
}
values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
let mut allowed_values: Vec<f32> = Vec::with_capacity(16);
let mut rng = thread_rng();
for i in 0..16 {
let start_idx = i * values.len() / 16;
let end_idx = (i + 1) * values.len() / 16;
let mut avg = 0.0;
for j in start_idx..end_idx {
avg += values[j];
}
avg /= (end_idx - start_idx) as f32;
allowed_values.push(avg);
}
allowed_values[0] = mi;
allowed_values[15] = ma;
allowed_values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
for col in 0..tensor.cols() {
let val = tensor.get_f32(row, col);
let mut best = 0;
let mut best_dist = std::f32::MAX;
for i in 0..16 {
let dist = (val - allowed_values[i] as f32).abs();
if dist < best_dist {
best = i;
best_dist = dist;
}
}
result.set_f32(row, col, allowed_values[best] as f32);
}
}
result
}
Loading…
Cancel
Save