Make the output colored. This is essential to be taken seriously.

Also did some clippy happiness changes.
broken-opencl-code
Mikko Juola 3 years ago
parent cd28aba5e2
commit f103871bc0

@ -4,6 +4,7 @@ use crate::tokenizer::{TokenId, Tokenizer};
use crate::transformer::Transformer; use crate::transformer::Transformer;
use crate::unpickler; use crate::unpickler;
use clap::Parser; use clap::Parser;
use colored::Colorize;
use std::io::{Read, Write}; use std::io::{Read, Write};
#[derive(Parser)] #[derive(Parser)]
@ -31,13 +32,27 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let model_path = cli.model_path; let model_path = cli.model_path;
let tokenizer_path = cli.tokenizer_path; let tokenizer_path = cli.tokenizer_path;
let mut be_quiet: bool = false;
if !colored::control::SHOULD_COLORIZE.should_colorize() {
be_quiet = true;
}
// Custom println-like macro that respects be_quiet
macro_rules! pln {
($($arg:tt)*) => {
if !be_quiet {
std::println!($($arg)*);
}
};
}
let prompt: String = match (cli.prompt, cli.prompt_file) { let prompt: String = match (cli.prompt, cli.prompt_file) {
(Some(prompt), None) => { (Some(prompt), None) => {
println!("Using prompt: {}", prompt); pln!("Using prompt: {}", prompt);
prompt prompt
} }
(None, Some(prompt_file)) => { (None, Some(prompt_file)) => {
println!("Using prompt file: {}", prompt_file); pln!("Using prompt file: {}", prompt_file);
let mut fs = std::fs::File::open(prompt_file)?; let mut fs = std::fs::File::open(prompt_file)?;
let mut bs = Vec::new(); let mut bs = Vec::new();
fs.read_to_end(&mut bs)?; fs.read_to_end(&mut bs)?;
@ -45,14 +60,14 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
String::from_utf8(bs)? String::from_utf8(bs)?
} }
_ => { _ => {
println!("Please provide either a prompt or a prompt file."); eprintln!("Please provide either a prompt or a prompt file.");
return Ok(()); return Err("Please provide either a prompt or a prompt file.".into());
} }
}; };
println!("Starting up. Loading tokenizer from {}...", tokenizer_path); pln!("Starting up. Loading tokenizer from {}...", tokenizer_path);
let tok = Tokenizer::load(tokenizer_path.as_str())?; let tok = Tokenizer::load(tokenizer_path.as_str())?;
println!("Tokenizer loaded. Loading model from {}...", model_path); pln!("Tokenizer loaded. Loading model from {}...", model_path);
let mut fs = std::fs::File::open(model_path.as_str())?; let mut fs = std::fs::File::open(model_path.as_str())?;
let mut bs = Vec::new(); let mut bs = Vec::new();
fs.read_to_end(&mut bs)?; fs.read_to_end(&mut bs)?;
@ -66,25 +81,27 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
.join("/") .join("/")
+ "/data/"; + "/data/";
let result = unpickler::unpickle(&bs)?; let result = unpickler::unpickle(&bs)?;
println!("Loading embeddings from {}...", model_data_dir); pln!("Loading embeddings from {}...", model_data_dir);
let emb = Embedding::from_unpickled(&result, model_data_dir.clone())?; let emb = Embedding::from_unpickled(&result, model_data_dir.clone())?;
println!("Loading transformer weights from {}...", model_data_dir); let max_seq_len = 512;
pln!("Loading transformer weights from {}...", model_data_dir);
let tr = Transformer::from_unpickled( let tr = Transformer::from_unpickled(
&result, &result,
emb, emb,
4096, 4096,
32, 32,
32, 32,
512, max_seq_len,
1e-6, 1e-6,
32, 32,
128, 128,
model_data_dir, model_data_dir,
)?; )?;
println!("All is loaded. Starting inference."); pln!("All is loaded. Starting inference.");
let mut toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt); let mut toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt.clone());
let mut prev_pos = 0; let mut prev_pos = 0;
let mut token_sampler = TokenSampler::new().temperature(0.8).top_p(0.9).top_k(50); let mut token_sampler = TokenSampler::new().temperature(0.8).top_p(0.9).top_k(50);
@ -98,30 +115,65 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
token_sampler = token_sampler.top_k(top_k as usize); token_sampler = token_sampler.top_k(top_k as usize);
} }
println!("Temperature: {}", token_sampler.get_temperature()); pln!("---");
println!("Top P: {}", token_sampler.get_top_p()); pln!("Temperature: {}", token_sampler.get_temperature());
println!("Top K: {}", token_sampler.get_top_k()); pln!("Top P: {}", token_sampler.get_top_p());
pln!("Top K: {}", token_sampler.get_top_k());
pln!("---");
pln!(
"{}",
" This is the color of the initial prompt".truecolor(128, 128, 255)
);
pln!(
"{}",
" This is the color of the generated text while full context is available"
.truecolor(128, 255, 128)
);
pln!(
"{}",
" Remaining text is in this color".truecolor(255, 128, 128)
);
pln!("---");
print!("{}", prompt.as_str().truecolor(128, 128, 255));
let _ = std::io::stdout().flush();
let mut caches = tr.make_caches(); let mut caches = tr.make_caches();
let mut first: bool = true;
let mut shifts: usize = 0;
loop { loop {
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches); if toks_id.len() >= max_seq_len {
toks_id = toks_id[1..].to_vec();
prev_pos -= 1;
caches.shift_left(1);
shifts += 1;
// TODO: it seems that text beyond context is just broken.
// Maybe I cannot just go and shift it.
}
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches, shifts);
let highest_pred_idx = token_sampler.sample(&preds); let highest_pred_idx = token_sampler.sample(&preds);
toks_id.push(highest_pred_idx as TokenId); toks_id.push(highest_pred_idx as TokenId);
let mut tok_str: String = "".to_string(); for (tok_idx, tok_id) in toks_id[prev_pos + 1..].iter().enumerate() {
for tok_id in toks_id[prev_pos + 1..].iter() {
if *tok_id == 1 { if *tok_id == 1 {
continue; continue;
} }
let mut tok_str: String = "".to_string();
let tok = tok.id_to_str(*tok_id); let tok = tok.id_to_str(*tok_id);
if tok == "<0x0A>" { if tok == "<0x0A>" {
tok_str += "\n"; tok_str += "\n";
} else { } else {
tok_str += tok.replace('▁', " ").as_str(); tok_str += tok.replace('▁', " ").as_str();
} }
if first && tok_idx < toks_id.len() - 1 {
// intentionally left empty
} else if shifts == 0 {
print!("{}", tok_str.truecolor(128, 255, 128));
} else {
print!("{}", tok_str.truecolor(255, 128, 128));
}
} }
print!("{}", tok_str);
let _ = std::io::stdout().flush(); let _ = std::io::stdout().flush();
prev_pos = toks_id.len() - 1; prev_pos = toks_id.len() - 1;
first = false;
} }
} }

@ -865,8 +865,8 @@ impl Tensor {
unsafe { unsafe {
let result = Tensor::uninitialized(self.rows, 1, self.dtype); let result = Tensor::uninitialized(self.rows, 1, self.dtype);
let capacity_cols: i64 = self.capacity_cols as i64; let capacity_cols: i64 = self.capacity_cols;
let result_capacity_cols = result.capacity_cols as i64; let result_capacity_cols: i64 = result.capacity_cols;
let col_its: usize = if self.cols % 8 == 0 { let col_its: usize = if self.cols % 8 == 0 {
(self.cols / 8) as usize (self.cols / 8) as usize
} else { } else {
@ -902,6 +902,8 @@ impl Tensor {
} }
// Computes matrix multiplication assuming left side has number of rows as 1 // Computes matrix multiplication assuming left side has number of rows as 1
#[allow(clippy::erasing_op)]
#[allow(clippy::identity_op)]
pub fn vector_matrix_mul(&self, other: &Tensor) -> Tensor { pub fn vector_matrix_mul(&self, other: &Tensor) -> Tensor {
if self.cols != other.rows { if self.cols != other.rows {
panic!( panic!(
@ -938,6 +940,9 @@ impl Tensor {
let row = row8 * 8; let row = row8 * 8;
let left = _mm256_loadu_ps(left_data.add(row)); let left = _mm256_loadu_ps(left_data.add(row));
let mut r = [0.0f32; 8]; let mut r = [0.0f32; 8];
// i hate you clippy because you ask me
// to make code more unreadable
#[allow(clippy::needless_range_loop)]
for i in 0..8 { for i in 0..8 {
if row + i < other.rows as usize { if row + i < other.rows as usize {
r[i] = *right_data.add((row + i) * other_capacity_cols + col); r[i] = *right_data.add((row + i) * other_capacity_cols + col);

@ -62,6 +62,33 @@ impl AttentionCache {
} }
AttentionCache { cache_k, cache_v } AttentionCache { cache_k, cache_v }
} }
fn shift_left(&mut self, shifts: usize) {
for _ in 0..shifts {
for idx in 0..self.cache_k.len() {
let mut k = self.cache_k[idx].write().unwrap();
let mut v = self.cache_v[idx].write().unwrap();
let k_rows = k.rows();
let k_cols = k.cols();
for head_idx in 0..k_rows {
for seq_idx in 0..k_cols - 1 {
let kval = k.get_f32(head_idx, seq_idx + 1);
let vval = v.get_f32(head_idx, seq_idx + 1);
k.set_f32(head_idx, seq_idx, kval);
v.set_f32(head_idx, seq_idx, vval);
}
}
}
}
}
}
impl TransformerCaches {
pub fn shift_left(&mut self, shifts: usize) {
for layer in self.layer_caches.iter_mut() {
layer.shift_left(shifts);
}
}
} }
pub struct RMSNorm { pub struct RMSNorm {
@ -122,7 +149,9 @@ impl Transformer {
let output = Tensor::from_unpickled(unpickled, "output.weight", data_dir)?.to_f32(); let output = Tensor::from_unpickled(unpickled, "output.weight", data_dir)?.to_f32();
Ok(Transformer { Ok(Transformer {
freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len * 2, 10000.0), // TODO: maybe rotary embedding can be computed on the fly if the sequence gets really long. I just
// slapped * 20 on max seq len here.
freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len * 20, 10000.0),
emb, emb,
dim, dim,
n_layers, n_layers,
@ -157,6 +186,7 @@ impl Transformer {
tokens: &[TokenId], tokens: &[TokenId],
start_pos: usize, start_pos: usize,
caches: &mut TransformerCaches, caches: &mut TransformerCaches,
shifts: usize,
) -> Tensor { ) -> Tensor {
assert!(caches.layer_caches.len() == self.n_layers); assert!(caches.layer_caches.len() == self.n_layers);
let mask: Option<Tensor> = if tokens.len() > 1 { let mask: Option<Tensor> = if tokens.len() > 1 {
@ -185,6 +215,7 @@ impl Transformer {
&self.freqs_cis, &self.freqs_cis,
&mask, &mask,
&mut caches.layer_caches[idx], &mut caches.layer_caches[idx],
shifts,
); );
} }
let out = self.norm.forward(&emb_tensor); let out = self.norm.forward(&emb_tensor);
@ -234,11 +265,17 @@ impl TransformerBlock {
freqs_cis: &FreqsCis, freqs_cis: &FreqsCis,
mask: &Option<Tensor>, mask: &Option<Tensor>,
attention_cache: &mut AttentionCache, attention_cache: &mut AttentionCache,
shifts: usize,
) -> Tensor { ) -> Tensor {
let attnorm_out = self.attention_norm.forward(x); let attnorm_out = self.attention_norm.forward(x);
let att_out = self let att_out = self.attn.forward(
.attn &attnorm_out,
.forward(&attnorm_out, start_pos, freqs_cis, mask, attention_cache); start_pos,
freqs_cis,
mask,
attention_cache,
shifts,
);
let h = x.add(&att_out); let h = x.add(&att_out);
let att_out = self.ffn_norm.forward(&h); let att_out = self.ffn_norm.forward(&h);
let att_out = self.feed_forward.forward(&att_out).transpose(); let att_out = self.feed_forward.forward(&att_out).transpose();
@ -300,8 +337,8 @@ impl FeedForward {
pub fn forward(&self, x: &Tensor) -> Tensor { pub fn forward(&self, x: &Tensor) -> Tensor {
let (w1_out, w3_out) = rayon::join( let (w1_out, w3_out) = rayon::join(
|| self.w1.matrix_mul_transposed(&x), || self.w1.matrix_mul_transposed(x),
|| self.w3.matrix_mul_transposed(&x), || self.w3.matrix_mul_transposed(x),
); );
let w1_out = w1_out.silu(); let w1_out = w1_out.silu();
let w1w3_out = w1_out.hadamard_product(&w3_out).transpose(); let w1w3_out = w1_out.hadamard_product(&w3_out).transpose();
@ -367,6 +404,7 @@ impl Attention {
freqs_cis: &FreqsCis, freqs_cis: &FreqsCis,
mask: &Option<Tensor>, mask: &Option<Tensor>,
attention_cache: &mut AttentionCache, attention_cache: &mut AttentionCache,
shifts: usize,
) -> Tensor { ) -> Tensor {
let seq_len = x.rows(); let seq_len = x.rows();
let (xq_out, (xk_out, xv_out)) = rayon::join( let (xq_out, (xk_out, xv_out)) = rayon::join(
@ -394,8 +432,13 @@ impl Attention {
.row(idx) .row(idx)
.view(self.n_local_heads as i64, self.head_dim as i64); .view(self.n_local_heads as i64, self.head_dim as i64);
let (xq_row, xk_row) = let (xq_row, xk_row) = apply_rotary_emb(
apply_rotary_emb(&xq_row, &xk_row, freqs_cis, idx as usize, start_pos); &xq_row,
&xk_row,
freqs_cis,
idx as usize,
start_pos + shifts,
);
xq_views.push(xq_row); xq_views.push(xq_row);
xk_views.push(xk_row); xk_views.push(xk_row);
@ -464,8 +507,7 @@ impl Attention {
} }
let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect(); let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect();
let xq_row = Tensor::concat(&concat_vec2).view(1, 4096); let xq_row = Tensor::concat(&concat_vec2).view(1, 4096);
let result = xq_row.matrix_mul_transposed(&self.wo); xq_row.matrix_mul_transposed(&self.wo)
result
}) })
.collect(); .collect();

Loading…
Cancel
Save