Improve matrix multiplication transposed further, this gives around ~10%-20% further increase by improving memory load to instruction ratio.

master
Mikko Juola 3 years ago
parent 61bc42b728
commit a1970b8a9c

@ -152,10 +152,7 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
pln!("Loading embeddings from {}...", model_path);
let emb = Embedding::from_unpickled(&unpickle_results, model_path.clone())?;
let max_seq_len = match cli.max_seq_len {
Some(max_seq_len) => max_seq_len,
None => 1024,
};
let max_seq_len = cli.max_seq_len.unwrap_or(1024);
let data_settings = {
#[cfg(feature = "opencl")]
@ -291,11 +288,9 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
break;
}
}
println!("");
if stop_seen {
if !be_quiet {
println!("Stop token seen. Stopping.");
}
println!();
if stop_seen && !be_quiet {
println!("Stop token seen. Stopping.");
}
if !be_quiet {
println!("---");

@ -1072,6 +1072,8 @@ impl Tensor {
let src_cols: usize = src.cols as usize;
let self_rows: usize = self.rows as usize;
let self_cols: usize = self.cols as usize;
let _other_cols: usize = other.cols as usize;
let other_rows: usize = other.rows as usize;
let other_cols_capacity: usize = other.capacity_cols as usize;
let src_cols_capacity: usize = src.capacity_cols as usize;
let self_cols_capacity: usize = self.capacity_cols as usize;
@ -1086,6 +1088,11 @@ impl Tensor {
} else {
self_rows / 4 + 1
};
let self_cols_its = if self_cols % 4 == 0 {
self_cols / 4
} else {
self_cols / 4 + 1
};
unsafe {
for row in 0..row_its {
@ -1093,18 +1100,66 @@ impl Tensor {
let row1 = row * 4 + 1;
let row2 = row * 4 + 2;
let row3 = row * 4 + 3;
for col in 0..self_cols {
let mut targets8: [__m256; 4] = [
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
for col in 0..self_cols_its {
let col0 = col * 4;
let col1 = col * 4 + 1;
let col2 = col * 4 + 2;
let col3 = col * 4 + 3;
let mut targets8: [[__m256; 4]; 4] = [
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
];
for p in 0..src_cols_its {
let other8: __m256 = _mm256_loadu_ps(
let other8_0: __m256 = _mm256_loadu_ps(
other_data
.add(col * other_cols_capacity + p * ITEMS_PER_CACHE_LINE),
.add(col0 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE),
);
let other8_1: __m256 =
if col1 < other_rows {
_mm256_loadu_ps(other_data.add(
col1 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
} else {
_mm256_setzero_ps()
};
let other8_2: __m256 =
if col2 < other_rows {
_mm256_loadu_ps(other_data.add(
col2 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
} else {
_mm256_setzero_ps()
};
let other8_3: __m256 =
if col3 < other_rows {
_mm256_loadu_ps(other_data.add(
col3 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
} else {
_mm256_setzero_ps()
};
let src8_0: __m256 = _mm256_loadu_ps(
src_data
.add(row0 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE),
@ -1133,24 +1188,61 @@ impl Tensor {
} else {
_mm256_setzero_ps()
};
targets8[0] = _mm256_fmadd_ps(src8_0, other8, targets8[0]);
targets8[1] = _mm256_fmadd_ps(src8_1, other8, targets8[1]);
targets8[2] = _mm256_fmadd_ps(src8_2, other8, targets8[2]);
targets8[3] = _mm256_fmadd_ps(src8_3, other8, targets8[3]);
targets8[0][0] = _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]);
targets8[0][1] = _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]);
targets8[0][2] = _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]);
targets8[0][3] = _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]);
targets8[1][0] = _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]);
targets8[1][1] = _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]);
targets8[1][2] = _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]);
targets8[1][3] = _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]);
targets8[2][0] = _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]);
targets8[2][1] = _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]);
targets8[2][2] = _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]);
targets8[2][3] = _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]);
targets8[3][0] = _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]);
targets8[3][1] = _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]);
targets8[3][2] = _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]);
targets8[3][3] = _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]);
}
let target0: f32 = horizontal_sum(targets8[0]);
let target1: f32 = horizontal_sum(targets8[1]);
let target2: f32 = horizontal_sum(targets8[2]);
let target3: f32 = horizontal_sum(targets8[3]);
*tgt_data.add(row0 * self_cols_capacity + col) = target0;
let target00: f32 = horizontal_sum(targets8[0][0]);
let target01: f32 = horizontal_sum(targets8[0][1]);
let target02: f32 = horizontal_sum(targets8[0][2]);
let target03: f32 = horizontal_sum(targets8[0][3]);
let target10: f32 = horizontal_sum(targets8[1][0]);
let target11: f32 = horizontal_sum(targets8[1][1]);
let target12: f32 = horizontal_sum(targets8[1][2]);
let target13: f32 = horizontal_sum(targets8[1][3]);
let target20: f32 = horizontal_sum(targets8[2][0]);
let target21: f32 = horizontal_sum(targets8[2][1]);
let target22: f32 = horizontal_sum(targets8[2][2]);
let target23: f32 = horizontal_sum(targets8[2][3]);
let target30: f32 = horizontal_sum(targets8[3][0]);
let target31: f32 = horizontal_sum(targets8[3][1]);
let target32: f32 = horizontal_sum(targets8[3][2]);
let target33: f32 = horizontal_sum(targets8[3][3]);
*tgt_data.add(row0 * self_cols_capacity + col0) += target00;
*tgt_data.add(row0 * self_cols_capacity + col1) += target10;
*tgt_data.add(row0 * self_cols_capacity + col2) += target20;
*tgt_data.add(row0 * self_cols_capacity + col3) += target30;
if row1 < self_rows {
*tgt_data.add(row1 * self_cols_capacity + col) = target1;
*tgt_data.add(row1 * self_cols_capacity + col0) += target01;
*tgt_data.add(row1 * self_cols_capacity + col1) += target11;
*tgt_data.add(row1 * self_cols_capacity + col2) += target21;
*tgt_data.add(row1 * self_cols_capacity + col3) += target31;
}
if row2 < self_rows {
*tgt_data.add(row2 * self_cols_capacity + col) = target2;
*tgt_data.add(row2 * self_cols_capacity + col0) += target02;
*tgt_data.add(row2 * self_cols_capacity + col1) += target12;
*tgt_data.add(row2 * self_cols_capacity + col2) += target22;
*tgt_data.add(row2 * self_cols_capacity + col3) += target32;
}
if row3 < self_rows {
*tgt_data.add(row3 * self_cols_capacity + col) = target3;
*tgt_data.add(row3 * self_cols_capacity + col0) += target03;
*tgt_data.add(row3 * self_cols_capacity + col1) += target13;
*tgt_data.add(row3 * self_cols_capacity + col2) += target23;
*tgt_data.add(row3 * self_cols_capacity + col3) += target33;
}
}
}
@ -1222,8 +1314,8 @@ impl Tensor {
];
let self_data: *const f32 = self.data as *const f32;
let other_data: *const f32 = other.data as *const f32;
let tgt_data: *mut f32 = result.data as *mut f32;
let ncols_capacity: usize = result.capacity_cols as usize;
let _tgt_data: *mut f32 = result.data as *mut f32;
let _ncols_capacity: usize = result.capacity_cols as usize;
for row in 0..row_its {
let row: i64 = row as i64;
sum8s[0] = _mm256_setzero_ps();
@ -1670,8 +1762,8 @@ impl Tensor {
let self_data: *const f16 = self.data as *const f16;
let tgt_data: *mut f32 = result.data as *mut f32;
let tgt_capacity_cols = result.capacity_cols as i64;
let self_capacity_cols = self.capacity_cols as i64;
let tgt_capacity_cols = result.capacity_cols;
let self_capacity_cols = self.capacity_cols;
for row in 0..self.rows {
for col in 0..cols_it {
let col = col * 8;
@ -1719,8 +1811,8 @@ impl Tensor {
let result = Tensor::uninitialized(self.rows, self.cols, TensorDType::Float16);
let self_data: *const f32 = self.data as *const f32;
let tgt_data: *mut f16 = result.data as *mut f16;
let tgt_capacity_cols = result.capacity_cols as i64;
let self_capacity_cols = self.capacity_cols as i64;
let tgt_capacity_cols = result.capacity_cols;
let self_capacity_cols = self.capacity_cols;
for row in 0..self.rows {
for col in 0..cols_it {
@ -1973,9 +2065,9 @@ mod tests {
fn mat_mul_transposed_agrees_with_regular_mat_mul() {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let a = rng.gen_range(8..64);
let b = rng.gen_range(8..64);
let r = rng.gen_range(8..64);
let a = rng.gen_range(1..=128);
let b = rng.gen_range(1..=128);
let r = rng.gen_range(1..=128);
// Make matrixes AxR and RxB
let a = Tensor::random(a, r, TensorDType::Float32);
@ -1990,7 +2082,7 @@ mod tests {
for row in 0..c.rows {
for col in 0..c.cols {
assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-5);
assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-3);
}
}
}

@ -68,7 +68,7 @@ impl TokenSampler {
pub fn sample(
&self,
logits: &Tensor,
tokenizer: &Tokenizer,
_tokenizer: &Tokenizer,
existing_tokens: &[TokenId],
) -> (TokenId, f32) {
let mut times_used: BTreeMap<TokenId, usize> = BTreeMap::new();
@ -119,7 +119,7 @@ impl TokenSampler {
None => {
// Sort NaNs to bottom
if b.1.is_nan() {
return std::cmp::Ordering::Less;
std::cmp::Ordering::Less
} else if a.1.is_nan() {
return std::cmp::Ordering::Greater;
} else {

@ -133,7 +133,7 @@ impl Tokenizer {
let mut skip_s: &str = "";
// Specially recognize newline. Otherwise it matches something we don't actually
// want.
if s.starts_with("\n") {
if s.starts_with('\n') {
if self.str_to_id("<0x0A>").is_some() {
best_candidate = "<0x0A>";
best_candidate_len = best_candidate.len();

@ -354,7 +354,7 @@ impl RMSNorm {
let data_dir: &Path = data_dir.as_ref();
let weights = Tensor::from_unpickled_pieces(
&unpickled[0..=0],
name.clone(),
name,
data_dir,
FromPiecesDirection::Rows,
)?
@ -448,11 +448,11 @@ impl FeedForward {
#[cfg(not(feature = "opencl"))]
if w1w3_out.rows() == 1 {
return self
self
.w2
.matrix_vector_mul_transposed_multithreaded(&w1w3_out);
.matrix_vector_mul_transposed_multithreaded(&w1w3_out)
} else {
return self.w2.matrix_mul_transposed(&w1w3_out);
self.w2.matrix_mul_transposed(&w1w3_out)
}
#[cfg(feature = "opencl")]
{

@ -587,7 +587,7 @@ pub fn unpickle(bytes: &[u8]) -> Result<Value, UnpicklingError> {
));
}
let idx = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
match memo.get(&(idx as u32)) {
match memo.get(&{ idx }) {
None => {
return Err(UnpicklingError::UnpicklingError(
"LONG_BINGET index out of range".to_string(),

Loading…
Cancel
Save