|
|
|
@ -135,6 +135,23 @@ impl Drop for Tensor {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Use this to smuggle pointers to threads without Rust getting so goddamn mad
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// Assumption usize = pointer size.
|
|
|
|
|
|
|
|
#[derive(Copy, Clone)]
|
|
|
|
|
|
|
|
struct WrappedPtr {
|
|
|
|
|
|
|
|
ptr: usize,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
impl WrappedPtr {
|
|
|
|
|
|
|
|
fn wrap(ptr: *const u8) -> WrappedPtr {
|
|
|
|
|
|
|
|
WrappedPtr { ptr: ptr as usize }
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn unwrap(self) -> *const u8 {
|
|
|
|
|
|
|
|
self.ptr as *const u8
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fn compute_capacity_cols(dtype: TensorDType, cols: i64) -> i64 {
|
|
|
|
fn compute_capacity_cols(dtype: TensorDType, cols: i64) -> i64 {
|
|
|
|
match dtype {
|
|
|
|
match dtype {
|
|
|
|
TensorDType::Float16 => compute_capacity_cols_f16(cols),
|
|
|
|
TensorDType::Float16 => compute_capacity_cols_f16(cols),
|
|
|
|
@ -1025,12 +1042,12 @@ impl Tensor {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
pub fn is_on_cpu(&self) -> bool {
|
|
|
|
pub fn is_on_cpu(&self) -> bool {
|
|
|
|
return !self.is_on_gpu();
|
|
|
|
!self.is_on_gpu()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Casts data type to whatever the other tensors data type is.
|
|
|
|
// Casts data type to whatever the other tensors data type is.
|
|
|
|
pub fn to_same_type(&self, other: &Tensor) -> Tensor {
|
|
|
|
pub fn to_same_type(&self, other: &Tensor) -> Tensor {
|
|
|
|
let mut result = self.clone();
|
|
|
|
let result = self.clone();
|
|
|
|
if result.dtype() == other.dtype() {
|
|
|
|
if result.dtype() == other.dtype() {
|
|
|
|
return result;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -1115,8 +1132,8 @@ impl Tensor {
|
|
|
|
self.rows as usize * self.capacity_cols as usize,
|
|
|
|
self.rows as usize * self.capacity_cols as usize,
|
|
|
|
);
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
let src_data: *const f32 = src.data as *const f32;
|
|
|
|
let _src_data: *const f32 = src.data as *const f32;
|
|
|
|
let other_data: *const f32 = other.data as *const f32;
|
|
|
|
let _other_data: *const f32 = other.data as *const f32;
|
|
|
|
|
|
|
|
|
|
|
|
let src_rows: usize = src.rows as usize;
|
|
|
|
let src_rows: usize = src.rows as usize;
|
|
|
|
let src_cols: usize = src.cols as usize;
|
|
|
|
let src_cols: usize = src.cols as usize;
|
|
|
|
@ -1145,152 +1162,184 @@ impl Tensor {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
unsafe {
|
|
|
|
unsafe {
|
|
|
|
for row in 0..row_its {
|
|
|
|
let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data);
|
|
|
|
let row0 = row * 4;
|
|
|
|
let other_data: WrappedPtr = WrappedPtr::wrap(other.data);
|
|
|
|
let row1 = row * 4 + 1;
|
|
|
|
let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data);
|
|
|
|
let row2 = row * 4 + 2;
|
|
|
|
(0..32).into_par_iter().for_each(|thread_idx| {
|
|
|
|
let row3 = row * 4 + 3;
|
|
|
|
let src_data: *const f32 = src_data_wrap.unwrap() as *const f32;
|
|
|
|
for col in 0..self_cols_its {
|
|
|
|
let other_data: *const f32 = other_data.unwrap() as *const f32;
|
|
|
|
let col0 = col * 4;
|
|
|
|
let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32;
|
|
|
|
let col1 = col * 4 + 1;
|
|
|
|
for row in 0..row_its {
|
|
|
|
let col2 = col * 4 + 2;
|
|
|
|
let row0 = row * 4;
|
|
|
|
let col3 = col * 4 + 3;
|
|
|
|
let row1 = row * 4 + 1;
|
|
|
|
let mut targets8: [[__m256; 4]; 4] = [
|
|
|
|
let row2 = row * 4 + 2;
|
|
|
|
[
|
|
|
|
let row3 = row * 4 + 3;
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
for col in 0..self_cols_its {
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
let row_col = row * self_cols_its + col;
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
if row_col % 32 != thread_idx {
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
continue;
|
|
|
|
],
|
|
|
|
}
|
|
|
|
[
|
|
|
|
let col0 = col * 4;
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
let col1 = col * 4 + 1;
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
let col2 = col * 4 + 2;
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
let col3 = col * 4 + 3;
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
let mut targets8: [[__m256; 4]; 4] = [
|
|
|
|
],
|
|
|
|
[
|
|
|
|
[
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
],
|
|
|
|
],
|
|
|
|
[
|
|
|
|
[
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
],
|
|
|
|
],
|
|
|
|
[
|
|
|
|
];
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
for p in 0..src_cols_its {
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
let other8_0: __m256 = _mm256_loadu_ps(
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
other_data.add(col0 * other_cols_capacity + p * ITEMS_PER_LINE),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
);
|
|
|
|
],
|
|
|
|
let other8_1: __m256 = if col1 < other_rows {
|
|
|
|
[
|
|
|
|
_mm256_loadu_ps(
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
other_data
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
.add(col1 * other_cols_capacity + p * ITEMS_PER_LINE),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
)
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
} else {
|
|
|
|
],
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
];
|
|
|
|
};
|
|
|
|
for p in 0..src_cols_its {
|
|
|
|
let other8_2: __m256 = if col2 < other_rows {
|
|
|
|
let other8_0: __m256 = _mm256_loadu_ps(
|
|
|
|
_mm256_loadu_ps(
|
|
|
|
|
|
|
|
other_data
|
|
|
|
|
|
|
|
.add(col2 * other_cols_capacity + p * ITEMS_PER_LINE),
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
let other8_3: __m256 = if col3 < other_rows {
|
|
|
|
|
|
|
|
_mm256_loadu_ps(
|
|
|
|
|
|
|
|
other_data
|
|
|
|
other_data
|
|
|
|
.add(col3 * other_cols_capacity + p * ITEMS_PER_LINE),
|
|
|
|
.add(col0 * other_cols_capacity + p * ITEMS_PER_LINE),
|
|
|
|
)
|
|
|
|
);
|
|
|
|
} else {
|
|
|
|
let other8_1: __m256 =
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
if col1 < other_rows {
|
|
|
|
};
|
|
|
|
_mm256_loadu_ps(other_data.add(
|
|
|
|
let src8_0: __m256 = _mm256_loadu_ps(
|
|
|
|
col1 * other_cols_capacity + p * ITEMS_PER_LINE,
|
|
|
|
src_data.add(row0 * src_cols_capacity + p * ITEMS_PER_LINE),
|
|
|
|
))
|
|
|
|
);
|
|
|
|
} else {
|
|
|
|
let src8_1: __m256 = if row1 < src_rows {
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
_mm256_loadu_ps(
|
|
|
|
};
|
|
|
|
src_data.add(row1 * src_cols_capacity + p * ITEMS_PER_LINE),
|
|
|
|
let other8_2: __m256 =
|
|
|
|
)
|
|
|
|
if col2 < other_rows {
|
|
|
|
} else {
|
|
|
|
_mm256_loadu_ps(other_data.add(
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
col2 * other_cols_capacity + p * ITEMS_PER_LINE,
|
|
|
|
};
|
|
|
|
))
|
|
|
|
let src8_2: __m256 = if row2 < src_rows {
|
|
|
|
} else {
|
|
|
|
_mm256_loadu_ps(
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
src_data.add(row2 * src_cols_capacity + p * ITEMS_PER_LINE),
|
|
|
|
};
|
|
|
|
)
|
|
|
|
let other8_3: __m256 =
|
|
|
|
} else {
|
|
|
|
if col3 < other_rows {
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
_mm256_loadu_ps(other_data.add(
|
|
|
|
};
|
|
|
|
col3 * other_cols_capacity + p * ITEMS_PER_LINE,
|
|
|
|
let src8_3: __m256 = if row3 < src_rows {
|
|
|
|
))
|
|
|
|
_mm256_loadu_ps(
|
|
|
|
} else {
|
|
|
|
src_data.add(row3 * src_cols_capacity + p * ITEMS_PER_LINE),
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
)
|
|
|
|
};
|
|
|
|
} else {
|
|
|
|
let src8_0: __m256 = _mm256_loadu_ps(
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
src_data.add(row0 * src_cols_capacity + p * ITEMS_PER_LINE),
|
|
|
|
};
|
|
|
|
);
|
|
|
|
targets8[0][0] = _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]);
|
|
|
|
let src8_1: __m256 = if row1 < src_rows {
|
|
|
|
targets8[0][1] = _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]);
|
|
|
|
_mm256_loadu_ps(
|
|
|
|
targets8[0][2] = _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]);
|
|
|
|
src_data
|
|
|
|
targets8[0][3] = _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]);
|
|
|
|
.add(row1 * src_cols_capacity + p * ITEMS_PER_LINE),
|
|
|
|
targets8[1][0] = _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]);
|
|
|
|
)
|
|
|
|
targets8[1][1] = _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]);
|
|
|
|
} else {
|
|
|
|
targets8[1][2] = _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]);
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
targets8[1][3] = _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]);
|
|
|
|
};
|
|
|
|
targets8[2][0] = _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]);
|
|
|
|
let src8_2: __m256 = if row2 < src_rows {
|
|
|
|
targets8[2][1] = _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]);
|
|
|
|
_mm256_loadu_ps(
|
|
|
|
targets8[2][2] = _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]);
|
|
|
|
src_data
|
|
|
|
targets8[2][3] = _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]);
|
|
|
|
.add(row2 * src_cols_capacity + p * ITEMS_PER_LINE),
|
|
|
|
targets8[3][0] = _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]);
|
|
|
|
)
|
|
|
|
targets8[3][1] = _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]);
|
|
|
|
} else {
|
|
|
|
targets8[3][2] = _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]);
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
targets8[3][3] = _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]);
|
|
|
|
};
|
|
|
|
}
|
|
|
|
let src8_3: __m256 = if row3 < src_rows {
|
|
|
|
let target00: f32 = horizontal_sum(targets8[0][0]);
|
|
|
|
_mm256_loadu_ps(
|
|
|
|
let target01: f32 = horizontal_sum(targets8[0][1]);
|
|
|
|
src_data
|
|
|
|
let target02: f32 = horizontal_sum(targets8[0][2]);
|
|
|
|
.add(row3 * src_cols_capacity + p * ITEMS_PER_LINE),
|
|
|
|
let target03: f32 = horizontal_sum(targets8[0][3]);
|
|
|
|
)
|
|
|
|
let target10: f32 = horizontal_sum(targets8[1][0]);
|
|
|
|
} else {
|
|
|
|
let target11: f32 = horizontal_sum(targets8[1][1]);
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
let target12: f32 = horizontal_sum(targets8[1][2]);
|
|
|
|
};
|
|
|
|
let target13: f32 = horizontal_sum(targets8[1][3]);
|
|
|
|
targets8[0][0] =
|
|
|
|
let target20: f32 = horizontal_sum(targets8[2][0]);
|
|
|
|
_mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]);
|
|
|
|
let target21: f32 = horizontal_sum(targets8[2][1]);
|
|
|
|
targets8[0][1] =
|
|
|
|
let target22: f32 = horizontal_sum(targets8[2][2]);
|
|
|
|
_mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]);
|
|
|
|
let target23: f32 = horizontal_sum(targets8[2][3]);
|
|
|
|
targets8[0][2] =
|
|
|
|
let target30: f32 = horizontal_sum(targets8[3][0]);
|
|
|
|
_mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]);
|
|
|
|
let target31: f32 = horizontal_sum(targets8[3][1]);
|
|
|
|
targets8[0][3] =
|
|
|
|
let target32: f32 = horizontal_sum(targets8[3][2]);
|
|
|
|
_mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]);
|
|
|
|
let target33: f32 = horizontal_sum(targets8[3][3]);
|
|
|
|
targets8[1][0] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]);
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col0) += target00;
|
|
|
|
targets8[1][1] =
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col1) += target10;
|
|
|
|
_mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]);
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col2) += target20;
|
|
|
|
targets8[1][2] =
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col3) += target30;
|
|
|
|
_mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]);
|
|
|
|
if row1 < self_rows {
|
|
|
|
targets8[1][3] =
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col0) += target01;
|
|
|
|
_mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]);
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col1) += target11;
|
|
|
|
targets8[2][0] =
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col2) += target21;
|
|
|
|
_mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]);
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col3) += target31;
|
|
|
|
targets8[2][1] =
|
|
|
|
}
|
|
|
|
_mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]);
|
|
|
|
if row2 < self_rows {
|
|
|
|
targets8[2][2] =
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col0) += target02;
|
|
|
|
_mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]);
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col1) += target12;
|
|
|
|
targets8[2][3] =
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col2) += target22;
|
|
|
|
_mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]);
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col3) += target32;
|
|
|
|
targets8[3][0] =
|
|
|
|
}
|
|
|
|
_mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]);
|
|
|
|
if row3 < self_rows {
|
|
|
|
targets8[3][1] =
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col0) += target03;
|
|
|
|
_mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]);
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col1) += target13;
|
|
|
|
targets8[3][2] =
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col2) += target23;
|
|
|
|
_mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]);
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col3) += target33;
|
|
|
|
targets8[3][3] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
let target00: f32 = horizontal_sum(targets8[0][0]);
|
|
|
|
|
|
|
|
let target01: f32 = horizontal_sum(targets8[0][1]);
|
|
|
|
|
|
|
|
let target02: f32 = horizontal_sum(targets8[0][2]);
|
|
|
|
|
|
|
|
let target03: f32 = horizontal_sum(targets8[0][3]);
|
|
|
|
|
|
|
|
let target10: f32 = horizontal_sum(targets8[1][0]);
|
|
|
|
|
|
|
|
let target11: f32 = horizontal_sum(targets8[1][1]);
|
|
|
|
|
|
|
|
let target12: f32 = horizontal_sum(targets8[1][2]);
|
|
|
|
|
|
|
|
let target13: f32 = horizontal_sum(targets8[1][3]);
|
|
|
|
|
|
|
|
let target20: f32 = horizontal_sum(targets8[2][0]);
|
|
|
|
|
|
|
|
let target21: f32 = horizontal_sum(targets8[2][1]);
|
|
|
|
|
|
|
|
let target22: f32 = horizontal_sum(targets8[2][2]);
|
|
|
|
|
|
|
|
let target23: f32 = horizontal_sum(targets8[2][3]);
|
|
|
|
|
|
|
|
let target30: f32 = horizontal_sum(targets8[3][0]);
|
|
|
|
|
|
|
|
let target31: f32 = horizontal_sum(targets8[3][1]);
|
|
|
|
|
|
|
|
let target32: f32 = horizontal_sum(targets8[3][2]);
|
|
|
|
|
|
|
|
let target33: f32 = horizontal_sum(targets8[3][3]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col0) += target00;
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col1) += target10;
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col2) += target20;
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col3) += target30;
|
|
|
|
|
|
|
|
if row1 < self_rows {
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col0) += target01;
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col1) += target11;
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col2) += target21;
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col3) += target31;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if row2 < self_rows {
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col0) += target02;
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col1) += target12;
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col2) += target22;
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col3) += target32;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if row3 < self_rows {
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col0) += target03;
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col1) += target13;
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col2) += target23;
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col3) += target33;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
TensorDType::Float16 => {
|
|
|
|
TensorDType::Float16 => {
|
|
|
|
@ -1304,8 +1353,8 @@ impl Tensor {
|
|
|
|
self.rows as usize * self.capacity_cols as usize,
|
|
|
|
self.rows as usize * self.capacity_cols as usize,
|
|
|
|
);
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
let src_data: *const f16 = src.data as *const f16;
|
|
|
|
let _src_data: *const f16 = src.data as *const f16;
|
|
|
|
let other_data: *const f16 = other.data as *const f16;
|
|
|
|
let _other_data: *const f16 = other.data as *const f16;
|
|
|
|
|
|
|
|
|
|
|
|
let src_rows: usize = src.rows as usize;
|
|
|
|
let src_rows: usize = src.rows as usize;
|
|
|
|
let src_cols: usize = src.cols as usize;
|
|
|
|
let src_cols: usize = src.cols as usize;
|
|
|
|
@ -1334,160 +1383,192 @@ impl Tensor {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
unsafe {
|
|
|
|
unsafe {
|
|
|
|
for row in 0..row_its {
|
|
|
|
let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data);
|
|
|
|
let row0 = row * 4;
|
|
|
|
let other_data: WrappedPtr = WrappedPtr::wrap(other.data);
|
|
|
|
let row1 = row * 4 + 1;
|
|
|
|
let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data);
|
|
|
|
let row2 = row * 4 + 2;
|
|
|
|
(0..32).into_par_iter().for_each(|thread_idx| {
|
|
|
|
let row3 = row * 4 + 3;
|
|
|
|
let src_data: *const f16 = src_data_wrap.unwrap() as *const f16;
|
|
|
|
for col in 0..self_cols_its {
|
|
|
|
let other_data: *const f16 = other_data.unwrap() as *const f16;
|
|
|
|
let col0 = col * 4;
|
|
|
|
let tgt_data: *mut f16 = tgt_data.unwrap() as *mut f16;
|
|
|
|
let col1 = col * 4 + 1;
|
|
|
|
for row in 0..row_its {
|
|
|
|
let col2 = col * 4 + 2;
|
|
|
|
let row0 = row * 4;
|
|
|
|
let col3 = col * 4 + 3;
|
|
|
|
let row1 = row * 4 + 1;
|
|
|
|
let mut targets8: [[__m256; 4]; 4] = [
|
|
|
|
let row2 = row * 4 + 2;
|
|
|
|
[
|
|
|
|
let row3 = row * 4 + 3;
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
for col in 0..self_cols_its {
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
let row_col = row * self_cols_its + col;
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
if row_col % 32 != thread_idx {
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
continue;
|
|
|
|
],
|
|
|
|
}
|
|
|
|
[
|
|
|
|
let col0 = col * 4;
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
let col1 = col * 4 + 1;
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
let col2 = col * 4 + 2;
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
let col3 = col * 4 + 3;
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
let mut targets8: [[__m256; 4]; 4] = [
|
|
|
|
],
|
|
|
|
[
|
|
|
|
[
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
],
|
|
|
|
],
|
|
|
|
[
|
|
|
|
[
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
],
|
|
|
|
],
|
|
|
|
[
|
|
|
|
];
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
// Loads from (row, column..column+8) and (row+1, column..column+8)
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
#[inline]
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
fn load2_rows(
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
ptr: *const f16,
|
|
|
|
],
|
|
|
|
row: usize,
|
|
|
|
[
|
|
|
|
column: usize,
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
cols_capacity: usize,
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
nrows: usize,
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
) -> (__m256, __m256) {
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
unsafe {
|
|
|
|
],
|
|
|
|
let (left, right) = if row + 1 < nrows {
|
|
|
|
];
|
|
|
|
(
|
|
|
|
// Loads from (row, column..column+8) and (row+1, column..column+8)
|
|
|
|
_mm_loadu_si128(ptr.add(row * cols_capacity + column)
|
|
|
|
#[inline]
|
|
|
|
as *const __m128i),
|
|
|
|
fn load2_rows(
|
|
|
|
_mm_loadu_si128(
|
|
|
|
ptr: *const f16,
|
|
|
|
ptr.add((row + 1) * cols_capacity + column)
|
|
|
|
row: usize,
|
|
|
|
as *const __m128i,
|
|
|
|
column: usize,
|
|
|
|
),
|
|
|
|
cols_capacity: usize,
|
|
|
|
)
|
|
|
|
nrows: usize,
|
|
|
|
} else {
|
|
|
|
) -> (__m256, __m256) {
|
|
|
|
(
|
|
|
|
unsafe {
|
|
|
|
_mm_loadu_si128(ptr.add(row * cols_capacity + column)
|
|
|
|
let (left, right) = if row + 1 < nrows {
|
|
|
|
as *const __m128i),
|
|
|
|
(
|
|
|
|
_mm_setzero_si128(),
|
|
|
|
_mm_loadu_si128(
|
|
|
|
)
|
|
|
|
ptr.add(row * cols_capacity + column)
|
|
|
|
};
|
|
|
|
as *const __m128i,
|
|
|
|
let left: __m256 = _mm256_cvtph_ps(left);
|
|
|
|
),
|
|
|
|
let right: __m256 = _mm256_cvtph_ps(right);
|
|
|
|
_mm_loadu_si128(
|
|
|
|
(left, right)
|
|
|
|
ptr.add((row + 1) * cols_capacity + column)
|
|
|
|
|
|
|
|
as *const __m128i,
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
(
|
|
|
|
|
|
|
|
_mm_loadu_si128(
|
|
|
|
|
|
|
|
ptr.add(row * cols_capacity + column)
|
|
|
|
|
|
|
|
as *const __m128i,
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
_mm_setzero_si128(),
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
let left: __m256 = _mm256_cvtph_ps(left);
|
|
|
|
|
|
|
|
let right: __m256 = _mm256_cvtph_ps(right);
|
|
|
|
|
|
|
|
(left, right)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for p in 0..src_cols_its {
|
|
|
|
|
|
|
|
let (other8_0, other8_1) = load2_rows(
|
|
|
|
|
|
|
|
other_data,
|
|
|
|
|
|
|
|
col0,
|
|
|
|
|
|
|
|
p * ITEMS_PER_LINE,
|
|
|
|
|
|
|
|
other_cols_capacity,
|
|
|
|
|
|
|
|
other_rows,
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
let (other8_2, other8_3) = load2_rows(
|
|
|
|
|
|
|
|
other_data,
|
|
|
|
|
|
|
|
col2,
|
|
|
|
|
|
|
|
p * ITEMS_PER_LINE,
|
|
|
|
|
|
|
|
other_cols_capacity,
|
|
|
|
|
|
|
|
other_rows,
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
let (src8_0, src8_1) = load2_rows(
|
|
|
|
|
|
|
|
src_data,
|
|
|
|
|
|
|
|
row0,
|
|
|
|
|
|
|
|
p * ITEMS_PER_LINE,
|
|
|
|
|
|
|
|
src_cols_capacity,
|
|
|
|
|
|
|
|
src_rows,
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
let (src8_2, src8_3) = load2_rows(
|
|
|
|
|
|
|
|
src_data,
|
|
|
|
|
|
|
|
row2,
|
|
|
|
|
|
|
|
p * ITEMS_PER_LINE,
|
|
|
|
|
|
|
|
src_cols_capacity,
|
|
|
|
|
|
|
|
src_rows,
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
targets8[0][0] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]);
|
|
|
|
|
|
|
|
targets8[0][1] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]);
|
|
|
|
|
|
|
|
targets8[0][2] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]);
|
|
|
|
|
|
|
|
targets8[0][3] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]);
|
|
|
|
|
|
|
|
targets8[1][0] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]);
|
|
|
|
|
|
|
|
targets8[1][1] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]);
|
|
|
|
|
|
|
|
targets8[1][2] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]);
|
|
|
|
|
|
|
|
targets8[1][3] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]);
|
|
|
|
|
|
|
|
targets8[2][0] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]);
|
|
|
|
|
|
|
|
targets8[2][1] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]);
|
|
|
|
|
|
|
|
targets8[2][2] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]);
|
|
|
|
|
|
|
|
targets8[2][3] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]);
|
|
|
|
|
|
|
|
targets8[3][0] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]);
|
|
|
|
|
|
|
|
targets8[3][1] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]);
|
|
|
|
|
|
|
|
targets8[3][2] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]);
|
|
|
|
|
|
|
|
targets8[3][3] =
|
|
|
|
|
|
|
|
_mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
let target00: f16 = horizontal_sum_f32_to_f16(targets8[0][0]);
|
|
|
|
|
|
|
|
let target01: f16 = horizontal_sum_f32_to_f16(targets8[0][1]);
|
|
|
|
|
|
|
|
let target02: f16 = horizontal_sum_f32_to_f16(targets8[0][2]);
|
|
|
|
|
|
|
|
let target03: f16 = horizontal_sum_f32_to_f16(targets8[0][3]);
|
|
|
|
|
|
|
|
let target10: f16 = horizontal_sum_f32_to_f16(targets8[1][0]);
|
|
|
|
|
|
|
|
let target11: f16 = horizontal_sum_f32_to_f16(targets8[1][1]);
|
|
|
|
|
|
|
|
let target12: f16 = horizontal_sum_f32_to_f16(targets8[1][2]);
|
|
|
|
|
|
|
|
let target13: f16 = horizontal_sum_f32_to_f16(targets8[1][3]);
|
|
|
|
|
|
|
|
let target20: f16 = horizontal_sum_f32_to_f16(targets8[2][0]);
|
|
|
|
|
|
|
|
let target21: f16 = horizontal_sum_f32_to_f16(targets8[2][1]);
|
|
|
|
|
|
|
|
let target22: f16 = horizontal_sum_f32_to_f16(targets8[2][2]);
|
|
|
|
|
|
|
|
let target23: f16 = horizontal_sum_f32_to_f16(targets8[2][3]);
|
|
|
|
|
|
|
|
let target30: f16 = horizontal_sum_f32_to_f16(targets8[3][0]);
|
|
|
|
|
|
|
|
let target31: f16 = horizontal_sum_f32_to_f16(targets8[3][1]);
|
|
|
|
|
|
|
|
let target32: f16 = horizontal_sum_f32_to_f16(targets8[3][2]);
|
|
|
|
|
|
|
|
let target33: f16 = horizontal_sum_f32_to_f16(targets8[3][3]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col0) += target00;
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col1) += target10;
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col2) += target20;
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col3) += target30;
|
|
|
|
|
|
|
|
if row1 < self_rows {
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col0) += target01;
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col1) += target11;
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col2) += target21;
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col3) += target31;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if row2 < self_rows {
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col0) += target02;
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col1) += target12;
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col2) += target22;
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col3) += target32;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if row3 < self_rows {
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col0) += target03;
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col1) += target13;
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col2) += target23;
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col3) += target33;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
for p in 0..src_cols_its {
|
|
|
|
|
|
|
|
let (other8_0, other8_1) = load2_rows(
|
|
|
|
|
|
|
|
other_data,
|
|
|
|
|
|
|
|
col0,
|
|
|
|
|
|
|
|
p * ITEMS_PER_LINE,
|
|
|
|
|
|
|
|
other_cols_capacity,
|
|
|
|
|
|
|
|
other_rows,
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
let (other8_2, other8_3) = load2_rows(
|
|
|
|
|
|
|
|
other_data,
|
|
|
|
|
|
|
|
col2,
|
|
|
|
|
|
|
|
p * ITEMS_PER_LINE,
|
|
|
|
|
|
|
|
other_cols_capacity,
|
|
|
|
|
|
|
|
other_rows,
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
let (src8_0, src8_1) = load2_rows(
|
|
|
|
|
|
|
|
src_data,
|
|
|
|
|
|
|
|
row0,
|
|
|
|
|
|
|
|
p * ITEMS_PER_LINE,
|
|
|
|
|
|
|
|
src_cols_capacity,
|
|
|
|
|
|
|
|
src_rows,
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
let (src8_2, src8_3) = load2_rows(
|
|
|
|
|
|
|
|
src_data,
|
|
|
|
|
|
|
|
row2,
|
|
|
|
|
|
|
|
p * ITEMS_PER_LINE,
|
|
|
|
|
|
|
|
src_cols_capacity,
|
|
|
|
|
|
|
|
src_rows,
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
targets8[0][0] = _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]);
|
|
|
|
|
|
|
|
targets8[0][1] = _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]);
|
|
|
|
|
|
|
|
targets8[0][2] = _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]);
|
|
|
|
|
|
|
|
targets8[0][3] = _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]);
|
|
|
|
|
|
|
|
targets8[1][0] = _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]);
|
|
|
|
|
|
|
|
targets8[1][1] = _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]);
|
|
|
|
|
|
|
|
targets8[1][2] = _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]);
|
|
|
|
|
|
|
|
targets8[1][3] = _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]);
|
|
|
|
|
|
|
|
targets8[2][0] = _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]);
|
|
|
|
|
|
|
|
targets8[2][1] = _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]);
|
|
|
|
|
|
|
|
targets8[2][2] = _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]);
|
|
|
|
|
|
|
|
targets8[2][3] = _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]);
|
|
|
|
|
|
|
|
targets8[3][0] = _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]);
|
|
|
|
|
|
|
|
targets8[3][1] = _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]);
|
|
|
|
|
|
|
|
targets8[3][2] = _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]);
|
|
|
|
|
|
|
|
targets8[3][3] = _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
let target00: f16 = horizontal_sum_f32_to_f16(targets8[0][0]);
|
|
|
|
|
|
|
|
let target01: f16 = horizontal_sum_f32_to_f16(targets8[0][1]);
|
|
|
|
|
|
|
|
let target02: f16 = horizontal_sum_f32_to_f16(targets8[0][2]);
|
|
|
|
|
|
|
|
let target03: f16 = horizontal_sum_f32_to_f16(targets8[0][3]);
|
|
|
|
|
|
|
|
let target10: f16 = horizontal_sum_f32_to_f16(targets8[1][0]);
|
|
|
|
|
|
|
|
let target11: f16 = horizontal_sum_f32_to_f16(targets8[1][1]);
|
|
|
|
|
|
|
|
let target12: f16 = horizontal_sum_f32_to_f16(targets8[1][2]);
|
|
|
|
|
|
|
|
let target13: f16 = horizontal_sum_f32_to_f16(targets8[1][3]);
|
|
|
|
|
|
|
|
let target20: f16 = horizontal_sum_f32_to_f16(targets8[2][0]);
|
|
|
|
|
|
|
|
let target21: f16 = horizontal_sum_f32_to_f16(targets8[2][1]);
|
|
|
|
|
|
|
|
let target22: f16 = horizontal_sum_f32_to_f16(targets8[2][2]);
|
|
|
|
|
|
|
|
let target23: f16 = horizontal_sum_f32_to_f16(targets8[2][3]);
|
|
|
|
|
|
|
|
let target30: f16 = horizontal_sum_f32_to_f16(targets8[3][0]);
|
|
|
|
|
|
|
|
let target31: f16 = horizontal_sum_f32_to_f16(targets8[3][1]);
|
|
|
|
|
|
|
|
let target32: f16 = horizontal_sum_f32_to_f16(targets8[3][2]);
|
|
|
|
|
|
|
|
let target33: f16 = horizontal_sum_f32_to_f16(targets8[3][3]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col0) += target00;
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col1) += target10;
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col2) += target20;
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col3) += target30;
|
|
|
|
|
|
|
|
if row1 < self_rows {
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col0) += target01;
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col1) += target11;
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col2) += target21;
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col3) += target31;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if row2 < self_rows {
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col0) += target02;
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col1) += target12;
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col2) += target22;
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col3) += target32;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if row3 < self_rows {
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col0) += target03;
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col1) += target13;
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col2) += target23;
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col3) += target33;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -1534,6 +1615,7 @@ impl Tensor {
|
|
|
|
assert_eq!(other.rows, 1);
|
|
|
|
assert_eq!(other.rows, 1);
|
|
|
|
assert_eq!(other.dtype, self.dtype);
|
|
|
|
assert_eq!(other.dtype, self.dtype);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[allow(unreachable_patterns)]
|
|
|
|
match self.dtype {
|
|
|
|
match self.dtype {
|
|
|
|
TensorDType::Float32 => self.matrix_vector_mul_transposed_f32(other),
|
|
|
|
TensorDType::Float32 => self.matrix_vector_mul_transposed_f32(other),
|
|
|
|
TensorDType::Float16 => self.matrix_vector_mul_transposed_f16(other),
|
|
|
|
TensorDType::Float16 => self.matrix_vector_mul_transposed_f16(other),
|
|
|
|
@ -1669,7 +1751,7 @@ impl Tensor {
|
|
|
|
self.assume_on_cpu();
|
|
|
|
self.assume_on_cpu();
|
|
|
|
other.assume_on_cpu();
|
|
|
|
other.assume_on_cpu();
|
|
|
|
unsafe {
|
|
|
|
unsafe {
|
|
|
|
let mut result = Tensor::uninitialized(self.rows, 1, self.dtype);
|
|
|
|
let result = Tensor::zeros(self.rows, 1, self.dtype);
|
|
|
|
let col_its: usize = if self.cols % 8 == 0 {
|
|
|
|
let col_its: usize = if self.cols % 8 == 0 {
|
|
|
|
(self.cols / 8) as usize
|
|
|
|
(self.cols / 8) as usize
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
@ -1680,16 +1762,18 @@ impl Tensor {
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
(self.rows / 4 + 1) as usize
|
|
|
|
(self.rows / 4 + 1) as usize
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
let self_data: *const f32 = self.data as *const f32;
|
|
|
|
|
|
|
|
let other_data: *const f32 = other.data as *const f32;
|
|
|
|
|
|
|
|
let tgt_data: *mut f32 = result.data as *mut f32;
|
|
|
|
|
|
|
|
let ncols_capacity: usize = result.capacity_cols as usize;
|
|
|
|
|
|
|
|
|
|
|
|
let mut sum8s: [__m256; 4] = [
|
|
|
|
let mut sum8s: [__m256; 4] = [
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
];
|
|
|
|
];
|
|
|
|
let self_data: *const f32 = self.data as *const f32;
|
|
|
|
|
|
|
|
let other_data: *const f32 = other.data as *const f32;
|
|
|
|
|
|
|
|
let _tgt_data: *mut f32 = result.data as *mut f32;
|
|
|
|
|
|
|
|
let _ncols_capacity: usize = result.capacity_cols as usize;
|
|
|
|
|
|
|
|
for row in 0..row_its {
|
|
|
|
for row in 0..row_its {
|
|
|
|
let row: i64 = row as i64;
|
|
|
|
let row: i64 = row as i64;
|
|
|
|
sum8s[0] = _mm256_setzero_ps();
|
|
|
|
sum8s[0] = _mm256_setzero_ps();
|
|
|
|
@ -1732,91 +1816,22 @@ impl Tensor {
|
|
|
|
let sum_2: f32 = horizontal_sum(sum8s[2]);
|
|
|
|
let sum_2: f32 = horizontal_sum(sum8s[2]);
|
|
|
|
let sum_3: f32 = horizontal_sum(sum8s[3]);
|
|
|
|
let sum_3: f32 = horizontal_sum(sum8s[3]);
|
|
|
|
if row4_0 < result.rows {
|
|
|
|
if row4_0 < result.rows {
|
|
|
|
result.set_f32(row4_0, 0, sum_0);
|
|
|
|
*(tgt_data.add(row4_0 as usize * ncols_capacity)) = sum_0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if row4_1 < result.rows {
|
|
|
|
if row4_1 < result.rows {
|
|
|
|
result.set_f32(row4_1, 0, sum_1);
|
|
|
|
*(tgt_data.add(row4_1 as usize * ncols_capacity)) = sum_1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if row4_2 < result.rows {
|
|
|
|
if row4_2 < result.rows {
|
|
|
|
result.set_f32(row4_2, 0, sum_2);
|
|
|
|
*(tgt_data.add(row4_2 as usize * ncols_capacity)) = sum_2;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if row4_3 < result.rows {
|
|
|
|
if row4_3 < result.rows {
|
|
|
|
result.set_f32(row4_3, 0, sum_3);
|
|
|
|
*(tgt_data.add(row4_3 as usize * ncols_capacity)) = sum_3;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
result
|
|
|
|
result
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/// Same as matrix_vector_mul but uses threading.
|
|
|
|
|
|
|
|
pub fn matrix_vector_mul_transposed_multithreaded(&self, other: &Tensor) -> Tensor {
|
|
|
|
|
|
|
|
self.assume_on_cpu();
|
|
|
|
|
|
|
|
other.assume_on_cpu();
|
|
|
|
|
|
|
|
if self.cols != other.cols {
|
|
|
|
|
|
|
|
panic!(
|
|
|
|
|
|
|
|
"Invalid matrix-vector transposed multiplication {}x{} vs {}x{}",
|
|
|
|
|
|
|
|
self.rows, self.cols, other.rows, other.cols
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
assert_eq!(other.rows, 1);
|
|
|
|
|
|
|
|
assert_eq!(other.dtype, self.dtype);
|
|
|
|
|
|
|
|
assert_eq!(self.dtype, TensorDType::Float32);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Use this to smuggle pointers to threads without Rust getting so goddamn mad
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// Assumption usize = pointer size.
|
|
|
|
|
|
|
|
#[derive(Copy, Clone)]
|
|
|
|
|
|
|
|
struct WrappedPtr {
|
|
|
|
|
|
|
|
ptr: usize,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
impl WrappedPtr {
|
|
|
|
|
|
|
|
fn wrap(ptr: *const u8) -> WrappedPtr {
|
|
|
|
|
|
|
|
WrappedPtr { ptr: ptr as usize }
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn unwrap(self) -> *const u8 {
|
|
|
|
|
|
|
|
self.ptr as *const u8
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsafe {
|
|
|
|
|
|
|
|
let result = Tensor::uninitialized(self.rows, 1, self.dtype);
|
|
|
|
|
|
|
|
let capacity_cols: i64 = self.capacity_cols;
|
|
|
|
|
|
|
|
let result_capacity_cols: i64 = result.capacity_cols;
|
|
|
|
|
|
|
|
let col_its: usize = if self.cols % 8 == 0 {
|
|
|
|
|
|
|
|
(self.cols / 8) as usize
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
(self.cols / 8 + 1) as usize
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
let self_data = WrappedPtr::wrap(self.data);
|
|
|
|
|
|
|
|
let other_data = WrappedPtr::wrap(other.data);
|
|
|
|
|
|
|
|
let result_data = WrappedPtr::wrap(result.data);
|
|
|
|
|
|
|
|
(0..self.rows as usize)
|
|
|
|
|
|
|
|
.into_par_iter()
|
|
|
|
|
|
|
|
.with_min_len(64)
|
|
|
|
|
|
|
|
.for_each(|row| {
|
|
|
|
|
|
|
|
let row = row as i64;
|
|
|
|
|
|
|
|
let self_data: *const f32 = self_data.unwrap() as *const f32;
|
|
|
|
|
|
|
|
let other_data: *const f32 = other_data.unwrap() as *const f32;
|
|
|
|
|
|
|
|
let result_data: *mut f32 = result_data.unwrap() as *mut f32;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let mut sum8: __m256 = _mm256_setzero_ps();
|
|
|
|
|
|
|
|
for col in 0..col_its {
|
|
|
|
|
|
|
|
let col = col * 8;
|
|
|
|
|
|
|
|
let left_side8 =
|
|
|
|
|
|
|
|
_mm256_loadu_ps(self_data.add((row * capacity_cols) as usize + col));
|
|
|
|
|
|
|
|
let right_side8 = _mm256_loadu_ps(other_data.add(col));
|
|
|
|
|
|
|
|
sum8 = _mm256_fmadd_ps(left_side8, right_side8, sum8);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
let sum: f32 = horizontal_sum(sum8);
|
|
|
|
|
|
|
|
result_data
|
|
|
|
|
|
|
|
.add((row * result_capacity_cols) as usize)
|
|
|
|
|
|
|
|
.write(sum);
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
result
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Computes matrix multiplication assuming left side has number of rows as 1
|
|
|
|
// Computes matrix multiplication assuming left side has number of rows as 1
|
|
|
|
#[allow(clippy::erasing_op)]
|
|
|
|
#[allow(clippy::erasing_op)]
|
|
|
|
#[allow(clippy::identity_op)]
|
|
|
|
#[allow(clippy::identity_op)]
|
|
|
|
|