From 26d5309cf7b6d2e5bd437622808a5ab0fd951536 Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Sat, 11 Mar 2023 21:50:59 -0800 Subject: [PATCH] Add support for bigger models. I've tested with 13B LLaMA model and it seems to work. There was a bug in unpickler that skipped over tuples of size 1. I had written bunch of code assuming there is no bug which I fixed and removed some unpickling code. I added functions to tensor.rs to be able construct tensors out of multiple files. --- README.md | 19 ++-- src/embedding.rs | 33 ++++--- src/rllama_main.rs | 52 ++++++---- src/tensor.rs | 241 +++++++++++++++++++++++++++++++++++++++++++-- src/transformer.rs | 51 +++++++--- src/unpickler.rs | 104 ++++++++++--------- 6 files changed, 384 insertions(+), 116 deletions(-) diff --git a/README.md b/README.md index a5b5f9a..c4269a1 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,16 @@ As of writing of this, this can run LLaMA-7B at around ~1 token per second, on a Ryzen 3950X using something like 1.5 threads because I haven't yet properly figured out how to multithread this. -It uses AVX2 intrinsics to speed up itself. Therefore, you need an x86-family +I've also managed to run LLaMA-13B which just barely fits in my 64-gig machine +with 32-bit float weights everywhere. + +I have not tried the bigger models yet. + +This uses AVX2 intrinsics to speed up itself. Therefore, you need an x86-family CPU to run this. -It has a Python unpickler that understands the `.pth` files used by PyTorch. -Well sort of, it doesn't unzip them automatically (see below). +It also has a Python unpickler that understands the `.pth` files used by +PyTorch. Well almost, it doesn't unzip them automatically (see below). # How to run @@ -27,16 +32,18 @@ decompress it. $ cd LLaMA $ cd 7B $ unzip consolidated.00.pth +# Only necessary for LLaMA-7B, rllama currently expected .00, .01, .02 etc.in directories +$ mv consolidated consolidated.00 ``` You should then be ready to generate some text. ```shell -cargo run --release -- --tokenizer-model /path/to/tokenizer.model --model-path /path/to/LLaMA/7B/consolidated/data.pkl --prompt "The meaning of life is" +cargo run --release -- --tokenizer-model /path/to/tokenizer.model --model-path /path/to/LLaMA/7B --param-path /path/to/LLaMA/7B/params.json --prompt "The meaning of life is" ``` -Right now it seems to use around ~25 gigabytes of memory. Internally all -weights are cast to 32-bit floats. +Right now it seems to use around ~25 gigabytes of memory for 7B and around ~50 +gigabytes for 13B. Internally all weights are cast to 32-bit floats. You can use `--temperature`, `--top-p` and `--top-k` to adjust token sampler settings. diff --git a/src/embedding.rs b/src/embedding.rs index dccbf6e..5c08c37 100644 --- a/src/embedding.rs +++ b/src/embedding.rs @@ -1,4 +1,4 @@ -use crate::tensor::Tensor; +use crate::tensor::{FromPiecesDirection, Tensor, TensorBuilder}; use crate::unpickler; use crate::unpickler::*; use std::collections::BTreeMap; @@ -10,24 +10,29 @@ pub struct Embedding { impl Embedding { pub fn from_unpickled>( - unpickled: &unpickler::Value, + unpickled: &[unpickler::Value], data_dir: P, ) -> Result { let data_dir: &Path = data_dir.as_ref(); - let val = match unpickled.get_str_key("tok_embeddings.weight") { - Some(val) => val, - None => { - return Err(UnpicklingError::MissingField( - "tok_embeddings.weight".to_string(), - )) - } - }; - let tensor = val - .to_tensor_builder() - .ok_or(UnpicklingError::InvalidTensorData)?; - let tensor = tensor.load(data_dir)?; + let mut builders: Vec = vec![]; + for unpickle in unpickled.iter() { + let val = match unpickle.get_str_key("tok_embeddings.weight") { + Some(val) => val, + None => { + return Err(UnpicklingError::MissingField( + "tok_embeddings.weight".to_string(), + )) + } + }; + builders.push( + val.to_tensor_builder() + .ok_or(UnpicklingError::InvalidTensorData)?, + ); + } + let tensor = + TensorBuilder::load_from_pieces(&builders, data_dir, FromPiecesDirection::Cols)?; let num_embeddings = tensor.rows(); let mut table: BTreeMap = BTreeMap::new(); diff --git a/src/rllama_main.rs b/src/rllama_main.rs index 920f56c..7ee7f21 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -3,10 +3,12 @@ use crate::token_sampler::TokenSampler; use crate::tokenizer::{TokenId, Tokenizer}; use crate::transformer::Transformer; use crate::unpickler; +use crate::unpickler::Value; use clap::Parser; use colored::Colorize; use serde::{Deserialize, Serialize}; use std::io::{Read, Write}; +use std::path::PathBuf; #[derive(Parser)] #[command(author, version, about, long_about = None)] @@ -94,37 +96,53 @@ pub fn main() -> Result<(), Box> { pln!("Starting up. Loading tokenizer from {}...", tokenizer_path); let tok = Tokenizer::load(tokenizer_path.as_str())?; pln!("Tokenizer loaded. Loading model from {}...", model_path); - let mut fs = std::fs::File::open(model_path.as_str())?; - let mut bs = Vec::new(); - fs.read_to_end(&mut bs)?; - std::mem::drop(fs); - // We chop off file name from model_path and append "data/" - let model_data_dir = model_path - .split('/') - .take(model_path.split('/').count() - 1) - .collect::>() - .join("/") - + "/data/"; - let result = unpickler::unpickle(&bs)?; - pln!("Loading embeddings from {}...", model_data_dir); - let emb = Embedding::from_unpickled(&result, model_data_dir.clone())?; + let mut unpickle_results: Vec = vec![]; + + let mut part: usize = 0; + loop { + let model_path: PathBuf = model_path.clone().into(); + let base_path = model_path.join(format!("consolidated.{:02}", part)); + // The data file is in consolidated.XX/data.pkl where XX is the part number. + let full_path = base_path.join("data.pkl"); + let mut fs = match std::fs::File::open(&full_path) { + Ok(fs) => fs, + Err(err) => { + if err.kind() == std::io::ErrorKind::NotFound { + break; + } else { + return Err(err.into()); + } + } + }; + let mut bs = Vec::new(); + fs.read_to_end(&mut bs)?; + std::mem::drop(fs); + pln!("Read data.pkl from path {}", full_path.display()); + + let result = unpickler::unpickle(&bs)?; + unpickle_results.push(result); + part += 1; + } + + pln!("Loading embeddings from {}...", model_path); + let emb = Embedding::from_unpickled(&unpickle_results, model_path.clone())?; let max_seq_len = match cli.max_seq_len { Some(max_seq_len) => max_seq_len, None => 1024, }; - pln!("Loading transformer weights from {}...", model_data_dir); + pln!("Loading transformer weights from {}...", model_path); let tr = Transformer::from_unpickled( - &result, + &unpickle_results, emb, params.dim, params.n_layers, params.n_heads, max_seq_len, params.norm_eps, - model_data_dir, + model_path, )?; pln!("All is loaded. Starting inference."); diff --git a/src/tensor.rs b/src/tensor.rs index c3a4302..d1f94dd 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -5,7 +5,7 @@ use rand::Rng; use rayon::prelude::*; use std::alloc::Layout; use std::arch::x86_64::*; -use std::io::Read; +use std::io::{Read, Seek}; use std::path::{Path, PathBuf}; use thiserror::Error; @@ -17,6 +17,7 @@ pub struct TensorBuilder { pub(crate) rows: i64, pub(crate) cols: i64, pub(crate) nitems: i64, + pub(crate) offset: i64, } #[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] @@ -29,12 +30,20 @@ pub enum TensorDType { pub enum TensorError { #[error("IO error: {0}")] IOError(#[from] std::io::Error), + #[error("IOError while reading tensor: {0} {1}")] + TensorBuilderReadError(std::io::Error, String), #[error("Invalid stride: {0}")] InvalidStride(i64), + #[error("Tried to build a tensor from zero files")] + TensorBuilderEmpty, + #[error("Tried to build a tensor from multiple files but the number of rows do not agree between the files. {0} != {1}")] + TensorBuilderRowsMismatch(i64, i64), + #[error("Tried to build a tensor from multiple files but the data types do not agree between the files. {0:?} != {1:?}")] + TensorBuilderDTypeMismatch(TensorDType, TensorDType), } impl TensorDType { - fn bytes_per_item(&self) -> usize { + pub fn bytes_per_item(&self) -> usize { match self { Self::Float16 => 2, Self::Float32 => 4, @@ -118,6 +127,28 @@ impl Tensor { Ok(val) } + pub fn from_unpickled_pieces, S: AsRef>( + unpickled: &[unpickler::Value], + name: S, + data_dir: P, + direction: FromPiecesDirection, + ) -> Result { + let data_dir: &Path = data_dir.as_ref(); + let name: &str = name.as_ref(); + let mut builders = Vec::new(); + for unpickle in unpickled.iter() { + let val = unpickle + .get_str_key(name) + .ok_or(UnpicklingError::MissingField(name.to_string()))?; + let val = val + .to_tensor_builder() + .ok_or(UnpicklingError::InvalidTensorData)?; + builders.push(val); + } + let val = TensorBuilder::load_from_pieces(&builders, data_dir, direction)?; + Ok(val) + } + pub fn rows(&self) -> i64 { self.rows } @@ -412,10 +443,16 @@ impl Tensor { pub fn hadamard_product_broadcast(&self, other: &Tensor) -> Tensor { if self.cols != other.cols { - panic!("Invalid hadamard product broadcast"); + panic!( + "Invalid hadamard product broadcast: {}x{} vs {}x{}", + self.rows, self.cols, other.rows, other.cols + ); } if other.rows != 1 { - panic!("Invalid hadamard product broadcast"); + panic!( + "Invalid hadamard product broadcast: {}x{} vs {}x{}", + self.rows, self.cols, other.rows, other.cols + ); } let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { @@ -1036,7 +1073,10 @@ impl Tensor { pub fn view(&self, rows: i64, cols: i64) -> Tensor { if rows * cols != self.rows * self.cols { - panic!("Invalid tensor view"); + panic!( + "Invalid tensor view, requested {}x{} but tensor is {}x{}", + rows, cols, self.rows, self.cols + ); } if rows == self.rows { return self.clone(); @@ -1139,6 +1179,16 @@ impl Tensor { } } +/// When we load multiple tensors, should we slap them together row by row, or column by column? +/// +/// E.g. If we have 32x4 and 32x4 then Rows --> 64x4 +/// If we have 32x4 and 32x4 then Cols --> 32x8 +#[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd, Debug)] +pub enum FromPiecesDirection { + Rows, + Cols, +} + impl TensorBuilder { pub fn load>(&self, data_dir: P) -> Result { let data_dir: &Path = data_dir.as_ref(); @@ -1146,21 +1196,194 @@ impl TensorBuilder { return Err(TensorError::InvalidStride(self.stride)); } let tensor = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; - assert_eq!(self.dtype, TensorDType::Float16); - let path = data_dir.join(&self.src_path); + let path = data_dir + .join(format!("consolidated.{:02}", 0)) + .join("data") + .join(&self.src_path); let mut f = std::fs::File::open(&path).unwrap(); + f.seek(std::io::SeekFrom::Start( + (self.offset as u64) * self.dtype.bytes_per_item() as u64, + ))?; let mut cursor: usize = 0; - let mut buf: Vec = vec![0; self.cols as usize * 2]; + let mut buf: Vec = vec![0; self.cols as usize * self.dtype.bytes_per_item()]; for _row in 0..self.rows { f.read_exact(&mut buf)?; unsafe { std::ptr::copy_nonoverlapping(buf.as_ptr(), tensor.data.add(cursor), buf.len()); } - cursor += tensor.capacity_cols as usize * 2; + cursor += tensor.capacity_cols as usize * self.dtype.bytes_per_item(); } Ok(tensor.to_f32()) } + + /// Loads a tensor from multiple TensorBuilders; used to load a tensor from multiple files + /// which is what the larger LLaMA models do. + pub fn load_from_pieces>( + builders: &[Self], + data_dir: P, + direction: FromPiecesDirection, + ) -> Result { + let data_dir: &Path = data_dir.as_ref(); + if builders.is_empty() { + return Err(TensorError::TensorBuilderEmpty); + } + + fn load_from_pieces_cols( + builders: &[TensorBuilder], + data_dir: &Path, + ) -> Result { + let mut total_cols: i64 = 0; + let expected_rows: i64 = builders[0].rows; + let expected_dtype: TensorDType = builders[0].dtype; + + // Do some checking before we attempt loading. + for builder in builders.iter() { + total_cols += builder.cols; + if builder.stride < 1 { + return Err(TensorError::InvalidStride(builder.stride)); + } + if builder.rows != expected_rows { + return Err(TensorError::TensorBuilderRowsMismatch( + builder.rows, + expected_rows, + )); + } + if builder.dtype != expected_dtype { + return Err(TensorError::TensorBuilderDTypeMismatch( + builder.dtype, + expected_dtype, + )); + } + } + + let tensor = + unsafe { Tensor::uninitialized(expected_rows, total_cols, builders[0].dtype) }; + let mut buf: Vec = vec![]; + let mut col_offset = 0; + for (idx, builder) in builders.iter().enumerate() { + let path = data_dir + .join(format!("consolidated.{:02}", idx)) + .join("data") + .join(&builder.src_path); + buf.truncate(0); + buf.resize(builder.cols as usize * builder.dtype.bytes_per_item(), 0); + let mut f = std::fs::File::open(&path).unwrap(); + f.seek(std::io::SeekFrom::Start( + (builder.offset as u64) * builder.dtype.bytes_per_item() as u64, + ))?; + for row in 0..builder.rows { + match f.read_exact(&mut buf) { + Ok(_) => {} + Err(err) => { + return Err(TensorError::TensorBuilderReadError( + err, + format!( + "path={:?} row={} expected_len={} offset={}", + path, + row, + buf.len(), + builder.offset + ), + )); + } + }; + unsafe { + std::ptr::copy_nonoverlapping( + buf.as_ptr(), + tensor.data.add( + ((row * tensor.capacity_cols + col_offset) as usize) + * builder.dtype.bytes_per_item(), + ), + buf.len(), + ); + } + } + col_offset += builder.cols; + } + Ok(tensor.to_f32()) + } + + fn load_from_pieces_rows( + builders: &[TensorBuilder], + data_dir: &Path, + ) -> Result { + let mut total_rows: i64 = 0; + let expected_cols: i64 = builders[0].cols; + let expected_dtype: TensorDType = builders[0].dtype; + + // Do some checking before we attempt loading. + for builder in builders.iter() { + total_rows += builder.rows; + if builder.stride < 1 { + return Err(TensorError::InvalidStride(builder.stride)); + } + if builder.cols != expected_cols { + return Err(TensorError::TensorBuilderRowsMismatch( + builder.cols, + expected_cols, + )); + } + if builder.dtype != expected_dtype { + return Err(TensorError::TensorBuilderDTypeMismatch( + builder.dtype, + expected_dtype, + )); + } + } + + let tensor = + unsafe { Tensor::uninitialized(total_rows, expected_cols, builders[0].dtype) }; + let mut buf: Vec = vec![]; + let mut row_offset: i64 = 0; + for (idx, builder) in builders.iter().enumerate() { + let path = data_dir + .join(format!("consolidated.{:02}", idx)) + .join("data") + .join(&builder.src_path); + buf.truncate(0); + buf.resize(builder.cols as usize * builder.dtype.bytes_per_item(), 0); + let mut f = std::fs::File::open(&path).unwrap(); + f.seek(std::io::SeekFrom::Start( + (builder.offset as u64) * builder.dtype.bytes_per_item() as u64, + ))?; + for row in 0..builder.rows { + match f.read_exact(&mut buf) { + Ok(_) => {} + Err(err) => { + return Err(TensorError::TensorBuilderReadError( + err, + format!( + "path={:?} row={} expected_len={} offset={}", + path, + row, + buf.len(), + builder.offset + ), + )); + } + }; + unsafe { + std::ptr::copy_nonoverlapping( + buf.as_ptr(), + tensor.data.add( + (((row + row_offset) * tensor.capacity_cols) as usize) + * builder.dtype.bytes_per_item(), + ), + buf.len(), + ); + } + } + row_offset += builder.rows; + } + Ok(tensor.to_f32()) + } + + match direction { + FromPiecesDirection::Rows => load_from_pieces_rows(builders, data_dir), + FromPiecesDirection::Cols => load_from_pieces_cols(builders, data_dir), + } + } } #[cfg(test)] diff --git a/src/transformer.rs b/src/transformer.rs index dbed42e..1eeb0f5 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -1,5 +1,5 @@ use crate::embedding::Embedding; -use crate::tensor::{Tensor, TensorDType}; +use crate::tensor::{FromPiecesDirection, Tensor, TensorDType}; use crate::tokenizer::TokenId; use crate::unpickler; use crate::unpickler::UnpicklingError; @@ -114,7 +114,7 @@ pub struct FeedForward { impl Transformer { #[allow(clippy::too_many_arguments)] pub fn from_unpickled>( - unpickled: &unpickler::Value, + unpickled: &[unpickler::Value], emb: Embedding, dim: usize, n_layers: usize, @@ -150,7 +150,13 @@ impl Transformer { std::mem::drop(progress_bar); let norm = RMSNorm::from_unpickled(unpickled, "norm.weight".to_string(), eps, data_dir)?; - let output = Tensor::from_unpickled(unpickled, "output.weight", data_dir)?.to_f32(); + let output = Tensor::from_unpickled_pieces( + unpickled, + "output.weight", + data_dir, + FromPiecesDirection::Rows, + )? + .to_f32(); Ok(Transformer { freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len, 10000.0), @@ -227,7 +233,7 @@ impl Transformer { impl TransformerBlock { pub fn from_unpickled>( - unpickled: &unpickler::Value, + unpickled: &[unpickler::Value], layer_id: usize, eps: f64, n_local_heads: usize, @@ -279,13 +285,19 @@ impl TransformerBlock { impl RMSNorm { pub fn from_unpickled>( - unpickled: &unpickler::Value, + unpickled: &[unpickler::Value], name: String, eps: f64, data_dir: P, ) -> Result { let data_dir: &Path = data_dir.as_ref(); - let weights = Tensor::from_unpickled(unpickled, name, data_dir)?.to_f32(); + let weights = Tensor::from_unpickled_pieces( + &unpickled[0..=0], + name.clone(), + data_dir, + FromPiecesDirection::Rows, + )? + .to_f32(); Ok(Self { eps, weight: weights, @@ -301,28 +313,31 @@ impl RMSNorm { impl FeedForward { pub fn from_unpickled>( - unpickled: &unpickler::Value, + unpickled: &[unpickler::Value], layer_id: usize, data_dir: P, ) -> Result { let data_dir: &Path = data_dir.as_ref(); - let w1 = Tensor::from_unpickled( + let w1 = Tensor::from_unpickled_pieces( unpickled, format!("layers.{}.feed_forward.w1.weight", layer_id), data_dir, + FromPiecesDirection::Rows, )? .to_f32(); - let w2 = Tensor::from_unpickled( + let w2 = Tensor::from_unpickled_pieces( unpickled, format!("layers.{}.feed_forward.w2.weight", layer_id), data_dir, + FromPiecesDirection::Cols, )? .to_f32(); - let w3 = Tensor::from_unpickled( + let w3 = Tensor::from_unpickled_pieces( unpickled, format!("layers.{}.feed_forward.w3.weight", layer_id), data_dir, + FromPiecesDirection::Rows, )? .to_f32(); @@ -348,7 +363,7 @@ impl FeedForward { impl Attention { pub fn from_unpickled>( - unpickled: &unpickler::Value, + unpickled: &[unpickler::Value], layer_id: usize, n_local_heads: usize, head_dim: usize, @@ -356,28 +371,32 @@ impl Attention { ) -> Result { let data_dir: &Path = data_dir.as_ref(); - let wq = Tensor::from_unpickled( + let wq = Tensor::from_unpickled_pieces( unpickled, format!("layers.{}.attention.wq.weight", layer_id), data_dir, + FromPiecesDirection::Rows, )? .to_f32(); - let wk = Tensor::from_unpickled( + let wk = Tensor::from_unpickled_pieces( unpickled, format!("layers.{}.attention.wk.weight", layer_id), data_dir, + FromPiecesDirection::Rows, )? .to_f32(); - let wv = Tensor::from_unpickled( + let wv = Tensor::from_unpickled_pieces( unpickled, format!("layers.{}.attention.wv.weight", layer_id), data_dir, + FromPiecesDirection::Rows, )? .to_f32(); - let wo = Tensor::from_unpickled( + let wo = Tensor::from_unpickled_pieces( unpickled, format!("layers.{}.attention.wo.weight", layer_id), data_dir, + FromPiecesDirection::Cols, )? .to_f32(); @@ -494,7 +513,7 @@ impl Attention { concat_vec.push(output.row(idx)); } let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect(); - let xq_row = Tensor::concat(&concat_vec2).view(1, 4096); + let xq_row = Tensor::concat(&concat_vec2).view(1, self.wo.rows()); xq_row.matrix_mul_transposed(&self.wo) }) .collect(); diff --git a/src/unpickler.rs b/src/unpickler.rs index df808b2..84a167f 100644 --- a/src/unpickler.rs +++ b/src/unpickler.rs @@ -108,54 +108,11 @@ impl Value { fn to_tensor_builder2(&self, args: &[Value]) -> Option { if args.len() == 6 { Self::to_tensor_builder2_6items(args) - } else if args.len() == 4 { - Self::to_tensor_builder2_4items(args) } else { None } } - fn to_tensor_builder2_4items(args: &[Value]) -> Option { - let storagev: &Value = args[0].get_persistent_id()?; - let storage_args: &[Value] = storagev.get_tuple()?; - let storage_mark: &str = storage_args[0].get_str()?; - if storage_mark != "storage" { - return None; - } - - let (storage_module, storage_type) = storage_args[1].get_global()?; - if storage_module != "torch" { - return None; - } - let dtype: TensorDType = match storage_type { - "HalfStorage" => TensorDType::Float16, - _ => return None, - }; - let storage_filename: &str = storage_args[2].get_str()?; - let nitems: i64 = storage_args[4].get_int64()?; - - let offset: i64 = args[1].get_int64()?; - if offset != 0 { - return None; - } - - let rows: i64 = 1; - let cols: i64 = nitems; - let row_stride: i64 = cols; - if row_stride != cols { - return None; - } - - Some(TensorBuilder { - src_path: PathBuf::from(storage_filename), - dtype, - stride: row_stride, - rows, - cols, - nitems, - }) - } - fn to_tensor_builder2_6items(args: &[Value]) -> Option { let storagev: &Value = args[0].get_persistent_id()?; let storage_args: &[Value] = storagev.get_tuple()?; @@ -170,36 +127,52 @@ impl Value { } let dtype: TensorDType = match storage_type { "HalfStorage" => TensorDType::Float16, - _ => return None, + _ => { + println!("1"); + return None; + } }; let storage_filename: &str = storage_args[2].get_str()?; let nitems: i64 = storage_args[4].get_int64()?; let offset: i64 = args[1].get_int64()?; - if offset != 0 { - return None; - } let shape: &[Value] = args[2].get_tuple()?; let stride: &[Value] = args[3].get_tuple()?; - if shape.len() != 2 { + if shape.len() != 2 && shape.len() != 1 { + println!("2"); return None; } - if stride.len() != 2 { + if stride.len() != 2 && stride.len() != 1 { + println!("3"); return None; } - let rows: i64 = shape[0].get_int64()?; - let cols: i64 = shape[1].get_int64()?; + let (rows, cols) = if shape.len() == 2 { + (shape[0].get_int64()?, shape[1].get_int64()?) + } else { + let cols = shape[0].get_int64()?; + (1, cols) + }; - let row_stride: i64 = stride[0].get_int64()?; - let col_stride: i64 = stride[1].get_int64()?; + let (row_stride, col_stride) = if stride.len() == 1 { + let (r, c) = (stride[0].get_int64()?, 1); + if r != 1 { + println!("4"); + return None; + } + (r, c) + } else { + (stride[0].get_int64()?, stride[1].get_int64()?) + }; if col_stride != 1 { + println!("5"); return None; } - if row_stride != cols { + if row_stride != cols && stride.len() == 2 { + println!("6"); return None; } @@ -210,6 +183,7 @@ impl Value { rows, cols, nitems, + offset, }) /* Args should look like this (took random example from debug print) : @@ -529,6 +503,7 @@ pub fn unpickle(bytes: &[u8]) -> Result { )); } tuple.push(stack.pop().unwrap()); + stack.push(Value::Tuple(tuple)); bytes = &bytes[1..]; continue; } @@ -604,6 +579,27 @@ pub fn unpickle(bytes: &[u8]) -> Result { bytes = &bytes[1..]; continue; } + if frame_opcode == 106 { + // long_binget + if bytes.len() < 5 { + return Err(UnpicklingError::UnpicklingError( + "Unexpected end of data while handling LONG_BINGET".to_string(), + )); + } + let idx = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + match memo.get(&(idx as u32)) { + None => { + return Err(UnpicklingError::UnpicklingError( + "LONG_BINGET index out of range".to_string(), + )); + } + Some(memo_value) => { + stack.push(memo_value.clone()); + } + } + bytes = &bytes[5..]; + continue; + } if frame_opcode == 46 { // stop // bytes = &bytes[1..];