|
|
|
@ -1508,31 +1508,43 @@ impl Tensor {
|
|
|
|
other.assume_on_cpu();
|
|
|
|
other.assume_on_cpu();
|
|
|
|
unsafe {
|
|
|
|
unsafe {
|
|
|
|
let mut result = Tensor::uninitialized(self.rows, 1, self.dtype);
|
|
|
|
let mut result = Tensor::uninitialized(self.rows, 1, self.dtype);
|
|
|
|
let col_its: usize = if self.cols % 8 == 0 {
|
|
|
|
let col_its: usize = if self.cols % 16 == 0 {
|
|
|
|
(self.cols / 8) as usize
|
|
|
|
(self.cols / 16) as usize
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
(self.cols / 8 + 1) as usize
|
|
|
|
(self.cols / 16 + 1) as usize
|
|
|
|
};
|
|
|
|
};
|
|
|
|
let row_its: usize = if self.rows % 4 == 0 {
|
|
|
|
let row_its: usize = if self.rows % 4 == 0 {
|
|
|
|
(self.rows / 4) as usize
|
|
|
|
(self.rows / 4) as usize
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
(self.rows / 4 + 1) as usize
|
|
|
|
(self.rows / 4 + 1) as usize
|
|
|
|
};
|
|
|
|
};
|
|
|
|
let mut sum8s: [__m256; 4] = [
|
|
|
|
let mut sum8s: [[__m256; 4]; 2] = [
|
|
|
|
|
|
|
|
[
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
[
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
],
|
|
|
|
];
|
|
|
|
];
|
|
|
|
let self_data: *const f16 = self.data as *const f16;
|
|
|
|
let self_data: *const f16 = self.data as *const f16;
|
|
|
|
let other_data: *const f16 = other.data as *const f16;
|
|
|
|
let other_data: *const f16 = other.data as *const f16;
|
|
|
|
let _ncols_capacity: usize = result.capacity_cols as usize;
|
|
|
|
let _ncols_capacity: usize = result.capacity_cols as usize;
|
|
|
|
for row in 0..row_its {
|
|
|
|
for row in 0..row_its {
|
|
|
|
let row: i64 = row as i64;
|
|
|
|
let row: i64 = row as i64;
|
|
|
|
sum8s[0] = _mm256_setzero_ps();
|
|
|
|
sum8s[0][0] = _mm256_setzero_ps();
|
|
|
|
sum8s[1] = _mm256_setzero_ps();
|
|
|
|
sum8s[0][1] = _mm256_setzero_ps();
|
|
|
|
sum8s[2] = _mm256_setzero_ps();
|
|
|
|
sum8s[0][2] = _mm256_setzero_ps();
|
|
|
|
sum8s[3] = _mm256_setzero_ps();
|
|
|
|
sum8s[0][3] = _mm256_setzero_ps();
|
|
|
|
|
|
|
|
sum8s[1][0] = _mm256_setzero_ps();
|
|
|
|
|
|
|
|
sum8s[1][1] = _mm256_setzero_ps();
|
|
|
|
|
|
|
|
sum8s[1][2] = _mm256_setzero_ps();
|
|
|
|
|
|
|
|
sum8s[1][3] = _mm256_setzero_ps();
|
|
|
|
let row4_0 = row * 4;
|
|
|
|
let row4_0 = row * 4;
|
|
|
|
let row4_1 = row * 4 + 1;
|
|
|
|
let row4_1 = row * 4 + 1;
|
|
|
|
let row4_2 = row * 4 + 2;
|
|
|
|
let row4_2 = row * 4 + 2;
|
|
|
|
@ -1565,25 +1577,39 @@ impl Tensor {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for col in 0..col_its {
|
|
|
|
for col in 0..col_its {
|
|
|
|
let col = col * 8;
|
|
|
|
let col = col * 16;
|
|
|
|
let right_side8 = load2(other_data, col);
|
|
|
|
let col2 = col + 8;
|
|
|
|
let left_side8_0 =
|
|
|
|
let right_side8_0 = load2(other_data, col);
|
|
|
|
|
|
|
|
let left_side8_00 =
|
|
|
|
load2row(self_data, row4_0, col, self.capacity_cols, self.rows);
|
|
|
|
load2row(self_data, row4_0, col, self.capacity_cols, self.rows);
|
|
|
|
let left_side8_1 =
|
|
|
|
let left_side8_10 =
|
|
|
|
load2row(self_data, row4_1, col, self.capacity_cols, self.rows);
|
|
|
|
load2row(self_data, row4_1, col, self.capacity_cols, self.rows);
|
|
|
|
let left_side8_2 =
|
|
|
|
let left_side8_20 =
|
|
|
|
load2row(self_data, row4_2, col, self.capacity_cols, self.rows);
|
|
|
|
load2row(self_data, row4_2, col, self.capacity_cols, self.rows);
|
|
|
|
let left_side8_3 =
|
|
|
|
let left_side8_30 =
|
|
|
|
load2row(self_data, row4_3, col, self.capacity_cols, self.rows);
|
|
|
|
load2row(self_data, row4_3, col, self.capacity_cols, self.rows);
|
|
|
|
sum8s[0] = _mm256_fmadd_ps(left_side8_0, right_side8, sum8s[0]);
|
|
|
|
sum8s[0][0] = _mm256_fmadd_ps(left_side8_00, right_side8_0, sum8s[0][0]);
|
|
|
|
sum8s[1] = _mm256_fmadd_ps(left_side8_1, right_side8, sum8s[1]);
|
|
|
|
sum8s[0][1] = _mm256_fmadd_ps(left_side8_10, right_side8_0, sum8s[0][1]);
|
|
|
|
sum8s[2] = _mm256_fmadd_ps(left_side8_2, right_side8, sum8s[2]);
|
|
|
|
sum8s[0][2] = _mm256_fmadd_ps(left_side8_20, right_side8_0, sum8s[0][2]);
|
|
|
|
sum8s[3] = _mm256_fmadd_ps(left_side8_3, right_side8, sum8s[3]);
|
|
|
|
sum8s[0][3] = _mm256_fmadd_ps(left_side8_30, right_side8_0, sum8s[0][3]);
|
|
|
|
}
|
|
|
|
let right_side8_1 = load2(other_data, col2);
|
|
|
|
let sum_0: f32 = horizontal_sum(sum8s[0]);
|
|
|
|
let left_side8_01 =
|
|
|
|
let sum_1: f32 = horizontal_sum(sum8s[1]);
|
|
|
|
load2row(self_data, row4_0, col2, self.capacity_cols, self.rows);
|
|
|
|
let sum_2: f32 = horizontal_sum(sum8s[2]);
|
|
|
|
let left_side8_11 =
|
|
|
|
let sum_3: f32 = horizontal_sum(sum8s[3]);
|
|
|
|
load2row(self_data, row4_1, col2, self.capacity_cols, self.rows);
|
|
|
|
|
|
|
|
let left_side8_21 =
|
|
|
|
|
|
|
|
load2row(self_data, row4_2, col2, self.capacity_cols, self.rows);
|
|
|
|
|
|
|
|
let left_side8_31 =
|
|
|
|
|
|
|
|
load2row(self_data, row4_3, col2, self.capacity_cols, self.rows);
|
|
|
|
|
|
|
|
sum8s[1][0] = _mm256_fmadd_ps(left_side8_01, right_side8_1, sum8s[1][0]);
|
|
|
|
|
|
|
|
sum8s[1][1] = _mm256_fmadd_ps(left_side8_11, right_side8_1, sum8s[1][1]);
|
|
|
|
|
|
|
|
sum8s[1][2] = _mm256_fmadd_ps(left_side8_21, right_side8_1, sum8s[1][2]);
|
|
|
|
|
|
|
|
sum8s[1][3] = _mm256_fmadd_ps(left_side8_31, right_side8_1, sum8s[1][3]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
let sum_0: f32 = horizontal_sum(sum8s[0][0]) + horizontal_sum(sum8s[1][0]);
|
|
|
|
|
|
|
|
let sum_1: f32 = horizontal_sum(sum8s[0][1]) + horizontal_sum(sum8s[1][1]);
|
|
|
|
|
|
|
|
let sum_2: f32 = horizontal_sum(sum8s[0][2]) + horizontal_sum(sum8s[1][2]);
|
|
|
|
|
|
|
|
let sum_3: f32 = horizontal_sum(sum8s[0][3]) + horizontal_sum(sum8s[1][3]);
|
|
|
|
if row4_0 < result.rows {
|
|
|
|
if row4_0 < result.rows {
|
|
|
|
result.set_f32(row4_0, 0, sum_0);
|
|
|
|
result.set_f32(row4_0, 0, sum_0);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|