|
|
|
@ -1072,6 +1072,8 @@ impl Tensor {
|
|
|
|
let src_cols: usize = src.cols as usize;
|
|
|
|
let src_cols: usize = src.cols as usize;
|
|
|
|
let self_rows: usize = self.rows as usize;
|
|
|
|
let self_rows: usize = self.rows as usize;
|
|
|
|
let self_cols: usize = self.cols as usize;
|
|
|
|
let self_cols: usize = self.cols as usize;
|
|
|
|
|
|
|
|
let _other_cols: usize = other.cols as usize;
|
|
|
|
|
|
|
|
let other_rows: usize = other.rows as usize;
|
|
|
|
let other_cols_capacity: usize = other.capacity_cols as usize;
|
|
|
|
let other_cols_capacity: usize = other.capacity_cols as usize;
|
|
|
|
let src_cols_capacity: usize = src.capacity_cols as usize;
|
|
|
|
let src_cols_capacity: usize = src.capacity_cols as usize;
|
|
|
|
let self_cols_capacity: usize = self.capacity_cols as usize;
|
|
|
|
let self_cols_capacity: usize = self.capacity_cols as usize;
|
|
|
|
@ -1086,6 +1088,11 @@ impl Tensor {
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
self_rows / 4 + 1
|
|
|
|
self_rows / 4 + 1
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
let self_cols_its = if self_cols % 4 == 0 {
|
|
|
|
|
|
|
|
self_cols / 4
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
self_cols / 4 + 1
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
unsafe {
|
|
|
|
unsafe {
|
|
|
|
for row in 0..row_its {
|
|
|
|
for row in 0..row_its {
|
|
|
|
@ -1093,18 +1100,66 @@ impl Tensor {
|
|
|
|
let row1 = row * 4 + 1;
|
|
|
|
let row1 = row * 4 + 1;
|
|
|
|
let row2 = row * 4 + 2;
|
|
|
|
let row2 = row * 4 + 2;
|
|
|
|
let row3 = row * 4 + 3;
|
|
|
|
let row3 = row * 4 + 3;
|
|
|
|
for col in 0..self_cols {
|
|
|
|
for col in 0..self_cols_its {
|
|
|
|
let mut targets8: [__m256; 4] = [
|
|
|
|
let col0 = col * 4;
|
|
|
|
|
|
|
|
let col1 = col * 4 + 1;
|
|
|
|
|
|
|
|
let col2 = col * 4 + 2;
|
|
|
|
|
|
|
|
let col3 = col * 4 + 3;
|
|
|
|
|
|
|
|
let mut targets8: [[__m256; 4]; 4] = [
|
|
|
|
|
|
|
|
[
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
[
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
[
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
[
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
_mm256_setzero_ps(),
|
|
|
|
|
|
|
|
],
|
|
|
|
];
|
|
|
|
];
|
|
|
|
for p in 0..src_cols_its {
|
|
|
|
for p in 0..src_cols_its {
|
|
|
|
let other8: __m256 = _mm256_loadu_ps(
|
|
|
|
let other8_0: __m256 = _mm256_loadu_ps(
|
|
|
|
other_data
|
|
|
|
other_data
|
|
|
|
.add(col * other_cols_capacity + p * ITEMS_PER_CACHE_LINE),
|
|
|
|
.add(col0 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE),
|
|
|
|
);
|
|
|
|
);
|
|
|
|
|
|
|
|
let other8_1: __m256 =
|
|
|
|
|
|
|
|
if col1 < other_rows {
|
|
|
|
|
|
|
|
_mm256_loadu_ps(other_data.add(
|
|
|
|
|
|
|
|
col1 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE,
|
|
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
let other8_2: __m256 =
|
|
|
|
|
|
|
|
if col2 < other_rows {
|
|
|
|
|
|
|
|
_mm256_loadu_ps(other_data.add(
|
|
|
|
|
|
|
|
col2 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE,
|
|
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
let other8_3: __m256 =
|
|
|
|
|
|
|
|
if col3 < other_rows {
|
|
|
|
|
|
|
|
_mm256_loadu_ps(other_data.add(
|
|
|
|
|
|
|
|
col3 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE,
|
|
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
|
|
|
|
};
|
|
|
|
let src8_0: __m256 = _mm256_loadu_ps(
|
|
|
|
let src8_0: __m256 = _mm256_loadu_ps(
|
|
|
|
src_data
|
|
|
|
src_data
|
|
|
|
.add(row0 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE),
|
|
|
|
.add(row0 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE),
|
|
|
|
@ -1133,24 +1188,61 @@ impl Tensor {
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
_mm256_setzero_ps()
|
|
|
|
};
|
|
|
|
};
|
|
|
|
targets8[0] = _mm256_fmadd_ps(src8_0, other8, targets8[0]);
|
|
|
|
targets8[0][0] = _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]);
|
|
|
|
targets8[1] = _mm256_fmadd_ps(src8_1, other8, targets8[1]);
|
|
|
|
targets8[0][1] = _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]);
|
|
|
|
targets8[2] = _mm256_fmadd_ps(src8_2, other8, targets8[2]);
|
|
|
|
targets8[0][2] = _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]);
|
|
|
|
targets8[3] = _mm256_fmadd_ps(src8_3, other8, targets8[3]);
|
|
|
|
targets8[0][3] = _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]);
|
|
|
|
}
|
|
|
|
targets8[1][0] = _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]);
|
|
|
|
let target0: f32 = horizontal_sum(targets8[0]);
|
|
|
|
targets8[1][1] = _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]);
|
|
|
|
let target1: f32 = horizontal_sum(targets8[1]);
|
|
|
|
targets8[1][2] = _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]);
|
|
|
|
let target2: f32 = horizontal_sum(targets8[2]);
|
|
|
|
targets8[1][3] = _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]);
|
|
|
|
let target3: f32 = horizontal_sum(targets8[3]);
|
|
|
|
targets8[2][0] = _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]);
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col) = target0;
|
|
|
|
targets8[2][1] = _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]);
|
|
|
|
|
|
|
|
targets8[2][2] = _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]);
|
|
|
|
|
|
|
|
targets8[2][3] = _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]);
|
|
|
|
|
|
|
|
targets8[3][0] = _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]);
|
|
|
|
|
|
|
|
targets8[3][1] = _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]);
|
|
|
|
|
|
|
|
targets8[3][2] = _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]);
|
|
|
|
|
|
|
|
targets8[3][3] = _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
let target00: f32 = horizontal_sum(targets8[0][0]);
|
|
|
|
|
|
|
|
let target01: f32 = horizontal_sum(targets8[0][1]);
|
|
|
|
|
|
|
|
let target02: f32 = horizontal_sum(targets8[0][2]);
|
|
|
|
|
|
|
|
let target03: f32 = horizontal_sum(targets8[0][3]);
|
|
|
|
|
|
|
|
let target10: f32 = horizontal_sum(targets8[1][0]);
|
|
|
|
|
|
|
|
let target11: f32 = horizontal_sum(targets8[1][1]);
|
|
|
|
|
|
|
|
let target12: f32 = horizontal_sum(targets8[1][2]);
|
|
|
|
|
|
|
|
let target13: f32 = horizontal_sum(targets8[1][3]);
|
|
|
|
|
|
|
|
let target20: f32 = horizontal_sum(targets8[2][0]);
|
|
|
|
|
|
|
|
let target21: f32 = horizontal_sum(targets8[2][1]);
|
|
|
|
|
|
|
|
let target22: f32 = horizontal_sum(targets8[2][2]);
|
|
|
|
|
|
|
|
let target23: f32 = horizontal_sum(targets8[2][3]);
|
|
|
|
|
|
|
|
let target30: f32 = horizontal_sum(targets8[3][0]);
|
|
|
|
|
|
|
|
let target31: f32 = horizontal_sum(targets8[3][1]);
|
|
|
|
|
|
|
|
let target32: f32 = horizontal_sum(targets8[3][2]);
|
|
|
|
|
|
|
|
let target33: f32 = horizontal_sum(targets8[3][3]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col0) += target00;
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col1) += target10;
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col2) += target20;
|
|
|
|
|
|
|
|
*tgt_data.add(row0 * self_cols_capacity + col3) += target30;
|
|
|
|
if row1 < self_rows {
|
|
|
|
if row1 < self_rows {
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col) = target1;
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col0) += target01;
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col1) += target11;
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col2) += target21;
|
|
|
|
|
|
|
|
*tgt_data.add(row1 * self_cols_capacity + col3) += target31;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if row2 < self_rows {
|
|
|
|
if row2 < self_rows {
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col) = target2;
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col0) += target02;
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col1) += target12;
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col2) += target22;
|
|
|
|
|
|
|
|
*tgt_data.add(row2 * self_cols_capacity + col3) += target32;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if row3 < self_rows {
|
|
|
|
if row3 < self_rows {
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col) = target3;
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col0) += target03;
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col1) += target13;
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col2) += target23;
|
|
|
|
|
|
|
|
*tgt_data.add(row3 * self_cols_capacity + col3) += target33;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -1222,8 +1314,8 @@ impl Tensor {
|
|
|
|
];
|
|
|
|
];
|
|
|
|
let self_data: *const f32 = self.data as *const f32;
|
|
|
|
let self_data: *const f32 = self.data as *const f32;
|
|
|
|
let other_data: *const f32 = other.data as *const f32;
|
|
|
|
let other_data: *const f32 = other.data as *const f32;
|
|
|
|
let tgt_data: *mut f32 = result.data as *mut f32;
|
|
|
|
let _tgt_data: *mut f32 = result.data as *mut f32;
|
|
|
|
let ncols_capacity: usize = result.capacity_cols as usize;
|
|
|
|
let _ncols_capacity: usize = result.capacity_cols as usize;
|
|
|
|
for row in 0..row_its {
|
|
|
|
for row in 0..row_its {
|
|
|
|
let row: i64 = row as i64;
|
|
|
|
let row: i64 = row as i64;
|
|
|
|
sum8s[0] = _mm256_setzero_ps();
|
|
|
|
sum8s[0] = _mm256_setzero_ps();
|
|
|
|
@ -1670,8 +1762,8 @@ impl Tensor {
|
|
|
|
|
|
|
|
|
|
|
|
let self_data: *const f16 = self.data as *const f16;
|
|
|
|
let self_data: *const f16 = self.data as *const f16;
|
|
|
|
let tgt_data: *mut f32 = result.data as *mut f32;
|
|
|
|
let tgt_data: *mut f32 = result.data as *mut f32;
|
|
|
|
let tgt_capacity_cols = result.capacity_cols as i64;
|
|
|
|
let tgt_capacity_cols = result.capacity_cols;
|
|
|
|
let self_capacity_cols = self.capacity_cols as i64;
|
|
|
|
let self_capacity_cols = self.capacity_cols;
|
|
|
|
for row in 0..self.rows {
|
|
|
|
for row in 0..self.rows {
|
|
|
|
for col in 0..cols_it {
|
|
|
|
for col in 0..cols_it {
|
|
|
|
let col = col * 8;
|
|
|
|
let col = col * 8;
|
|
|
|
@ -1719,8 +1811,8 @@ impl Tensor {
|
|
|
|
let result = Tensor::uninitialized(self.rows, self.cols, TensorDType::Float16);
|
|
|
|
let result = Tensor::uninitialized(self.rows, self.cols, TensorDType::Float16);
|
|
|
|
let self_data: *const f32 = self.data as *const f32;
|
|
|
|
let self_data: *const f32 = self.data as *const f32;
|
|
|
|
let tgt_data: *mut f16 = result.data as *mut f16;
|
|
|
|
let tgt_data: *mut f16 = result.data as *mut f16;
|
|
|
|
let tgt_capacity_cols = result.capacity_cols as i64;
|
|
|
|
let tgt_capacity_cols = result.capacity_cols;
|
|
|
|
let self_capacity_cols = self.capacity_cols as i64;
|
|
|
|
let self_capacity_cols = self.capacity_cols;
|
|
|
|
|
|
|
|
|
|
|
|
for row in 0..self.rows {
|
|
|
|
for row in 0..self.rows {
|
|
|
|
for col in 0..cols_it {
|
|
|
|
for col in 0..cols_it {
|
|
|
|
@ -1973,9 +2065,9 @@ mod tests {
|
|
|
|
fn mat_mul_transposed_agrees_with_regular_mat_mul() {
|
|
|
|
fn mat_mul_transposed_agrees_with_regular_mat_mul() {
|
|
|
|
let mut rng = rand::thread_rng();
|
|
|
|
let mut rng = rand::thread_rng();
|
|
|
|
for _ in 0..1000 {
|
|
|
|
for _ in 0..1000 {
|
|
|
|
let a = rng.gen_range(8..64);
|
|
|
|
let a = rng.gen_range(1..=128);
|
|
|
|
let b = rng.gen_range(8..64);
|
|
|
|
let b = rng.gen_range(1..=128);
|
|
|
|
let r = rng.gen_range(8..64);
|
|
|
|
let r = rng.gen_range(1..=128);
|
|
|
|
|
|
|
|
|
|
|
|
// Make matrixes AxR and RxB
|
|
|
|
// Make matrixes AxR and RxB
|
|
|
|
let a = Tensor::random(a, r, TensorDType::Float32);
|
|
|
|
let a = Tensor::random(a, r, TensorDType::Float32);
|
|
|
|
@ -1990,7 +2082,7 @@ mod tests {
|
|
|
|
|
|
|
|
|
|
|
|
for row in 0..c.rows {
|
|
|
|
for row in 0..c.rows {
|
|
|
|
for col in 0..c.cols {
|
|
|
|
for col in 0..c.cols {
|
|
|
|
assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-5);
|
|
|
|
assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-3);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|