From 3d0afcf24309f28ec540ed7645c35400a865ad6f Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Fri, 17 Mar 2023 23:25:19 -0700 Subject: [PATCH] Make matrix multiplication multithreaded. This improves performance greatly with f16. It's faster now than OpenCL on LLaMA-7B. --- src/benches/benchmark.rs | 30 +- src/rllama_main.rs | 2 +- src/tensor.rs | 771 ++++++++++++++++++++------------------- src/transformer.rs | 2 + 4 files changed, 411 insertions(+), 394 deletions(-) diff --git a/src/benches/benchmark.rs b/src/benches/benchmark.rs index 4b126dc..2453d79 100644 --- a/src/benches/benchmark.rs +++ b/src/benches/benchmark.rs @@ -129,6 +129,18 @@ pub fn tensor_benchmarks(c: &mut Criterion) { }, ); + c.bench_function( + "matrix multiplication 8x4096 @ 4096x4096 f16 in-place, transposed", + |b| { + b.iter(|| { + let _ = result_84096_f16.matrix_mul_inplace_transposed( + black_box(&orig_84096_1_f16), + black_box(&orig_84096_2_f16), + ); + }) + }, + ); + c.bench_function( "matrix multiplication 8x4096 @ 4096x4096 f32 in-place, transposed", |b| { @@ -142,13 +154,11 @@ pub fn tensor_benchmarks(c: &mut Criterion) { ); c.bench_function( - "matrix multiplication 8x4096 @ 4096x4096 f16 in-place, transposed", + "matrix multiplication 8x4096 @ 4096x4096 f32 in-place", |b| { b.iter(|| { - let _ = result_84096_f16.matrix_mul_inplace_transposed( - black_box(&orig_84096_1_f16), - black_box(&orig_84096_2_f16), - ); + let _ = result_84096 + .matrix_mul_inplace(black_box(&orig_84096_1), black_box(&orig_84096_2)); }) }, ); @@ -165,16 +175,6 @@ pub fn tensor_benchmarks(c: &mut Criterion) { }) }); - c.bench_function( - "matrix multiplication 8x4096 @ 4096x4096 f32 in-place", - |b| { - b.iter(|| { - let _ = result_84096 - .matrix_mul_inplace(black_box(&orig_84096_1), black_box(&orig_84096_2)); - }) - }, - ); - c.bench_function("matrix multiplication f32 not in-place", |b| { b.iter(|| { let _ = black_box(&orig32_1).matrix_mul(black_box(&orig32_2)); diff --git a/src/rllama_main.rs b/src/rllama_main.rs index b8813e1..aaff326 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -171,7 +171,7 @@ pub fn main() -> Result<(), Box> { DataSettings::new() }; - if cli.f16 == true { + if cli.f16 { data_settings = data_settings.force_f16(); } diff --git a/src/tensor.rs b/src/tensor.rs index 7dc4371..0691ae7 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -135,6 +135,23 @@ impl Drop for Tensor { } } +// Use this to smuggle pointers to threads without Rust getting so goddamn mad +// +// Assumption usize = pointer size. +#[derive(Copy, Clone)] +struct WrappedPtr { + ptr: usize, +} +impl WrappedPtr { + fn wrap(ptr: *const u8) -> WrappedPtr { + WrappedPtr { ptr: ptr as usize } + } + + fn unwrap(self) -> *const u8 { + self.ptr as *const u8 + } +} + fn compute_capacity_cols(dtype: TensorDType, cols: i64) -> i64 { match dtype { TensorDType::Float16 => compute_capacity_cols_f16(cols), @@ -1025,12 +1042,12 @@ impl Tensor { } pub fn is_on_cpu(&self) -> bool { - return !self.is_on_gpu(); + !self.is_on_gpu() } // Casts data type to whatever the other tensors data type is. pub fn to_same_type(&self, other: &Tensor) -> Tensor { - let mut result = self.clone(); + let result = self.clone(); if result.dtype() == other.dtype() { return result; } @@ -1115,8 +1132,8 @@ impl Tensor { self.rows as usize * self.capacity_cols as usize, ); } - let src_data: *const f32 = src.data as *const f32; - let other_data: *const f32 = other.data as *const f32; + let _src_data: *const f32 = src.data as *const f32; + let _other_data: *const f32 = other.data as *const f32; let src_rows: usize = src.rows as usize; let src_cols: usize = src.cols as usize; @@ -1145,152 +1162,184 @@ impl Tensor { }; unsafe { - for row in 0..row_its { - let row0 = row * 4; - let row1 = row * 4 + 1; - let row2 = row * 4 + 2; - let row3 = row * 4 + 3; - for col in 0..self_cols_its { - let col0 = col * 4; - let col1 = col * 4 + 1; - let col2 = col * 4 + 2; - let col3 = col * 4 + 3; - let mut targets8: [[__m256; 4]; 4] = [ - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - ]; - for p in 0..src_cols_its { - let other8_0: __m256 = _mm256_loadu_ps( - other_data.add(col0 * other_cols_capacity + p * ITEMS_PER_LINE), - ); - let other8_1: __m256 = if col1 < other_rows { - _mm256_loadu_ps( - other_data - .add(col1 * other_cols_capacity + p * ITEMS_PER_LINE), - ) - } else { - _mm256_setzero_ps() - }; - let other8_2: __m256 = if col2 < other_rows { - _mm256_loadu_ps( - other_data - .add(col2 * other_cols_capacity + p * ITEMS_PER_LINE), - ) - } else { - _mm256_setzero_ps() - }; - let other8_3: __m256 = if col3 < other_rows { - _mm256_loadu_ps( + let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data); + let other_data: WrappedPtr = WrappedPtr::wrap(other.data); + let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data); + (0..32).into_par_iter().for_each(|thread_idx| { + let src_data: *const f32 = src_data_wrap.unwrap() as *const f32; + let other_data: *const f32 = other_data.unwrap() as *const f32; + let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32; + for row in 0..row_its { + let row0 = row * 4; + let row1 = row * 4 + 1; + let row2 = row * 4 + 2; + let row3 = row * 4 + 3; + for col in 0..self_cols_its { + let row_col = row * self_cols_its + col; + if row_col % 32 != thread_idx { + continue; + } + let col0 = col * 4; + let col1 = col * 4 + 1; + let col2 = col * 4 + 2; + let col3 = col * 4 + 3; + let mut targets8: [[__m256; 4]; 4] = [ + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + ]; + for p in 0..src_cols_its { + let other8_0: __m256 = _mm256_loadu_ps( other_data - .add(col3 * other_cols_capacity + p * ITEMS_PER_LINE), - ) - } else { - _mm256_setzero_ps() - }; - let src8_0: __m256 = _mm256_loadu_ps( - src_data.add(row0 * src_cols_capacity + p * ITEMS_PER_LINE), - ); - let src8_1: __m256 = if row1 < src_rows { - _mm256_loadu_ps( - src_data.add(row1 * src_cols_capacity + p * ITEMS_PER_LINE), - ) - } else { - _mm256_setzero_ps() - }; - let src8_2: __m256 = if row2 < src_rows { - _mm256_loadu_ps( - src_data.add(row2 * src_cols_capacity + p * ITEMS_PER_LINE), - ) - } else { - _mm256_setzero_ps() - }; - let src8_3: __m256 = if row3 < src_rows { - _mm256_loadu_ps( - src_data.add(row3 * src_cols_capacity + p * ITEMS_PER_LINE), - ) - } else { - _mm256_setzero_ps() - }; - targets8[0][0] = _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]); - targets8[0][1] = _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]); - targets8[0][2] = _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]); - targets8[0][3] = _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]); - targets8[1][0] = _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]); - targets8[1][1] = _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]); - targets8[1][2] = _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]); - targets8[1][3] = _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]); - targets8[2][0] = _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]); - targets8[2][1] = _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]); - targets8[2][2] = _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]); - targets8[2][3] = _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]); - targets8[3][0] = _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]); - targets8[3][1] = _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]); - targets8[3][2] = _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]); - targets8[3][3] = _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]); - } - let target00: f32 = horizontal_sum(targets8[0][0]); - let target01: f32 = horizontal_sum(targets8[0][1]); - let target02: f32 = horizontal_sum(targets8[0][2]); - let target03: f32 = horizontal_sum(targets8[0][3]); - let target10: f32 = horizontal_sum(targets8[1][0]); - let target11: f32 = horizontal_sum(targets8[1][1]); - let target12: f32 = horizontal_sum(targets8[1][2]); - let target13: f32 = horizontal_sum(targets8[1][3]); - let target20: f32 = horizontal_sum(targets8[2][0]); - let target21: f32 = horizontal_sum(targets8[2][1]); - let target22: f32 = horizontal_sum(targets8[2][2]); - let target23: f32 = horizontal_sum(targets8[2][3]); - let target30: f32 = horizontal_sum(targets8[3][0]); - let target31: f32 = horizontal_sum(targets8[3][1]); - let target32: f32 = horizontal_sum(targets8[3][2]); - let target33: f32 = horizontal_sum(targets8[3][3]); - - *tgt_data.add(row0 * self_cols_capacity + col0) += target00; - *tgt_data.add(row0 * self_cols_capacity + col1) += target10; - *tgt_data.add(row0 * self_cols_capacity + col2) += target20; - *tgt_data.add(row0 * self_cols_capacity + col3) += target30; - if row1 < self_rows { - *tgt_data.add(row1 * self_cols_capacity + col0) += target01; - *tgt_data.add(row1 * self_cols_capacity + col1) += target11; - *tgt_data.add(row1 * self_cols_capacity + col2) += target21; - *tgt_data.add(row1 * self_cols_capacity + col3) += target31; - } - if row2 < self_rows { - *tgt_data.add(row2 * self_cols_capacity + col0) += target02; - *tgt_data.add(row2 * self_cols_capacity + col1) += target12; - *tgt_data.add(row2 * self_cols_capacity + col2) += target22; - *tgt_data.add(row2 * self_cols_capacity + col3) += target32; - } - if row3 < self_rows { - *tgt_data.add(row3 * self_cols_capacity + col0) += target03; - *tgt_data.add(row3 * self_cols_capacity + col1) += target13; - *tgt_data.add(row3 * self_cols_capacity + col2) += target23; - *tgt_data.add(row3 * self_cols_capacity + col3) += target33; + .add(col0 * other_cols_capacity + p * ITEMS_PER_LINE), + ); + let other8_1: __m256 = + if col1 < other_rows { + _mm256_loadu_ps(other_data.add( + col1 * other_cols_capacity + p * ITEMS_PER_LINE, + )) + } else { + _mm256_setzero_ps() + }; + let other8_2: __m256 = + if col2 < other_rows { + _mm256_loadu_ps(other_data.add( + col2 * other_cols_capacity + p * ITEMS_PER_LINE, + )) + } else { + _mm256_setzero_ps() + }; + let other8_3: __m256 = + if col3 < other_rows { + _mm256_loadu_ps(other_data.add( + col3 * other_cols_capacity + p * ITEMS_PER_LINE, + )) + } else { + _mm256_setzero_ps() + }; + let src8_0: __m256 = _mm256_loadu_ps( + src_data.add(row0 * src_cols_capacity + p * ITEMS_PER_LINE), + ); + let src8_1: __m256 = if row1 < src_rows { + _mm256_loadu_ps( + src_data + .add(row1 * src_cols_capacity + p * ITEMS_PER_LINE), + ) + } else { + _mm256_setzero_ps() + }; + let src8_2: __m256 = if row2 < src_rows { + _mm256_loadu_ps( + src_data + .add(row2 * src_cols_capacity + p * ITEMS_PER_LINE), + ) + } else { + _mm256_setzero_ps() + }; + let src8_3: __m256 = if row3 < src_rows { + _mm256_loadu_ps( + src_data + .add(row3 * src_cols_capacity + p * ITEMS_PER_LINE), + ) + } else { + _mm256_setzero_ps() + }; + targets8[0][0] = + _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]); + targets8[0][1] = + _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]); + targets8[0][2] = + _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]); + targets8[0][3] = + _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]); + targets8[1][0] = + _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]); + targets8[1][1] = + _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]); + targets8[1][2] = + _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]); + targets8[1][3] = + _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]); + targets8[2][0] = + _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]); + targets8[2][1] = + _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]); + targets8[2][2] = + _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]); + targets8[2][3] = + _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]); + targets8[3][0] = + _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]); + targets8[3][1] = + _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]); + targets8[3][2] = + _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]); + targets8[3][3] = + _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]); + } + let target00: f32 = horizontal_sum(targets8[0][0]); + let target01: f32 = horizontal_sum(targets8[0][1]); + let target02: f32 = horizontal_sum(targets8[0][2]); + let target03: f32 = horizontal_sum(targets8[0][3]); + let target10: f32 = horizontal_sum(targets8[1][0]); + let target11: f32 = horizontal_sum(targets8[1][1]); + let target12: f32 = horizontal_sum(targets8[1][2]); + let target13: f32 = horizontal_sum(targets8[1][3]); + let target20: f32 = horizontal_sum(targets8[2][0]); + let target21: f32 = horizontal_sum(targets8[2][1]); + let target22: f32 = horizontal_sum(targets8[2][2]); + let target23: f32 = horizontal_sum(targets8[2][3]); + let target30: f32 = horizontal_sum(targets8[3][0]); + let target31: f32 = horizontal_sum(targets8[3][1]); + let target32: f32 = horizontal_sum(targets8[3][2]); + let target33: f32 = horizontal_sum(targets8[3][3]); + + *tgt_data.add(row0 * self_cols_capacity + col0) += target00; + *tgt_data.add(row0 * self_cols_capacity + col1) += target10; + *tgt_data.add(row0 * self_cols_capacity + col2) += target20; + *tgt_data.add(row0 * self_cols_capacity + col3) += target30; + if row1 < self_rows { + *tgt_data.add(row1 * self_cols_capacity + col0) += target01; + *tgt_data.add(row1 * self_cols_capacity + col1) += target11; + *tgt_data.add(row1 * self_cols_capacity + col2) += target21; + *tgt_data.add(row1 * self_cols_capacity + col3) += target31; + } + if row2 < self_rows { + *tgt_data.add(row2 * self_cols_capacity + col0) += target02; + *tgt_data.add(row2 * self_cols_capacity + col1) += target12; + *tgt_data.add(row2 * self_cols_capacity + col2) += target22; + *tgt_data.add(row2 * self_cols_capacity + col3) += target32; + } + if row3 < self_rows { + *tgt_data.add(row3 * self_cols_capacity + col0) += target03; + *tgt_data.add(row3 * self_cols_capacity + col1) += target13; + *tgt_data.add(row3 * self_cols_capacity + col2) += target23; + *tgt_data.add(row3 * self_cols_capacity + col3) += target33; + } } } - } + }); } } TensorDType::Float16 => { @@ -1304,8 +1353,8 @@ impl Tensor { self.rows as usize * self.capacity_cols as usize, ); } - let src_data: *const f16 = src.data as *const f16; - let other_data: *const f16 = other.data as *const f16; + let _src_data: *const f16 = src.data as *const f16; + let _other_data: *const f16 = other.data as *const f16; let src_rows: usize = src.rows as usize; let src_cols: usize = src.cols as usize; @@ -1334,160 +1383,192 @@ impl Tensor { }; unsafe { - for row in 0..row_its { - let row0 = row * 4; - let row1 = row * 4 + 1; - let row2 = row * 4 + 2; - let row3 = row * 4 + 3; - for col in 0..self_cols_its { - let col0 = col * 4; - let col1 = col * 4 + 1; - let col2 = col * 4 + 2; - let col3 = col * 4 + 3; - let mut targets8: [[__m256; 4]; 4] = [ - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - [ - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - _mm256_setzero_ps(), - ], - ]; - // Loads from (row, column..column+8) and (row+1, column..column+8) - #[inline] - fn load2_rows( - ptr: *const f16, - row: usize, - column: usize, - cols_capacity: usize, - nrows: usize, - ) -> (__m256, __m256) { - unsafe { - let (left, right) = if row + 1 < nrows { - ( - _mm_loadu_si128(ptr.add(row * cols_capacity + column) - as *const __m128i), - _mm_loadu_si128( - ptr.add((row + 1) * cols_capacity + column) - as *const __m128i, - ), - ) - } else { - ( - _mm_loadu_si128(ptr.add(row * cols_capacity + column) - as *const __m128i), - _mm_setzero_si128(), - ) - }; - let left: __m256 = _mm256_cvtph_ps(left); - let right: __m256 = _mm256_cvtph_ps(right); - (left, right) + let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data); + let other_data: WrappedPtr = WrappedPtr::wrap(other.data); + let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data); + (0..32).into_par_iter().for_each(|thread_idx| { + let src_data: *const f16 = src_data_wrap.unwrap() as *const f16; + let other_data: *const f16 = other_data.unwrap() as *const f16; + let tgt_data: *mut f16 = tgt_data.unwrap() as *mut f16; + for row in 0..row_its { + let row0 = row * 4; + let row1 = row * 4 + 1; + let row2 = row * 4 + 2; + let row3 = row * 4 + 3; + for col in 0..self_cols_its { + let row_col = row * self_cols_its + col; + if row_col % 32 != thread_idx { + continue; + } + let col0 = col * 4; + let col1 = col * 4 + 1; + let col2 = col * 4 + 2; + let col3 = col * 4 + 3; + let mut targets8: [[__m256; 4]; 4] = [ + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + [ + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + _mm256_setzero_ps(), + ], + ]; + // Loads from (row, column..column+8) and (row+1, column..column+8) + #[inline] + fn load2_rows( + ptr: *const f16, + row: usize, + column: usize, + cols_capacity: usize, + nrows: usize, + ) -> (__m256, __m256) { + unsafe { + let (left, right) = if row + 1 < nrows { + ( + _mm_loadu_si128( + ptr.add(row * cols_capacity + column) + as *const __m128i, + ), + _mm_loadu_si128( + ptr.add((row + 1) * cols_capacity + column) + as *const __m128i, + ), + ) + } else { + ( + _mm_loadu_si128( + ptr.add(row * cols_capacity + column) + as *const __m128i, + ), + _mm_setzero_si128(), + ) + }; + let left: __m256 = _mm256_cvtph_ps(left); + let right: __m256 = _mm256_cvtph_ps(right); + (left, right) + } + } + for p in 0..src_cols_its { + let (other8_0, other8_1) = load2_rows( + other_data, + col0, + p * ITEMS_PER_LINE, + other_cols_capacity, + other_rows, + ); + let (other8_2, other8_3) = load2_rows( + other_data, + col2, + p * ITEMS_PER_LINE, + other_cols_capacity, + other_rows, + ); + let (src8_0, src8_1) = load2_rows( + src_data, + row0, + p * ITEMS_PER_LINE, + src_cols_capacity, + src_rows, + ); + let (src8_2, src8_3) = load2_rows( + src_data, + row2, + p * ITEMS_PER_LINE, + src_cols_capacity, + src_rows, + ); + targets8[0][0] = + _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]); + targets8[0][1] = + _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]); + targets8[0][2] = + _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]); + targets8[0][3] = + _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]); + targets8[1][0] = + _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]); + targets8[1][1] = + _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]); + targets8[1][2] = + _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]); + targets8[1][3] = + _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]); + targets8[2][0] = + _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]); + targets8[2][1] = + _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]); + targets8[2][2] = + _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]); + targets8[2][3] = + _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]); + targets8[3][0] = + _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]); + targets8[3][1] = + _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]); + targets8[3][2] = + _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]); + targets8[3][3] = + _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]); + } + let target00: f16 = horizontal_sum_f32_to_f16(targets8[0][0]); + let target01: f16 = horizontal_sum_f32_to_f16(targets8[0][1]); + let target02: f16 = horizontal_sum_f32_to_f16(targets8[0][2]); + let target03: f16 = horizontal_sum_f32_to_f16(targets8[0][3]); + let target10: f16 = horizontal_sum_f32_to_f16(targets8[1][0]); + let target11: f16 = horizontal_sum_f32_to_f16(targets8[1][1]); + let target12: f16 = horizontal_sum_f32_to_f16(targets8[1][2]); + let target13: f16 = horizontal_sum_f32_to_f16(targets8[1][3]); + let target20: f16 = horizontal_sum_f32_to_f16(targets8[2][0]); + let target21: f16 = horizontal_sum_f32_to_f16(targets8[2][1]); + let target22: f16 = horizontal_sum_f32_to_f16(targets8[2][2]); + let target23: f16 = horizontal_sum_f32_to_f16(targets8[2][3]); + let target30: f16 = horizontal_sum_f32_to_f16(targets8[3][0]); + let target31: f16 = horizontal_sum_f32_to_f16(targets8[3][1]); + let target32: f16 = horizontal_sum_f32_to_f16(targets8[3][2]); + let target33: f16 = horizontal_sum_f32_to_f16(targets8[3][3]); + + *tgt_data.add(row0 * self_cols_capacity + col0) += target00; + *tgt_data.add(row0 * self_cols_capacity + col1) += target10; + *tgt_data.add(row0 * self_cols_capacity + col2) += target20; + *tgt_data.add(row0 * self_cols_capacity + col3) += target30; + if row1 < self_rows { + *tgt_data.add(row1 * self_cols_capacity + col0) += target01; + *tgt_data.add(row1 * self_cols_capacity + col1) += target11; + *tgt_data.add(row1 * self_cols_capacity + col2) += target21; + *tgt_data.add(row1 * self_cols_capacity + col3) += target31; + } + if row2 < self_rows { + *tgt_data.add(row2 * self_cols_capacity + col0) += target02; + *tgt_data.add(row2 * self_cols_capacity + col1) += target12; + *tgt_data.add(row2 * self_cols_capacity + col2) += target22; + *tgt_data.add(row2 * self_cols_capacity + col3) += target32; + } + if row3 < self_rows { + *tgt_data.add(row3 * self_cols_capacity + col0) += target03; + *tgt_data.add(row3 * self_cols_capacity + col1) += target13; + *tgt_data.add(row3 * self_cols_capacity + col2) += target23; + *tgt_data.add(row3 * self_cols_capacity + col3) += target33; } - } - for p in 0..src_cols_its { - let (other8_0, other8_1) = load2_rows( - other_data, - col0, - p * ITEMS_PER_LINE, - other_cols_capacity, - other_rows, - ); - let (other8_2, other8_3) = load2_rows( - other_data, - col2, - p * ITEMS_PER_LINE, - other_cols_capacity, - other_rows, - ); - let (src8_0, src8_1) = load2_rows( - src_data, - row0, - p * ITEMS_PER_LINE, - src_cols_capacity, - src_rows, - ); - let (src8_2, src8_3) = load2_rows( - src_data, - row2, - p * ITEMS_PER_LINE, - src_cols_capacity, - src_rows, - ); - targets8[0][0] = _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]); - targets8[0][1] = _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]); - targets8[0][2] = _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]); - targets8[0][3] = _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]); - targets8[1][0] = _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]); - targets8[1][1] = _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]); - targets8[1][2] = _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]); - targets8[1][3] = _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]); - targets8[2][0] = _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]); - targets8[2][1] = _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]); - targets8[2][2] = _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]); - targets8[2][3] = _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]); - targets8[3][0] = _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]); - targets8[3][1] = _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]); - targets8[3][2] = _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]); - targets8[3][3] = _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]); - } - let target00: f16 = horizontal_sum_f32_to_f16(targets8[0][0]); - let target01: f16 = horizontal_sum_f32_to_f16(targets8[0][1]); - let target02: f16 = horizontal_sum_f32_to_f16(targets8[0][2]); - let target03: f16 = horizontal_sum_f32_to_f16(targets8[0][3]); - let target10: f16 = horizontal_sum_f32_to_f16(targets8[1][0]); - let target11: f16 = horizontal_sum_f32_to_f16(targets8[1][1]); - let target12: f16 = horizontal_sum_f32_to_f16(targets8[1][2]); - let target13: f16 = horizontal_sum_f32_to_f16(targets8[1][3]); - let target20: f16 = horizontal_sum_f32_to_f16(targets8[2][0]); - let target21: f16 = horizontal_sum_f32_to_f16(targets8[2][1]); - let target22: f16 = horizontal_sum_f32_to_f16(targets8[2][2]); - let target23: f16 = horizontal_sum_f32_to_f16(targets8[2][3]); - let target30: f16 = horizontal_sum_f32_to_f16(targets8[3][0]); - let target31: f16 = horizontal_sum_f32_to_f16(targets8[3][1]); - let target32: f16 = horizontal_sum_f32_to_f16(targets8[3][2]); - let target33: f16 = horizontal_sum_f32_to_f16(targets8[3][3]); - - *tgt_data.add(row0 * self_cols_capacity + col0) += target00; - *tgt_data.add(row0 * self_cols_capacity + col1) += target10; - *tgt_data.add(row0 * self_cols_capacity + col2) += target20; - *tgt_data.add(row0 * self_cols_capacity + col3) += target30; - if row1 < self_rows { - *tgt_data.add(row1 * self_cols_capacity + col0) += target01; - *tgt_data.add(row1 * self_cols_capacity + col1) += target11; - *tgt_data.add(row1 * self_cols_capacity + col2) += target21; - *tgt_data.add(row1 * self_cols_capacity + col3) += target31; - } - if row2 < self_rows { - *tgt_data.add(row2 * self_cols_capacity + col0) += target02; - *tgt_data.add(row2 * self_cols_capacity + col1) += target12; - *tgt_data.add(row2 * self_cols_capacity + col2) += target22; - *tgt_data.add(row2 * self_cols_capacity + col3) += target32; - } - if row3 < self_rows { - *tgt_data.add(row3 * self_cols_capacity + col0) += target03; - *tgt_data.add(row3 * self_cols_capacity + col1) += target13; - *tgt_data.add(row3 * self_cols_capacity + col2) += target23; - *tgt_data.add(row3 * self_cols_capacity + col3) += target33; } } - } + }); } } } @@ -1534,6 +1615,7 @@ impl Tensor { assert_eq!(other.rows, 1); assert_eq!(other.dtype, self.dtype); + #[allow(unreachable_patterns)] match self.dtype { TensorDType::Float32 => self.matrix_vector_mul_transposed_f32(other), TensorDType::Float16 => self.matrix_vector_mul_transposed_f16(other), @@ -1669,7 +1751,7 @@ impl Tensor { self.assume_on_cpu(); other.assume_on_cpu(); unsafe { - let mut result = Tensor::uninitialized(self.rows, 1, self.dtype); + let result = Tensor::zeros(self.rows, 1, self.dtype); let col_its: usize = if self.cols % 8 == 0 { (self.cols / 8) as usize } else { @@ -1680,16 +1762,18 @@ impl Tensor { } else { (self.rows / 4 + 1) as usize }; + let self_data: *const f32 = self.data as *const f32; + let other_data: *const f32 = other.data as *const f32; + let tgt_data: *mut f32 = result.data as *mut f32; + let ncols_capacity: usize = result.capacity_cols as usize; + let mut sum8s: [__m256; 4] = [ _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), ]; - let self_data: *const f32 = self.data as *const f32; - let other_data: *const f32 = other.data as *const f32; - let _tgt_data: *mut f32 = result.data as *mut f32; - let _ncols_capacity: usize = result.capacity_cols as usize; + for row in 0..row_its { let row: i64 = row as i64; sum8s[0] = _mm256_setzero_ps(); @@ -1732,91 +1816,22 @@ impl Tensor { let sum_2: f32 = horizontal_sum(sum8s[2]); let sum_3: f32 = horizontal_sum(sum8s[3]); if row4_0 < result.rows { - result.set_f32(row4_0, 0, sum_0); + *(tgt_data.add(row4_0 as usize * ncols_capacity)) = sum_0; } if row4_1 < result.rows { - result.set_f32(row4_1, 0, sum_1); + *(tgt_data.add(row4_1 as usize * ncols_capacity)) = sum_1; } if row4_2 < result.rows { - result.set_f32(row4_2, 0, sum_2); + *(tgt_data.add(row4_2 as usize * ncols_capacity)) = sum_2; } if row4_3 < result.rows { - result.set_f32(row4_3, 0, sum_3); + *(tgt_data.add(row4_3 as usize * ncols_capacity)) = sum_3; } } result } } - /// Same as matrix_vector_mul but uses threading. - pub fn matrix_vector_mul_transposed_multithreaded(&self, other: &Tensor) -> Tensor { - self.assume_on_cpu(); - other.assume_on_cpu(); - if self.cols != other.cols { - panic!( - "Invalid matrix-vector transposed multiplication {}x{} vs {}x{}", - self.rows, self.cols, other.rows, other.cols - ); - } - assert_eq!(other.rows, 1); - assert_eq!(other.dtype, self.dtype); - assert_eq!(self.dtype, TensorDType::Float32); - - // Use this to smuggle pointers to threads without Rust getting so goddamn mad - // - // Assumption usize = pointer size. - #[derive(Copy, Clone)] - struct WrappedPtr { - ptr: usize, - } - impl WrappedPtr { - fn wrap(ptr: *const u8) -> WrappedPtr { - WrappedPtr { ptr: ptr as usize } - } - - fn unwrap(self) -> *const u8 { - self.ptr as *const u8 - } - } - - unsafe { - let result = Tensor::uninitialized(self.rows, 1, self.dtype); - let capacity_cols: i64 = self.capacity_cols; - let result_capacity_cols: i64 = result.capacity_cols; - let col_its: usize = if self.cols % 8 == 0 { - (self.cols / 8) as usize - } else { - (self.cols / 8 + 1) as usize - }; - let self_data = WrappedPtr::wrap(self.data); - let other_data = WrappedPtr::wrap(other.data); - let result_data = WrappedPtr::wrap(result.data); - (0..self.rows as usize) - .into_par_iter() - .with_min_len(64) - .for_each(|row| { - let row = row as i64; - let self_data: *const f32 = self_data.unwrap() as *const f32; - let other_data: *const f32 = other_data.unwrap() as *const f32; - let result_data: *mut f32 = result_data.unwrap() as *mut f32; - - let mut sum8: __m256 = _mm256_setzero_ps(); - for col in 0..col_its { - let col = col * 8; - let left_side8 = - _mm256_loadu_ps(self_data.add((row * capacity_cols) as usize + col)); - let right_side8 = _mm256_loadu_ps(other_data.add(col)); - sum8 = _mm256_fmadd_ps(left_side8, right_side8, sum8); - } - let sum: f32 = horizontal_sum(sum8); - result_data - .add((row * result_capacity_cols) as usize) - .write(sum); - }); - result - } - } - // Computes matrix multiplication assuming left side has number of rows as 1 #[allow(clippy::erasing_op)] #[allow(clippy::identity_op)] diff --git a/src/transformer.rs b/src/transformer.rs index 07205d6..7836f6c 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -58,6 +58,7 @@ impl DataSettings { } } + #[allow(clippy::new_without_default)] #[cfg(not(feature = "opencl"))] pub fn new() -> Self { DataSettings { force_f16: false } @@ -147,6 +148,7 @@ pub struct RMSNorm { weight: Tensor, } +#[allow(dead_code)] pub struct Attention { wq: Tensor, wk: Tensor,