Improve the handwritten AVX2 for matrix_mul_inplace_transposed.

This is something like ~60% faster than old version.
master
Mikko Juola 3 years ago
parent 0cce655763
commit 61bc42b728

@ -1068,6 +1068,7 @@ impl Tensor {
let src_data: *const f32 = src.data as *const f32; let src_data: *const f32 = src.data as *const f32;
let other_data: *const f32 = other.data as *const f32; let other_data: *const f32 = other.data as *const f32;
let src_rows: usize = src.rows as usize;
let src_cols: usize = src.cols as usize; let src_cols: usize = src.cols as usize;
let self_rows: usize = self.rows as usize; let self_rows: usize = self.rows as usize;
let self_cols: usize = self.cols as usize; let self_cols: usize = self.cols as usize;
@ -1080,25 +1081,77 @@ impl Tensor {
} else { } else {
src_cols / ITEMS_PER_CACHE_LINE + 1 src_cols / ITEMS_PER_CACHE_LINE + 1
}; };
let row_its = if self_rows % 4 == 0 {
self_rows / 4
} else {
self_rows / 4 + 1
};
unsafe { unsafe {
for row in 0..self_rows { for row in 0..row_its {
let row = row; let row0 = row * 4;
let row1 = row * 4 + 1;
let row2 = row * 4 + 2;
let row3 = row * 4 + 3;
for col in 0..self_cols { for col in 0..self_cols {
let mut target8: __m256 = _mm256_setzero_ps(); let mut targets8: [__m256; 4] = [
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
];
for p in 0..src_cols_its { for p in 0..src_cols_its {
let src8: __m256 = _mm256_loadu_ps(
src_data
.add(row * src_cols_capacity + p * ITEMS_PER_CACHE_LINE),
);
let other8: __m256 = _mm256_loadu_ps( let other8: __m256 = _mm256_loadu_ps(
other_data other_data
.add(col * other_cols_capacity + p * ITEMS_PER_CACHE_LINE), .add(col * other_cols_capacity + p * ITEMS_PER_CACHE_LINE),
); );
target8 = _mm256_fmadd_ps(src8, other8, target8); let src8_0: __m256 = _mm256_loadu_ps(
src_data
.add(row0 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE),
);
let src8_1: __m256 =
if row1 < src_rows {
_mm256_loadu_ps(src_data.add(
row1 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
} else {
_mm256_setzero_ps()
};
let src8_2: __m256 =
if row2 < src_rows {
_mm256_loadu_ps(src_data.add(
row2 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
} else {
_mm256_setzero_ps()
};
let src8_3: __m256 =
if row3 < src_rows {
_mm256_loadu_ps(src_data.add(
row3 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
} else {
_mm256_setzero_ps()
};
targets8[0] = _mm256_fmadd_ps(src8_0, other8, targets8[0]);
targets8[1] = _mm256_fmadd_ps(src8_1, other8, targets8[1]);
targets8[2] = _mm256_fmadd_ps(src8_2, other8, targets8[2]);
targets8[3] = _mm256_fmadd_ps(src8_3, other8, targets8[3]);
}
let target0: f32 = horizontal_sum(targets8[0]);
let target1: f32 = horizontal_sum(targets8[1]);
let target2: f32 = horizontal_sum(targets8[2]);
let target3: f32 = horizontal_sum(targets8[3]);
*tgt_data.add(row0 * self_cols_capacity + col) = target0;
if row1 < self_rows {
*tgt_data.add(row1 * self_cols_capacity + col) = target1;
}
if row2 < self_rows {
*tgt_data.add(row2 * self_cols_capacity + col) = target2;
}
if row3 < self_rows {
*tgt_data.add(row3 * self_cols_capacity + col) = target3;
} }
let target: f32 = horizontal_sum(target8);
*tgt_data.add(row * self_cols_capacity + col) = target;
} }
} }
} }

Loading…
Cancel
Save