Improve matrix multiplication transposed further, this gives around ~10%-20% further increase by improving memory load to instruction ratio.

master
Mikko Juola 3 years ago
parent 61bc42b728
commit a1970b8a9c

@ -152,10 +152,7 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
pln!("Loading embeddings from {}...", model_path); pln!("Loading embeddings from {}...", model_path);
let emb = Embedding::from_unpickled(&unpickle_results, model_path.clone())?; let emb = Embedding::from_unpickled(&unpickle_results, model_path.clone())?;
let max_seq_len = match cli.max_seq_len { let max_seq_len = cli.max_seq_len.unwrap_or(1024);
Some(max_seq_len) => max_seq_len,
None => 1024,
};
let data_settings = { let data_settings = {
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
@ -291,11 +288,9 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
break; break;
} }
} }
println!(""); println!();
if stop_seen { if stop_seen && !be_quiet {
if !be_quiet { println!("Stop token seen. Stopping.");
println!("Stop token seen. Stopping.");
}
} }
if !be_quiet { if !be_quiet {
println!("---"); println!("---");

@ -1072,6 +1072,8 @@ impl Tensor {
let src_cols: usize = src.cols as usize; let src_cols: usize = src.cols as usize;
let self_rows: usize = self.rows as usize; let self_rows: usize = self.rows as usize;
let self_cols: usize = self.cols as usize; let self_cols: usize = self.cols as usize;
let _other_cols: usize = other.cols as usize;
let other_rows: usize = other.rows as usize;
let other_cols_capacity: usize = other.capacity_cols as usize; let other_cols_capacity: usize = other.capacity_cols as usize;
let src_cols_capacity: usize = src.capacity_cols as usize; let src_cols_capacity: usize = src.capacity_cols as usize;
let self_cols_capacity: usize = self.capacity_cols as usize; let self_cols_capacity: usize = self.capacity_cols as usize;
@ -1086,6 +1088,11 @@ impl Tensor {
} else { } else {
self_rows / 4 + 1 self_rows / 4 + 1
}; };
let self_cols_its = if self_cols % 4 == 0 {
self_cols / 4
} else {
self_cols / 4 + 1
};
unsafe { unsafe {
for row in 0..row_its { for row in 0..row_its {
@ -1093,18 +1100,66 @@ impl Tensor {
let row1 = row * 4 + 1; let row1 = row * 4 + 1;
let row2 = row * 4 + 2; let row2 = row * 4 + 2;
let row3 = row * 4 + 3; let row3 = row * 4 + 3;
for col in 0..self_cols { for col in 0..self_cols_its {
let mut targets8: [__m256; 4] = [ let col0 = col * 4;
_mm256_setzero_ps(), let col1 = col * 4 + 1;
_mm256_setzero_ps(), let col2 = col * 4 + 2;
_mm256_setzero_ps(), let col3 = col * 4 + 3;
_mm256_setzero_ps(), let mut targets8: [[__m256; 4]; 4] = [
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
]; ];
for p in 0..src_cols_its { for p in 0..src_cols_its {
let other8: __m256 = _mm256_loadu_ps( let other8_0: __m256 = _mm256_loadu_ps(
other_data other_data
.add(col * other_cols_capacity + p * ITEMS_PER_CACHE_LINE), .add(col0 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE),
); );
let other8_1: __m256 =
if col1 < other_rows {
_mm256_loadu_ps(other_data.add(
col1 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
} else {
_mm256_setzero_ps()
};
let other8_2: __m256 =
if col2 < other_rows {
_mm256_loadu_ps(other_data.add(
col2 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
} else {
_mm256_setzero_ps()
};
let other8_3: __m256 =
if col3 < other_rows {
_mm256_loadu_ps(other_data.add(
col3 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
} else {
_mm256_setzero_ps()
};
let src8_0: __m256 = _mm256_loadu_ps( let src8_0: __m256 = _mm256_loadu_ps(
src_data src_data
.add(row0 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE), .add(row0 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE),
@ -1133,24 +1188,61 @@ impl Tensor {
} else { } else {
_mm256_setzero_ps() _mm256_setzero_ps()
}; };
targets8[0] = _mm256_fmadd_ps(src8_0, other8, targets8[0]); targets8[0][0] = _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]);
targets8[1] = _mm256_fmadd_ps(src8_1, other8, targets8[1]); targets8[0][1] = _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]);
targets8[2] = _mm256_fmadd_ps(src8_2, other8, targets8[2]); targets8[0][2] = _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]);
targets8[3] = _mm256_fmadd_ps(src8_3, other8, targets8[3]); targets8[0][3] = _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]);
targets8[1][0] = _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]);
targets8[1][1] = _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]);
targets8[1][2] = _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]);
targets8[1][3] = _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]);
targets8[2][0] = _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]);
targets8[2][1] = _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]);
targets8[2][2] = _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]);
targets8[2][3] = _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]);
targets8[3][0] = _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]);
targets8[3][1] = _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]);
targets8[3][2] = _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]);
targets8[3][3] = _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]);
} }
let target0: f32 = horizontal_sum(targets8[0]); let target00: f32 = horizontal_sum(targets8[0][0]);
let target1: f32 = horizontal_sum(targets8[1]); let target01: f32 = horizontal_sum(targets8[0][1]);
let target2: f32 = horizontal_sum(targets8[2]); let target02: f32 = horizontal_sum(targets8[0][2]);
let target3: f32 = horizontal_sum(targets8[3]); let target03: f32 = horizontal_sum(targets8[0][3]);
*tgt_data.add(row0 * self_cols_capacity + col) = target0; let target10: f32 = horizontal_sum(targets8[1][0]);
let target11: f32 = horizontal_sum(targets8[1][1]);
let target12: f32 = horizontal_sum(targets8[1][2]);
let target13: f32 = horizontal_sum(targets8[1][3]);
let target20: f32 = horizontal_sum(targets8[2][0]);
let target21: f32 = horizontal_sum(targets8[2][1]);
let target22: f32 = horizontal_sum(targets8[2][2]);
let target23: f32 = horizontal_sum(targets8[2][3]);
let target30: f32 = horizontal_sum(targets8[3][0]);
let target31: f32 = horizontal_sum(targets8[3][1]);
let target32: f32 = horizontal_sum(targets8[3][2]);
let target33: f32 = horizontal_sum(targets8[3][3]);
*tgt_data.add(row0 * self_cols_capacity + col0) += target00;
*tgt_data.add(row0 * self_cols_capacity + col1) += target10;
*tgt_data.add(row0 * self_cols_capacity + col2) += target20;
*tgt_data.add(row0 * self_cols_capacity + col3) += target30;
if row1 < self_rows { if row1 < self_rows {
*tgt_data.add(row1 * self_cols_capacity + col) = target1; *tgt_data.add(row1 * self_cols_capacity + col0) += target01;
*tgt_data.add(row1 * self_cols_capacity + col1) += target11;
*tgt_data.add(row1 * self_cols_capacity + col2) += target21;
*tgt_data.add(row1 * self_cols_capacity + col3) += target31;
} }
if row2 < self_rows { if row2 < self_rows {
*tgt_data.add(row2 * self_cols_capacity + col) = target2; *tgt_data.add(row2 * self_cols_capacity + col0) += target02;
*tgt_data.add(row2 * self_cols_capacity + col1) += target12;
*tgt_data.add(row2 * self_cols_capacity + col2) += target22;
*tgt_data.add(row2 * self_cols_capacity + col3) += target32;
} }
if row3 < self_rows { if row3 < self_rows {
*tgt_data.add(row3 * self_cols_capacity + col) = target3; *tgt_data.add(row3 * self_cols_capacity + col0) += target03;
*tgt_data.add(row3 * self_cols_capacity + col1) += target13;
*tgt_data.add(row3 * self_cols_capacity + col2) += target23;
*tgt_data.add(row3 * self_cols_capacity + col3) += target33;
} }
} }
} }
@ -1222,8 +1314,8 @@ impl Tensor {
]; ];
let self_data: *const f32 = self.data as *const f32; let self_data: *const f32 = self.data as *const f32;
let other_data: *const f32 = other.data as *const f32; let other_data: *const f32 = other.data as *const f32;
let tgt_data: *mut f32 = result.data as *mut f32; let _tgt_data: *mut f32 = result.data as *mut f32;
let ncols_capacity: usize = result.capacity_cols as usize; let _ncols_capacity: usize = result.capacity_cols as usize;
for row in 0..row_its { for row in 0..row_its {
let row: i64 = row as i64; let row: i64 = row as i64;
sum8s[0] = _mm256_setzero_ps(); sum8s[0] = _mm256_setzero_ps();
@ -1670,8 +1762,8 @@ impl Tensor {
let self_data: *const f16 = self.data as *const f16; let self_data: *const f16 = self.data as *const f16;
let tgt_data: *mut f32 = result.data as *mut f32; let tgt_data: *mut f32 = result.data as *mut f32;
let tgt_capacity_cols = result.capacity_cols as i64; let tgt_capacity_cols = result.capacity_cols;
let self_capacity_cols = self.capacity_cols as i64; let self_capacity_cols = self.capacity_cols;
for row in 0..self.rows { for row in 0..self.rows {
for col in 0..cols_it { for col in 0..cols_it {
let col = col * 8; let col = col * 8;
@ -1719,8 +1811,8 @@ impl Tensor {
let result = Tensor::uninitialized(self.rows, self.cols, TensorDType::Float16); let result = Tensor::uninitialized(self.rows, self.cols, TensorDType::Float16);
let self_data: *const f32 = self.data as *const f32; let self_data: *const f32 = self.data as *const f32;
let tgt_data: *mut f16 = result.data as *mut f16; let tgt_data: *mut f16 = result.data as *mut f16;
let tgt_capacity_cols = result.capacity_cols as i64; let tgt_capacity_cols = result.capacity_cols;
let self_capacity_cols = self.capacity_cols as i64; let self_capacity_cols = self.capacity_cols;
for row in 0..self.rows { for row in 0..self.rows {
for col in 0..cols_it { for col in 0..cols_it {
@ -1973,9 +2065,9 @@ mod tests {
fn mat_mul_transposed_agrees_with_regular_mat_mul() { fn mat_mul_transposed_agrees_with_regular_mat_mul() {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..1000 { for _ in 0..1000 {
let a = rng.gen_range(8..64); let a = rng.gen_range(1..=128);
let b = rng.gen_range(8..64); let b = rng.gen_range(1..=128);
let r = rng.gen_range(8..64); let r = rng.gen_range(1..=128);
// Make matrixes AxR and RxB // Make matrixes AxR and RxB
let a = Tensor::random(a, r, TensorDType::Float32); let a = Tensor::random(a, r, TensorDType::Float32);
@ -1990,7 +2082,7 @@ mod tests {
for row in 0..c.rows { for row in 0..c.rows {
for col in 0..c.cols { for col in 0..c.cols {
assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-5); assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-3);
} }
} }
} }

@ -68,7 +68,7 @@ impl TokenSampler {
pub fn sample( pub fn sample(
&self, &self,
logits: &Tensor, logits: &Tensor,
tokenizer: &Tokenizer, _tokenizer: &Tokenizer,
existing_tokens: &[TokenId], existing_tokens: &[TokenId],
) -> (TokenId, f32) { ) -> (TokenId, f32) {
let mut times_used: BTreeMap<TokenId, usize> = BTreeMap::new(); let mut times_used: BTreeMap<TokenId, usize> = BTreeMap::new();
@ -119,7 +119,7 @@ impl TokenSampler {
None => { None => {
// Sort NaNs to bottom // Sort NaNs to bottom
if b.1.is_nan() { if b.1.is_nan() {
return std::cmp::Ordering::Less; std::cmp::Ordering::Less
} else if a.1.is_nan() { } else if a.1.is_nan() {
return std::cmp::Ordering::Greater; return std::cmp::Ordering::Greater;
} else { } else {

@ -133,7 +133,7 @@ impl Tokenizer {
let mut skip_s: &str = ""; let mut skip_s: &str = "";
// Specially recognize newline. Otherwise it matches something we don't actually // Specially recognize newline. Otherwise it matches something we don't actually
// want. // want.
if s.starts_with("\n") { if s.starts_with('\n') {
if self.str_to_id("<0x0A>").is_some() { if self.str_to_id("<0x0A>").is_some() {
best_candidate = "<0x0A>"; best_candidate = "<0x0A>";
best_candidate_len = best_candidate.len(); best_candidate_len = best_candidate.len();

@ -354,7 +354,7 @@ impl RMSNorm {
let data_dir: &Path = data_dir.as_ref(); let data_dir: &Path = data_dir.as_ref();
let weights = Tensor::from_unpickled_pieces( let weights = Tensor::from_unpickled_pieces(
&unpickled[0..=0], &unpickled[0..=0],
name.clone(), name,
data_dir, data_dir,
FromPiecesDirection::Rows, FromPiecesDirection::Rows,
)? )?
@ -448,11 +448,11 @@ impl FeedForward {
#[cfg(not(feature = "opencl"))] #[cfg(not(feature = "opencl"))]
if w1w3_out.rows() == 1 { if w1w3_out.rows() == 1 {
return self self
.w2 .w2
.matrix_vector_mul_transposed_multithreaded(&w1w3_out); .matrix_vector_mul_transposed_multithreaded(&w1w3_out)
} else { } else {
return self.w2.matrix_mul_transposed(&w1w3_out); self.w2.matrix_mul_transposed(&w1w3_out)
} }
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
{ {

@ -587,7 +587,7 @@ pub fn unpickle(bytes: &[u8]) -> Result<Value, UnpicklingError> {
)); ));
} }
let idx = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); let idx = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
match memo.get(&(idx as u32)) { match memo.get(&{ idx }) {
None => { None => {
return Err(UnpicklingError::UnpicklingError( return Err(UnpicklingError::UnpicklingError(
"LONG_BINGET index out of range".to_string(), "LONG_BINGET index out of range".to_string(),

Loading…
Cancel
Save