|
|
|
@ -1495,8 +1495,115 @@ impl Tensor {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
assert_eq!(other.rows, 1);
|
|
|
|
assert_eq!(other.rows, 1);
|
|
|
|
assert_eq!(other.dtype, self.dtype);
|
|
|
|
assert_eq!(other.dtype, self.dtype);
|
|
|
|
assert_eq!(self.dtype, TensorDType::Float32);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
match self.dtype {
|
|
|
|
|
|
|
|
TensorDType::Float32 => self.matrix_vector_mul_transposed_f32(other),
|
|
|
|
|
|
|
|
TensorDType::Float16 => self.matrix_vector_mul_transposed_f16(other),
|
|
|
|
|
|
|
|
_ => panic!("Unsupported dtype"),
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn matrix_vector_mul_transposed_f16(&self, other: &Tensor) -> Tensor {
|
|
|
|
|
|
|
|
self.assume_on_cpu();
|
|
|
|
|
|
|
|
other.assume_on_cpu();
|
|
|
|
|
|
|
|
unsafe {
|
|
|
|
|
|
|
|
let mut result = Tensor::uninitialized(self.rows, 1, self.dtype);
|
|
|
|
|
|
|
|
let col_its: usize = if self.cols % 8 == 0 {
|
|
|
|
|
|
|
|
(self.cols / 8) as usize
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
(self.cols / 8 + 1) as usize
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
let row_its: usize = if self.rows % 4 == 0 {
|
|
|
|
|
|
|
|
(self.rows / 4) as usize
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
(self.rows / 4 + 1) as usize
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
let mut sum8s: [__m256; 4] = [
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
];
|
|
|
|
|
|
|
|
let self_data: *const f16 = self.data as *const f16;
|
|
|
|
|
|
|
|
let other_data: *const f16 = other.data as *const f16;
|
|
|
|
|
|
|
|
let _ncols_capacity: usize = result.capacity_cols as usize;
|
|
|
|
|
|
|
|
for row in 0..row_its {
|
|
|
|
|
|
|
|
let row: i64 = row as i64;
|
|
|
|
|
|
|
|
sum8s[0] = _mm256_setzero_ps();
|
|
|
|
|
|
|
|
sum8s[1] = _mm256_setzero_ps();
|
|
|
|
|
|
|
|
sum8s[2] = _mm256_setzero_ps();
|
|
|
|
|
|
|
|
sum8s[3] = _mm256_setzero_ps();
|
|
|
|
|
|
|
|
let row4_0 = row * 4;
|
|
|
|
|
|
|
|
let row4_1 = row * 4 + 1;
|
|
|
|
|
|
|
|
let row4_2 = row * 4 + 2;
|
|
|
|
|
|
|
|
let row4_3 = row * 4 + 3;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Loads from (0, column..column+8)
|
|
|
|
|
|
|
|
#[inline]
|
|
|
|
|
|
|
|
fn load2(ptr: *const f16, col: usize) -> __m256 {
|
|
|
|
|
|
|
|
unsafe { _mm256_cvtph_ps(_mm_loadu_si128(ptr.add(col) as *const __m128i)) }
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// Loads from (row, column..column+8)
|
|
|
|
|
|
|
|
#[inline]
|
|
|
|
|
|
|
|
fn load2row(
|
|
|
|
|
|
|
|
ptr: *const f16,
|
|
|
|
|
|
|
|
row: i64,
|
|
|
|
|
|
|
|
col: usize,
|
|
|
|
|
|
|
|
cols_capacity: i64,
|
|
|
|
|
|
|
|
nrows: i64,
|
|
|
|
|
|
|
|
) -> __m256 {
|
|
|
|
|
|
|
|
unsafe {
|
|
|
|
|
|
|
|
if row < nrows {
|
|
|
|
|
|
|
|
_mm256_cvtph_ps(_mm_loadu_si128(
|
|
|
|
|
|
|
|
ptr.add(row as usize * cols_capacity as usize + col)
|
|
|
|
|
|
|
|
as *const __m128i,
|
|
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for col in 0..col_its {
|
|
|
|
|
|
|
|
let col = col * 8;
|
|
|
|
|
|
|
|
let right_side8 = load2(other_data, col);
|
|
|
|
|
|
|
|
let left_side8_0 =
|
|
|
|
|
|
|
|
load2row(self_data, row4_0, col, self.capacity_cols, self.rows);
|
|
|
|
|
|
|
|
let left_side8_1 =
|
|
|
|
|
|
|
|
load2row(self_data, row4_1, col, self.capacity_cols, self.rows);
|
|
|
|
|
|
|
|
let left_side8_2 =
|
|
|
|
|
|
|
|
load2row(self_data, row4_2, col, self.capacity_cols, self.rows);
|
|
|
|
|
|
|
|
let left_side8_3 =
|
|
|
|
|
|
|
|
load2row(self_data, row4_3, col, self.capacity_cols, self.rows);
|
|
|
|
|
|
|
|
sum8s[0] = _mm256_fmadd_ps(left_side8_0, right_side8, sum8s[0]);
|
|
|
|
|
|
|
|
sum8s[1] = _mm256_fmadd_ps(left_side8_1, right_side8, sum8s[1]);
|
|
|
|
|
|
|
|
sum8s[2] = _mm256_fmadd_ps(left_side8_2, right_side8, sum8s[2]);
|
|
|
|
|
|
|
|
sum8s[3] = _mm256_fmadd_ps(left_side8_3, right_side8, sum8s[3]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
let sum_0: f32 = horizontal_sum(sum8s[0]);
|
|
|
|
|
|
|
|
let sum_1: f32 = horizontal_sum(sum8s[1]);
|
|
|
|
|
|
|
|
let sum_2: f32 = horizontal_sum(sum8s[2]);
|
|
|
|
|
|
|
|
let sum_3: f32 = horizontal_sum(sum8s[3]);
|
|
|
|
|
|
|
|
if row4_0 < result.rows {
|
|
|
|
|
|
|
|
result.set_f32(row4_0, 0, sum_0);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if row4_1 < result.rows {
|
|
|
|
|
|
|
|
result.set_f32(row4_1, 0, sum_1);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if row4_2 < result.rows {
|
|
|
|
|
|
|
|
result.set_f32(row4_2, 0, sum_2);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if row4_3 < result.rows {
|
|
|
|
|
|
|
|
result.set_f32(row4_3, 0, sum_3);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
result
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn matrix_vector_mul_transposed_f32(&self, other: &Tensor) -> Tensor {
|
|
|
|
|
|
|
|
self.assume_on_cpu();
|
|
|
|
|
|
|
|
other.assume_on_cpu();
|
|
|
|
unsafe {
|
|
|
|
unsafe {
|
|
|
|
let mut result = Tensor::uninitialized(self.rows, 1, self.dtype);
|
|
|
|
let mut result = Tensor::uninitialized(self.rows, 1, self.dtype);
|
|
|
|
let col_its: usize = if self.cols % 8 == 0 {
|
|
|
|
let col_its: usize = if self.cols % 8 == 0 {
|
|
|
|
@ -2321,6 +2428,35 @@ mod tests {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
|
|
|
fn mat_vector_mul_transposed_f32_agrees_mat_vector_mul_transposed_f16() {
|
|
|
|
|
|
|
|
let mut rng = rand::thread_rng();
|
|
|
|
|
|
|
|
for _ in 0..1000 {
|
|
|
|
|
|
|
|
let a = rng.gen_range(1..=128);
|
|
|
|
|
|
|
|
let r = rng.gen_range(1..=128);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Make matrixes AxR and Rx1
|
|
|
|
|
|
|
|
let a = Tensor::random(a, r, TensorDType::Float32);
|
|
|
|
|
|
|
|
let b = Tensor::random(r, 1, TensorDType::Float32);
|
|
|
|
|
|
|
|
let a2 = a.clone().to_f16();
|
|
|
|
|
|
|
|
let b2 = b.clone().to_f16();
|
|
|
|
|
|
|
|
let b_transposed = b.transpose();
|
|
|
|
|
|
|
|
let b2_transposed = b2.transpose();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let c = a.matrix_vector_mul_transposed(&b_transposed);
|
|
|
|
|
|
|
|
let c2 = a2.matrix_vector_mul_transposed(&b2_transposed);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert_eq!(c.rows, c2.rows);
|
|
|
|
|
|
|
|
assert_eq!(c.cols, c2.cols);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for row in 0..c.rows {
|
|
|
|
|
|
|
|
for col in 0..c.cols {
|
|
|
|
|
|
|
|
assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-1);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
#[test]
|
|
|
|
fn view_preserves_values() {
|
|
|
|
fn view_preserves_values() {
|
|
|
|
fn test_with_type(dtype: TensorDType) {
|
|
|
|
fn test_with_type(dtype: TensorDType) {
|
|
|
|
|