Make the output colored. This is essential to be taken seriously.

Also did some clippy happiness changes.
broken-opencl-code
Mikko Juola 3 years ago
parent cd28aba5e2
commit f103871bc0

@ -4,6 +4,7 @@ use crate::tokenizer::{TokenId, Tokenizer};
use crate::transformer::Transformer;
use crate::unpickler;
use clap::Parser;
use colored::Colorize;
use std::io::{Read, Write};
#[derive(Parser)]
@ -31,13 +32,27 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let model_path = cli.model_path;
let tokenizer_path = cli.tokenizer_path;
let mut be_quiet: bool = false;
if !colored::control::SHOULD_COLORIZE.should_colorize() {
be_quiet = true;
}
// Custom println-like macro that respects be_quiet
macro_rules! pln {
($($arg:tt)*) => {
if !be_quiet {
std::println!($($arg)*);
}
};
}
let prompt: String = match (cli.prompt, cli.prompt_file) {
(Some(prompt), None) => {
println!("Using prompt: {}", prompt);
pln!("Using prompt: {}", prompt);
prompt
}
(None, Some(prompt_file)) => {
println!("Using prompt file: {}", prompt_file);
pln!("Using prompt file: {}", prompt_file);
let mut fs = std::fs::File::open(prompt_file)?;
let mut bs = Vec::new();
fs.read_to_end(&mut bs)?;
@ -45,14 +60,14 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
String::from_utf8(bs)?
}
_ => {
println!("Please provide either a prompt or a prompt file.");
return Ok(());
eprintln!("Please provide either a prompt or a prompt file.");
return Err("Please provide either a prompt or a prompt file.".into());
}
};
println!("Starting up. Loading tokenizer from {}...", tokenizer_path);
pln!("Starting up. Loading tokenizer from {}...", tokenizer_path);
let tok = Tokenizer::load(tokenizer_path.as_str())?;
println!("Tokenizer loaded. Loading model from {}...", model_path);
pln!("Tokenizer loaded. Loading model from {}...", model_path);
let mut fs = std::fs::File::open(model_path.as_str())?;
let mut bs = Vec::new();
fs.read_to_end(&mut bs)?;
@ -66,25 +81,27 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
.join("/")
+ "/data/";
let result = unpickler::unpickle(&bs)?;
println!("Loading embeddings from {}...", model_data_dir);
pln!("Loading embeddings from {}...", model_data_dir);
let emb = Embedding::from_unpickled(&result, model_data_dir.clone())?;
println!("Loading transformer weights from {}...", model_data_dir);
let max_seq_len = 512;
pln!("Loading transformer weights from {}...", model_data_dir);
let tr = Transformer::from_unpickled(
&result,
emb,
4096,
32,
32,
512,
max_seq_len,
1e-6,
32,
128,
model_data_dir,
)?;
println!("All is loaded. Starting inference.");
pln!("All is loaded. Starting inference.");
let mut toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt);
let mut toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt.clone());
let mut prev_pos = 0;
let mut token_sampler = TokenSampler::new().temperature(0.8).top_p(0.9).top_k(50);
@ -98,30 +115,65 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
token_sampler = token_sampler.top_k(top_k as usize);
}
println!("Temperature: {}", token_sampler.get_temperature());
println!("Top P: {}", token_sampler.get_top_p());
println!("Top K: {}", token_sampler.get_top_k());
pln!("---");
pln!("Temperature: {}", token_sampler.get_temperature());
pln!("Top P: {}", token_sampler.get_top_p());
pln!("Top K: {}", token_sampler.get_top_k());
pln!("---");
pln!(
"{}",
" This is the color of the initial prompt".truecolor(128, 128, 255)
);
pln!(
"{}",
" This is the color of the generated text while full context is available"
.truecolor(128, 255, 128)
);
pln!(
"{}",
" Remaining text is in this color".truecolor(255, 128, 128)
);
pln!("---");
print!("{}", prompt.as_str().truecolor(128, 128, 255));
let _ = std::io::stdout().flush();
let mut caches = tr.make_caches();
let mut first: bool = true;
let mut shifts: usize = 0;
loop {
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches);
if toks_id.len() >= max_seq_len {
toks_id = toks_id[1..].to_vec();
prev_pos -= 1;
caches.shift_left(1);
shifts += 1;
// TODO: it seems that text beyond context is just broken.
// Maybe I cannot just go and shift it.
}
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches, shifts);
let highest_pred_idx = token_sampler.sample(&preds);
toks_id.push(highest_pred_idx as TokenId);
let mut tok_str: String = "".to_string();
for tok_id in toks_id[prev_pos + 1..].iter() {
for (tok_idx, tok_id) in toks_id[prev_pos + 1..].iter().enumerate() {
if *tok_id == 1 {
continue;
}
let mut tok_str: String = "".to_string();
let tok = tok.id_to_str(*tok_id);
if tok == "<0x0A>" {
tok_str += "\n";
} else {
tok_str += tok.replace('▁', " ").as_str();
}
if first && tok_idx < toks_id.len() - 1 {
// intentionally left empty
} else if shifts == 0 {
print!("{}", tok_str.truecolor(128, 255, 128));
} else {
print!("{}", tok_str.truecolor(255, 128, 128));
}
}
print!("{}", tok_str);
let _ = std::io::stdout().flush();
prev_pos = toks_id.len() - 1;
first = false;
}
}

@ -865,8 +865,8 @@ impl Tensor {
unsafe {
let result = Tensor::uninitialized(self.rows, 1, self.dtype);
let capacity_cols: i64 = self.capacity_cols as i64;
let result_capacity_cols = result.capacity_cols as i64;
let capacity_cols: i64 = self.capacity_cols;
let result_capacity_cols: i64 = result.capacity_cols;
let col_its: usize = if self.cols % 8 == 0 {
(self.cols / 8) as usize
} else {
@ -902,6 +902,8 @@ impl Tensor {
}
// Computes matrix multiplication assuming left side has number of rows as 1
#[allow(clippy::erasing_op)]
#[allow(clippy::identity_op)]
pub fn vector_matrix_mul(&self, other: &Tensor) -> Tensor {
if self.cols != other.rows {
panic!(
@ -938,6 +940,9 @@ impl Tensor {
let row = row8 * 8;
let left = _mm256_loadu_ps(left_data.add(row));
let mut r = [0.0f32; 8];
// i hate you clippy because you ask me
// to make code more unreadable
#[allow(clippy::needless_range_loop)]
for i in 0..8 {
if row + i < other.rows as usize {
r[i] = *right_data.add((row + i) * other_capacity_cols + col);

@ -62,6 +62,33 @@ impl AttentionCache {
}
AttentionCache { cache_k, cache_v }
}
fn shift_left(&mut self, shifts: usize) {
for _ in 0..shifts {
for idx in 0..self.cache_k.len() {
let mut k = self.cache_k[idx].write().unwrap();
let mut v = self.cache_v[idx].write().unwrap();
let k_rows = k.rows();
let k_cols = k.cols();
for head_idx in 0..k_rows {
for seq_idx in 0..k_cols - 1 {
let kval = k.get_f32(head_idx, seq_idx + 1);
let vval = v.get_f32(head_idx, seq_idx + 1);
k.set_f32(head_idx, seq_idx, kval);
v.set_f32(head_idx, seq_idx, vval);
}
}
}
}
}
}
impl TransformerCaches {
pub fn shift_left(&mut self, shifts: usize) {
for layer in self.layer_caches.iter_mut() {
layer.shift_left(shifts);
}
}
}
pub struct RMSNorm {
@ -122,7 +149,9 @@ impl Transformer {
let output = Tensor::from_unpickled(unpickled, "output.weight", data_dir)?.to_f32();
Ok(Transformer {
freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len * 2, 10000.0),
// TODO: maybe rotary embedding can be computed on the fly if the sequence gets really long. I just
// slapped * 20 on max seq len here.
freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len * 20, 10000.0),
emb,
dim,
n_layers,
@ -157,6 +186,7 @@ impl Transformer {
tokens: &[TokenId],
start_pos: usize,
caches: &mut TransformerCaches,
shifts: usize,
) -> Tensor {
assert!(caches.layer_caches.len() == self.n_layers);
let mask: Option<Tensor> = if tokens.len() > 1 {
@ -185,6 +215,7 @@ impl Transformer {
&self.freqs_cis,
&mask,
&mut caches.layer_caches[idx],
shifts,
);
}
let out = self.norm.forward(&emb_tensor);
@ -234,11 +265,17 @@ impl TransformerBlock {
freqs_cis: &FreqsCis,
mask: &Option<Tensor>,
attention_cache: &mut AttentionCache,
shifts: usize,
) -> Tensor {
let attnorm_out = self.attention_norm.forward(x);
let att_out = self
.attn
.forward(&attnorm_out, start_pos, freqs_cis, mask, attention_cache);
let att_out = self.attn.forward(
&attnorm_out,
start_pos,
freqs_cis,
mask,
attention_cache,
shifts,
);
let h = x.add(&att_out);
let att_out = self.ffn_norm.forward(&h);
let att_out = self.feed_forward.forward(&att_out).transpose();
@ -300,8 +337,8 @@ impl FeedForward {
pub fn forward(&self, x: &Tensor) -> Tensor {
let (w1_out, w3_out) = rayon::join(
|| self.w1.matrix_mul_transposed(&x),
|| self.w3.matrix_mul_transposed(&x),
|| self.w1.matrix_mul_transposed(x),
|| self.w3.matrix_mul_transposed(x),
);
let w1_out = w1_out.silu();
let w1w3_out = w1_out.hadamard_product(&w3_out).transpose();
@ -367,6 +404,7 @@ impl Attention {
freqs_cis: &FreqsCis,
mask: &Option<Tensor>,
attention_cache: &mut AttentionCache,
shifts: usize,
) -> Tensor {
let seq_len = x.rows();
let (xq_out, (xk_out, xv_out)) = rayon::join(
@ -394,8 +432,13 @@ impl Attention {
.row(idx)
.view(self.n_local_heads as i64, self.head_dim as i64);
let (xq_row, xk_row) =
apply_rotary_emb(&xq_row, &xk_row, freqs_cis, idx as usize, start_pos);
let (xq_row, xk_row) = apply_rotary_emb(
&xq_row,
&xk_row,
freqs_cis,
idx as usize,
start_pos + shifts,
);
xq_views.push(xq_row);
xk_views.push(xk_row);
@ -464,8 +507,7 @@ impl Attention {
}
let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect();
let xq_row = Tensor::concat(&concat_vec2).view(1, 4096);
let result = xq_row.matrix_mul_transposed(&self.wo);
result
xq_row.matrix_mul_transposed(&self.wo)
})
.collect();

Loading…
Cancel
Save