|
|
|
@ -1068,6 +1068,7 @@ impl Tensor {
|
|
|
|
let src_data: *const f32 = src.data as *const f32;
|
|
|
|
let src_data: *const f32 = src.data as *const f32;
|
|
|
|
let other_data: *const f32 = other.data as *const f32;
|
|
|
|
let other_data: *const f32 = other.data as *const f32;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let src_rows: usize = src.rows as usize;
|
|
|
|
let src_cols: usize = src.cols as usize;
|
|
|
|
let src_cols: usize = src.cols as usize;
|
|
|
|
let self_rows: usize = self.rows as usize;
|
|
|
|
let self_rows: usize = self.rows as usize;
|
|
|
|
let self_cols: usize = self.cols as usize;
|
|
|
|
let self_cols: usize = self.cols as usize;
|
|
|
|
@ -1080,25 +1081,77 @@ impl Tensor {
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
src_cols / ITEMS_PER_CACHE_LINE + 1
|
|
|
|
src_cols / ITEMS_PER_CACHE_LINE + 1
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
let row_its = if self_rows % 4 == 0 {
|
|
|
|
|
|
|
|
self_rows / 4
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
self_rows / 4 + 1
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
unsafe {
|
|
|
|
unsafe {
|
|
|
|
for row in 0..self_rows {
|
|
|
|
for row in 0..row_its {
|
|
|
|
let row = row;
|
|
|
|
let row0 = row * 4;
|
|
|
|
|
|
|
|
let row1 = row * 4 + 1;
|
|
|
|
|
|
|
|
let row2 = row * 4 + 2;
|
|
|
|
|
|
|
|
let row3 = row * 4 + 3;
|
|
|
|
for col in 0..self_cols {
|
|
|
|
for col in 0..self_cols {
|
|
|
|
let mut target8: __m256 = _mm256_setzero_ps();
|
|
|
|
let mut targets8: [__m256; 4] = [
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
];
|
|
|
|
for p in 0..src_cols_its {
|
|
|
|
for p in 0..src_cols_its {
|
|
|
|
let src8: __m256 = _mm256_loadu_ps(
|
|
|
|
|
|
|
|
src_data
|
|
|
|
|
|
|
|
.add(row * src_cols_capacity + p * ITEMS_PER_CACHE_LINE),
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
let other8: __m256 = _mm256_loadu_ps(
|
|
|
|
let other8: __m256 = _mm256_loadu_ps(
|
|
|
|
other_data
|
|
|
|
other_data
|
|
|
|
.add(col * other_cols_capacity + p * ITEMS_PER_CACHE_LINE),
|
|
|
|
.add(col * other_cols_capacity + p * ITEMS_PER_CACHE_LINE),
|
|
|
|
);
|
|
|
|
);
|
|
|
|
target8 = _mm256_fmadd_ps(src8, other8, target8);
|
|
|
|
let src8_0: __m256 = _mm256_loadu_ps(
|
|
|
|
|
|
|
|
src_data
|
|
|
|
|
|
|
|
.add(row0 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE),
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
let src8_1: __m256 =
|
|
|
|
|
|
|
|
if row1 < src_rows {
|
|
|
|
|
|
|
|
_mm256_loadu_ps(src_data.add(
|
|
|
|
|
|
|
|
row1 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE,
|
|
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
let src8_2: __m256 =
|
|
|
|
|
|
|
|
if row2 < src_rows {
|
|
|
|
|
|
|
|
_mm256_loadu_ps(src_data.add(
|
|
|
|
|
|
|
|
row2 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE,
|
|
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
let src8_3: __m256 =
|
|
|
|
|
|
|
|
if row3 < src_rows {
|
|
|
|
|
|
|
|
_mm256_loadu_ps(src_data.add(
|
|
|
|
|
|
|
|
row3 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE,
|
|
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
targets8[0] = _mm256_fmadd_ps(src8_0, other8, targets8[0]);
|
|
|
|
|
|
|
|
targets8[1] = _mm256_fmadd_ps(src8_1, other8, targets8[1]);
|
|
|
|
|
|
|
|
targets8[2] = _mm256_fmadd_ps(src8_2, other8, targets8[2]);
|
|
|
|
|
|
|
|
targets8[3] = _mm256_fmadd_ps(src8_3, other8, targets8[3]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
let target0: f32 = horizontal_sum(targets8[0]);
|
|
|
|
|
|
|
|
let target1: f32 = horizontal_sum(targets8[1]);
|
|
|
|
|
|
|
|
let target2: f32 = horizontal_sum(targets8[2]);
|
|
|
|
|
|
|
|
let target3: f32 = horizontal_sum(targets8[3]);
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col) = target0;
|
|
|
|
|
|
|
|
if row1 < self_rows {
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col) = target1;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if row2 < self_rows {
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col) = target2;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if row3 < self_rows {
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col) = target3;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
let target: f32 = horizontal_sum(target8);
|
|
|
|
|
|
|
|
*tgt_data.add(row * self_cols_capacity + col) = target;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|