Read parameters from model's JSON file instead of hard-coding them, make max sequence length configurable.

broken-opencl-code
Mikko Juola 3 years ago
parent f103871bc0
commit 18ef805458

2
Cargo.lock generated

@ -800,6 +800,8 @@ dependencies = [
"protobuf-parse", "protobuf-parse",
"rand", "rand",
"rayon", "rayon",
"serde",
"serde_json",
"thiserror", "thiserror",
] ]

@ -22,6 +22,8 @@ rayon = "1.7"
clap = { version = "4.1", features = ["derive"] } clap = { version = "4.1", features = ["derive"] }
indicatif = "0.17" indicatif = "0.17"
colored = "2" colored = "2"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
# We need protobuf compiler # We need protobuf compiler
[build-dependencies] [build-dependencies]

@ -5,6 +5,7 @@ use crate::transformer::Transformer;
use crate::unpickler; use crate::unpickler;
use clap::Parser; use clap::Parser;
use colored::Colorize; use colored::Colorize;
use serde::{Deserialize, Serialize};
use std::io::{Read, Write}; use std::io::{Read, Write};
#[derive(Parser)] #[derive(Parser)]
@ -14,11 +15,17 @@ struct Cli {
model_path: String, model_path: String,
#[arg(long)] #[arg(long)]
tokenizer_path: String, tokenizer_path: String,
#[arg(long)]
param_path: String,
#[arg(long)] #[arg(long)]
prompt: Option<String>, prompt: Option<String>,
#[arg(long)] #[arg(long)]
prompt_file: Option<String>, prompt_file: Option<String>,
#[arg(long)]
max_seq_len: Option<usize>,
#[arg(long)] #[arg(long)]
temperature: Option<f32>, temperature: Option<f32>,
#[arg(long)] #[arg(long)]
@ -27,10 +34,21 @@ struct Cli {
top_k: Option<i32>, top_k: Option<i32>,
} }
#[derive(Clone, Serialize, Deserialize)]
struct ModelParams {
dim: usize,
multiple_of: usize,
n_heads: usize,
n_layers: usize,
norm_eps: f64,
vocab_size: i64,
}
pub fn main() -> Result<(), Box<dyn std::error::Error>> { pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse(); let cli = Cli::parse();
let model_path = cli.model_path; let model_path = cli.model_path;
let tokenizer_path = cli.tokenizer_path; let tokenizer_path = cli.tokenizer_path;
let param_path = cli.param_path;
let mut be_quiet: bool = false; let mut be_quiet: bool = false;
if !colored::control::SHOULD_COLORIZE.should_colorize() { if !colored::control::SHOULD_COLORIZE.should_colorize() {
@ -46,6 +64,14 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
}; };
} }
// Read ModelParams from param_path, we expect it to be JSON
let mut fs = std::fs::File::open(&param_path)?;
let mut bs = Vec::new();
fs.read_to_end(&mut bs)?;
std::mem::drop(fs);
let params: ModelParams = serde_json::from_slice(&bs)?;
pln!("Loaded model parameters from {}.", param_path);
let prompt: String = match (cli.prompt, cli.prompt_file) { let prompt: String = match (cli.prompt, cli.prompt_file) {
(Some(prompt), None) => { (Some(prompt), None) => {
pln!("Using prompt: {}", prompt); pln!("Using prompt: {}", prompt);
@ -84,19 +110,20 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
pln!("Loading embeddings from {}...", model_data_dir); pln!("Loading embeddings from {}...", model_data_dir);
let emb = Embedding::from_unpickled(&result, model_data_dir.clone())?; let emb = Embedding::from_unpickled(&result, model_data_dir.clone())?;
let max_seq_len = 512; let max_seq_len = match cli.max_seq_len {
Some(max_seq_len) => max_seq_len,
None => 1024,
};
pln!("Loading transformer weights from {}...", model_data_dir); pln!("Loading transformer weights from {}...", model_data_dir);
let tr = Transformer::from_unpickled( let tr = Transformer::from_unpickled(
&result, &result,
emb, emb,
4096, params.dim,
32, params.n_layers,
32, params.n_heads,
max_seq_len, max_seq_len,
1e-6, params.norm_eps,
32,
128,
model_data_dir, model_data_dir,
)?; )?;
pln!("All is loaded. Starting inference."); pln!("All is loaded. Starting inference.");
@ -116,6 +143,14 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
pln!("---"); pln!("---");
pln!(" dim: {}", params.dim);
pln!(" multiple_of: {}", params.multiple_of);
pln!(" n_heads: {}", params.n_heads);
pln!(" n_layers: {}", params.n_layers);
pln!(" norm_eps: {}", params.norm_eps);
pln!(" vocab_size: {}", params.vocab_size);
pln!("---");
pln!("Max sequence length: {}", max_seq_len);
pln!("Temperature: {}", token_sampler.get_temperature()); pln!("Temperature: {}", token_sampler.get_temperature());
pln!("Top P: {}", token_sampler.get_top_p()); pln!("Top P: {}", token_sampler.get_top_p());
pln!("Top K: {}", token_sampler.get_top_k()); pln!("Top K: {}", token_sampler.get_top_k());
@ -126,12 +161,7 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
); );
pln!( pln!(
"{}", "{}",
" This is the color of the generated text while full context is available" " This is the color of the generated text".truecolor(128, 255, 128)
.truecolor(128, 255, 128)
);
pln!(
"{}",
" Remaining text is in this color".truecolor(255, 128, 128)
); );
pln!("---"); pln!("---");
print!("{}", prompt.as_str().truecolor(128, 128, 255)); print!("{}", prompt.as_str().truecolor(128, 128, 255));
@ -139,17 +169,8 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut caches = tr.make_caches(); let mut caches = tr.make_caches();
let mut first: bool = true; let mut first: bool = true;
let mut shifts: usize = 0; while toks_id.len() < max_seq_len {
loop { let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches);
if toks_id.len() >= max_seq_len {
toks_id = toks_id[1..].to_vec();
prev_pos -= 1;
caches.shift_left(1);
shifts += 1;
// TODO: it seems that text beyond context is just broken.
// Maybe I cannot just go and shift it.
}
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches, shifts);
let highest_pred_idx = token_sampler.sample(&preds); let highest_pred_idx = token_sampler.sample(&preds);
toks_id.push(highest_pred_idx as TokenId); toks_id.push(highest_pred_idx as TokenId);
@ -166,14 +187,14 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
if first && tok_idx < toks_id.len() - 1 { if first && tok_idx < toks_id.len() - 1 {
// intentionally left empty // intentionally left empty
} else if shifts == 0 {
print!("{}", tok_str.truecolor(128, 255, 128));
} else { } else {
print!("{}", tok_str.truecolor(255, 128, 128)); print!("{}", tok_str.truecolor(128, 255, 128));
} }
} }
let _ = std::io::stdout().flush(); let _ = std::io::stdout().flush();
prev_pos = toks_id.len() - 1; prev_pos = toks_id.len() - 1;
first = false; first = false;
} }
println!("");
Ok(())
} }

@ -121,10 +121,14 @@ impl Transformer {
n_heads: usize, n_heads: usize,
max_seq_len: usize, max_seq_len: usize,
eps: f64, eps: f64,
n_local_heads: usize,
head_dim: usize,
data_dir: P, data_dir: P,
) -> Result<Transformer, UnpicklingError> { ) -> Result<Transformer, UnpicklingError> {
assert_eq!(dim % n_heads, 0);
let head_dim = dim / n_heads;
let n_local_heads = n_heads; // I think the local heads is an artifact of the original
// implementation that used multi-GPU in the Facebook repo.
// Should delete it later.
let data_dir: &Path = data_dir.as_ref(); let data_dir: &Path = data_dir.as_ref();
let progress_bar = ProgressBar::new(n_layers as u64); let progress_bar = ProgressBar::new(n_layers as u64);
@ -149,9 +153,7 @@ impl Transformer {
let output = Tensor::from_unpickled(unpickled, "output.weight", data_dir)?.to_f32(); let output = Tensor::from_unpickled(unpickled, "output.weight", data_dir)?.to_f32();
Ok(Transformer { Ok(Transformer {
// TODO: maybe rotary embedding can be computed on the fly if the sequence gets really long. I just freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len, 10000.0),
// slapped * 20 on max seq len here.
freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len * 20, 10000.0),
emb, emb,
dim, dim,
n_layers, n_layers,
@ -186,7 +188,6 @@ impl Transformer {
tokens: &[TokenId], tokens: &[TokenId],
start_pos: usize, start_pos: usize,
caches: &mut TransformerCaches, caches: &mut TransformerCaches,
shifts: usize,
) -> Tensor { ) -> Tensor {
assert!(caches.layer_caches.len() == self.n_layers); assert!(caches.layer_caches.len() == self.n_layers);
let mask: Option<Tensor> = if tokens.len() > 1 { let mask: Option<Tensor> = if tokens.len() > 1 {
@ -215,7 +216,6 @@ impl Transformer {
&self.freqs_cis, &self.freqs_cis,
&mask, &mask,
&mut caches.layer_caches[idx], &mut caches.layer_caches[idx],
shifts,
); );
} }
let out = self.norm.forward(&emb_tensor); let out = self.norm.forward(&emb_tensor);
@ -265,17 +265,11 @@ impl TransformerBlock {
freqs_cis: &FreqsCis, freqs_cis: &FreqsCis,
mask: &Option<Tensor>, mask: &Option<Tensor>,
attention_cache: &mut AttentionCache, attention_cache: &mut AttentionCache,
shifts: usize,
) -> Tensor { ) -> Tensor {
let attnorm_out = self.attention_norm.forward(x); let attnorm_out = self.attention_norm.forward(x);
let att_out = self.attn.forward( let att_out = self
&attnorm_out, .attn
start_pos, .forward(&attnorm_out, start_pos, freqs_cis, mask, attention_cache);
freqs_cis,
mask,
attention_cache,
shifts,
);
let h = x.add(&att_out); let h = x.add(&att_out);
let att_out = self.ffn_norm.forward(&h); let att_out = self.ffn_norm.forward(&h);
let att_out = self.feed_forward.forward(&att_out).transpose(); let att_out = self.feed_forward.forward(&att_out).transpose();
@ -404,7 +398,6 @@ impl Attention {
freqs_cis: &FreqsCis, freqs_cis: &FreqsCis,
mask: &Option<Tensor>, mask: &Option<Tensor>,
attention_cache: &mut AttentionCache, attention_cache: &mut AttentionCache,
shifts: usize,
) -> Tensor { ) -> Tensor {
let seq_len = x.rows(); let seq_len = x.rows();
let (xq_out, (xk_out, xv_out)) = rayon::join( let (xq_out, (xk_out, xv_out)) = rayon::join(
@ -432,13 +425,8 @@ impl Attention {
.row(idx) .row(idx)
.view(self.n_local_heads as i64, self.head_dim as i64); .view(self.n_local_heads as i64, self.head_dim as i64);
let (xq_row, xk_row) = apply_rotary_emb( let (xq_row, xk_row) =
&xq_row, apply_rotary_emb(&xq_row, &xk_row, freqs_cis, idx as usize, start_pos);
&xk_row,
freqs_cis,
idx as usize,
start_pos + shifts,
);
xq_views.push(xq_row); xq_views.push(xq_row);
xk_views.push(xk_row); xk_views.push(xk_row);

Loading…
Cancel
Save