From 18ef805458d8ef26efb5f0bfd4904fd17cb66f64 Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Sat, 11 Mar 2023 10:44:06 -0800 Subject: [PATCH] Read parameters from model's JSON file instead of hard-coding them, make max sequence length configurable. --- Cargo.lock | 2 ++ Cargo.toml | 2 ++ src/rllama_main.rs | 75 +++++++++++++++++++++++++++++----------------- src/transformer.rs | 36 ++++++++-------------- 4 files changed, 64 insertions(+), 51 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d3aaf80..87a9224 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -800,6 +800,8 @@ dependencies = [ "protobuf-parse", "rand", "rayon", + "serde", + "serde_json", "thiserror", ] diff --git a/Cargo.toml b/Cargo.toml index ec1209f..7ea2f16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,8 @@ rayon = "1.7" clap = { version = "4.1", features = ["derive"] } indicatif = "0.17" colored = "2" +serde = { version = "1", features = ["derive"] } +serde_json = "1" # We need protobuf compiler [build-dependencies] diff --git a/src/rllama_main.rs b/src/rllama_main.rs index a64cd1b..920f56c 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -5,6 +5,7 @@ use crate::transformer::Transformer; use crate::unpickler; use clap::Parser; use colored::Colorize; +use serde::{Deserialize, Serialize}; use std::io::{Read, Write}; #[derive(Parser)] @@ -14,11 +15,17 @@ struct Cli { model_path: String, #[arg(long)] tokenizer_path: String, + #[arg(long)] + param_path: String, + #[arg(long)] prompt: Option, #[arg(long)] prompt_file: Option, + #[arg(long)] + max_seq_len: Option, + #[arg(long)] temperature: Option, #[arg(long)] @@ -27,10 +34,21 @@ struct Cli { top_k: Option, } +#[derive(Clone, Serialize, Deserialize)] +struct ModelParams { + dim: usize, + multiple_of: usize, + n_heads: usize, + n_layers: usize, + norm_eps: f64, + vocab_size: i64, +} + pub fn main() -> Result<(), Box> { let cli = Cli::parse(); let model_path = cli.model_path; let tokenizer_path = cli.tokenizer_path; + let param_path = cli.param_path; let mut be_quiet: bool = false; if !colored::control::SHOULD_COLORIZE.should_colorize() { @@ -46,6 +64,14 @@ pub fn main() -> Result<(), Box> { }; } + // Read ModelParams from param_path, we expect it to be JSON + let mut fs = std::fs::File::open(¶m_path)?; + let mut bs = Vec::new(); + fs.read_to_end(&mut bs)?; + std::mem::drop(fs); + let params: ModelParams = serde_json::from_slice(&bs)?; + pln!("Loaded model parameters from {}.", param_path); + let prompt: String = match (cli.prompt, cli.prompt_file) { (Some(prompt), None) => { pln!("Using prompt: {}", prompt); @@ -84,19 +110,20 @@ pub fn main() -> Result<(), Box> { pln!("Loading embeddings from {}...", model_data_dir); let emb = Embedding::from_unpickled(&result, model_data_dir.clone())?; - let max_seq_len = 512; + let max_seq_len = match cli.max_seq_len { + Some(max_seq_len) => max_seq_len, + None => 1024, + }; pln!("Loading transformer weights from {}...", model_data_dir); let tr = Transformer::from_unpickled( &result, emb, - 4096, - 32, - 32, + params.dim, + params.n_layers, + params.n_heads, max_seq_len, - 1e-6, - 32, - 128, + params.norm_eps, model_data_dir, )?; pln!("All is loaded. Starting inference."); @@ -116,6 +143,14 @@ pub fn main() -> Result<(), Box> { } pln!("---"); + pln!(" dim: {}", params.dim); + pln!(" multiple_of: {}", params.multiple_of); + pln!(" n_heads: {}", params.n_heads); + pln!(" n_layers: {}", params.n_layers); + pln!(" norm_eps: {}", params.norm_eps); + pln!(" vocab_size: {}", params.vocab_size); + pln!("---"); + pln!("Max sequence length: {}", max_seq_len); pln!("Temperature: {}", token_sampler.get_temperature()); pln!("Top P: {}", token_sampler.get_top_p()); pln!("Top K: {}", token_sampler.get_top_k()); @@ -126,12 +161,7 @@ pub fn main() -> Result<(), Box> { ); pln!( "{}", - " This is the color of the generated text while full context is available" - .truecolor(128, 255, 128) - ); - pln!( - "{}", - " Remaining text is in this color".truecolor(255, 128, 128) + " This is the color of the generated text".truecolor(128, 255, 128) ); pln!("---"); print!("{}", prompt.as_str().truecolor(128, 128, 255)); @@ -139,17 +169,8 @@ pub fn main() -> Result<(), Box> { let mut caches = tr.make_caches(); let mut first: bool = true; - let mut shifts: usize = 0; - loop { - if toks_id.len() >= max_seq_len { - toks_id = toks_id[1..].to_vec(); - prev_pos -= 1; - caches.shift_left(1); - shifts += 1; - // TODO: it seems that text beyond context is just broken. - // Maybe I cannot just go and shift it. - } - let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches, shifts); + while toks_id.len() < max_seq_len { + let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches); let highest_pred_idx = token_sampler.sample(&preds); toks_id.push(highest_pred_idx as TokenId); @@ -166,14 +187,14 @@ pub fn main() -> Result<(), Box> { } if first && tok_idx < toks_id.len() - 1 { // intentionally left empty - } else if shifts == 0 { - print!("{}", tok_str.truecolor(128, 255, 128)); } else { - print!("{}", tok_str.truecolor(255, 128, 128)); + print!("{}", tok_str.truecolor(128, 255, 128)); } } let _ = std::io::stdout().flush(); prev_pos = toks_id.len() - 1; first = false; } + println!(""); + Ok(()) } diff --git a/src/transformer.rs b/src/transformer.rs index bfbba5e..dbed42e 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -121,10 +121,14 @@ impl Transformer { n_heads: usize, max_seq_len: usize, eps: f64, - n_local_heads: usize, - head_dim: usize, data_dir: P, ) -> Result { + assert_eq!(dim % n_heads, 0); + let head_dim = dim / n_heads; + let n_local_heads = n_heads; // I think the local heads is an artifact of the original + // implementation that used multi-GPU in the Facebook repo. + // Should delete it later. + let data_dir: &Path = data_dir.as_ref(); let progress_bar = ProgressBar::new(n_layers as u64); @@ -149,9 +153,7 @@ impl Transformer { let output = Tensor::from_unpickled(unpickled, "output.weight", data_dir)?.to_f32(); Ok(Transformer { - // TODO: maybe rotary embedding can be computed on the fly if the sequence gets really long. I just - // slapped * 20 on max seq len here. - freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len * 20, 10000.0), + freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len, 10000.0), emb, dim, n_layers, @@ -186,7 +188,6 @@ impl Transformer { tokens: &[TokenId], start_pos: usize, caches: &mut TransformerCaches, - shifts: usize, ) -> Tensor { assert!(caches.layer_caches.len() == self.n_layers); let mask: Option = if tokens.len() > 1 { @@ -215,7 +216,6 @@ impl Transformer { &self.freqs_cis, &mask, &mut caches.layer_caches[idx], - shifts, ); } let out = self.norm.forward(&emb_tensor); @@ -265,17 +265,11 @@ impl TransformerBlock { freqs_cis: &FreqsCis, mask: &Option, attention_cache: &mut AttentionCache, - shifts: usize, ) -> Tensor { let attnorm_out = self.attention_norm.forward(x); - let att_out = self.attn.forward( - &attnorm_out, - start_pos, - freqs_cis, - mask, - attention_cache, - shifts, - ); + let att_out = self + .attn + .forward(&attnorm_out, start_pos, freqs_cis, mask, attention_cache); let h = x.add(&att_out); let att_out = self.ffn_norm.forward(&h); let att_out = self.feed_forward.forward(&att_out).transpose(); @@ -404,7 +398,6 @@ impl Attention { freqs_cis: &FreqsCis, mask: &Option, attention_cache: &mut AttentionCache, - shifts: usize, ) -> Tensor { let seq_len = x.rows(); let (xq_out, (xk_out, xv_out)) = rayon::join( @@ -432,13 +425,8 @@ impl Attention { .row(idx) .view(self.n_local_heads as i64, self.head_dim as i64); - let (xq_row, xk_row) = apply_rotary_emb( - &xq_row, - &xk_row, - freqs_cis, - idx as usize, - start_pos + shifts, - ); + let (xq_row, xk_row) = + apply_rotary_emb(&xq_row, &xk_row, freqs_cis, idx as usize, start_pos); xq_views.push(xq_row); xk_views.push(xk_row);