Read parameters from model's JSON file instead of hard-coding them, make max sequence length configurable.

broken-opencl-code
Mikko Juola 3 years ago
parent f103871bc0
commit 18ef805458

2
Cargo.lock generated

@ -800,6 +800,8 @@ dependencies = [
"protobuf-parse",
"rand",
"rayon",
"serde",
"serde_json",
"thiserror",
]

@ -22,6 +22,8 @@ rayon = "1.7"
clap = { version = "4.1", features = ["derive"] }
indicatif = "0.17"
colored = "2"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
# We need protobuf compiler
[build-dependencies]

@ -5,6 +5,7 @@ use crate::transformer::Transformer;
use crate::unpickler;
use clap::Parser;
use colored::Colorize;
use serde::{Deserialize, Serialize};
use std::io::{Read, Write};
#[derive(Parser)]
@ -14,11 +15,17 @@ struct Cli {
model_path: String,
#[arg(long)]
tokenizer_path: String,
#[arg(long)]
param_path: String,
#[arg(long)]
prompt: Option<String>,
#[arg(long)]
prompt_file: Option<String>,
#[arg(long)]
max_seq_len: Option<usize>,
#[arg(long)]
temperature: Option<f32>,
#[arg(long)]
@ -27,10 +34,21 @@ struct Cli {
top_k: Option<i32>,
}
#[derive(Clone, Serialize, Deserialize)]
struct ModelParams {
dim: usize,
multiple_of: usize,
n_heads: usize,
n_layers: usize,
norm_eps: f64,
vocab_size: i64,
}
pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse();
let model_path = cli.model_path;
let tokenizer_path = cli.tokenizer_path;
let param_path = cli.param_path;
let mut be_quiet: bool = false;
if !colored::control::SHOULD_COLORIZE.should_colorize() {
@ -46,6 +64,14 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
};
}
// Read ModelParams from param_path, we expect it to be JSON
let mut fs = std::fs::File::open(&param_path)?;
let mut bs = Vec::new();
fs.read_to_end(&mut bs)?;
std::mem::drop(fs);
let params: ModelParams = serde_json::from_slice(&bs)?;
pln!("Loaded model parameters from {}.", param_path);
let prompt: String = match (cli.prompt, cli.prompt_file) {
(Some(prompt), None) => {
pln!("Using prompt: {}", prompt);
@ -84,19 +110,20 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
pln!("Loading embeddings from {}...", model_data_dir);
let emb = Embedding::from_unpickled(&result, model_data_dir.clone())?;
let max_seq_len = 512;
let max_seq_len = match cli.max_seq_len {
Some(max_seq_len) => max_seq_len,
None => 1024,
};
pln!("Loading transformer weights from {}...", model_data_dir);
let tr = Transformer::from_unpickled(
&result,
emb,
4096,
32,
32,
params.dim,
params.n_layers,
params.n_heads,
max_seq_len,
1e-6,
32,
128,
params.norm_eps,
model_data_dir,
)?;
pln!("All is loaded. Starting inference.");
@ -116,6 +143,14 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
}
pln!("---");
pln!(" dim: {}", params.dim);
pln!(" multiple_of: {}", params.multiple_of);
pln!(" n_heads: {}", params.n_heads);
pln!(" n_layers: {}", params.n_layers);
pln!(" norm_eps: {}", params.norm_eps);
pln!(" vocab_size: {}", params.vocab_size);
pln!("---");
pln!("Max sequence length: {}", max_seq_len);
pln!("Temperature: {}", token_sampler.get_temperature());
pln!("Top P: {}", token_sampler.get_top_p());
pln!("Top K: {}", token_sampler.get_top_k());
@ -126,12 +161,7 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
);
pln!(
"{}",
" This is the color of the generated text while full context is available"
.truecolor(128, 255, 128)
);
pln!(
"{}",
" Remaining text is in this color".truecolor(255, 128, 128)
" This is the color of the generated text".truecolor(128, 255, 128)
);
pln!("---");
print!("{}", prompt.as_str().truecolor(128, 128, 255));
@ -139,17 +169,8 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut caches = tr.make_caches();
let mut first: bool = true;
let mut shifts: usize = 0;
loop {
if toks_id.len() >= max_seq_len {
toks_id = toks_id[1..].to_vec();
prev_pos -= 1;
caches.shift_left(1);
shifts += 1;
// TODO: it seems that text beyond context is just broken.
// Maybe I cannot just go and shift it.
}
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches, shifts);
while toks_id.len() < max_seq_len {
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches);
let highest_pred_idx = token_sampler.sample(&preds);
toks_id.push(highest_pred_idx as TokenId);
@ -166,14 +187,14 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
}
if first && tok_idx < toks_id.len() - 1 {
// intentionally left empty
} else if shifts == 0 {
print!("{}", tok_str.truecolor(128, 255, 128));
} else {
print!("{}", tok_str.truecolor(255, 128, 128));
print!("{}", tok_str.truecolor(128, 255, 128));
}
}
let _ = std::io::stdout().flush();
prev_pos = toks_id.len() - 1;
first = false;
}
println!("");
Ok(())
}

@ -121,10 +121,14 @@ impl Transformer {
n_heads: usize,
max_seq_len: usize,
eps: f64,
n_local_heads: usize,
head_dim: usize,
data_dir: P,
) -> Result<Transformer, UnpicklingError> {
assert_eq!(dim % n_heads, 0);
let head_dim = dim / n_heads;
let n_local_heads = n_heads; // I think the local heads is an artifact of the original
// implementation that used multi-GPU in the Facebook repo.
// Should delete it later.
let data_dir: &Path = data_dir.as_ref();
let progress_bar = ProgressBar::new(n_layers as u64);
@ -149,9 +153,7 @@ impl Transformer {
let output = Tensor::from_unpickled(unpickled, "output.weight", data_dir)?.to_f32();
Ok(Transformer {
// TODO: maybe rotary embedding can be computed on the fly if the sequence gets really long. I just
// slapped * 20 on max seq len here.
freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len * 20, 10000.0),
freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len, 10000.0),
emb,
dim,
n_layers,
@ -186,7 +188,6 @@ impl Transformer {
tokens: &[TokenId],
start_pos: usize,
caches: &mut TransformerCaches,
shifts: usize,
) -> Tensor {
assert!(caches.layer_caches.len() == self.n_layers);
let mask: Option<Tensor> = if tokens.len() > 1 {
@ -215,7 +216,6 @@ impl Transformer {
&self.freqs_cis,
&mask,
&mut caches.layer_caches[idx],
shifts,
);
}
let out = self.norm.forward(&emb_tensor);
@ -265,17 +265,11 @@ impl TransformerBlock {
freqs_cis: &FreqsCis,
mask: &Option<Tensor>,
attention_cache: &mut AttentionCache,
shifts: usize,
) -> Tensor {
let attnorm_out = self.attention_norm.forward(x);
let att_out = self.attn.forward(
&attnorm_out,
start_pos,
freqs_cis,
mask,
attention_cache,
shifts,
);
let att_out = self
.attn
.forward(&attnorm_out, start_pos, freqs_cis, mask, attention_cache);
let h = x.add(&att_out);
let att_out = self.ffn_norm.forward(&h);
let att_out = self.feed_forward.forward(&att_out).transpose();
@ -404,7 +398,6 @@ impl Attention {
freqs_cis: &FreqsCis,
mask: &Option<Tensor>,
attention_cache: &mut AttentionCache,
shifts: usize,
) -> Tensor {
let seq_len = x.rows();
let (xq_out, (xk_out, xv_out)) = rayon::join(
@ -432,13 +425,8 @@ impl Attention {
.row(idx)
.view(self.n_local_heads as i64, self.head_dim as i64);
let (xq_row, xk_row) = apply_rotary_emb(
&xq_row,
&xk_row,
freqs_cis,
idx as usize,
start_pos + shifts,
);
let (xq_row, xk_row) =
apply_rotary_emb(&xq_row, &xk_row, freqs_cis, idx as usize, start_pos);
xq_views.push(xq_row);
xk_views.push(xk_row);

Loading…
Cancel
Save