Add f16 version of matrix multiplication that works without any OpenCL.

In benchmark it is modestly faster than f32. The main transformer loop
doesn't know how to use f16 yet though, and I need to implement some
other ops for that to start working.
master
Mikko Juola 3 years ago
parent a1970b8a9c
commit baecd25ee3

@ -99,12 +99,40 @@ pub fn tensor_benchmarks(c: &mut Criterion) {
let orig_84096_2 = Tensor::zeros(4096, 4096, TensorDType::Float32);
let mut result_84096 = Tensor::zeros(8, 4096, TensorDType::Float32);
let orig_84096_1_f16 = Tensor::zeros(8, 4096, TensorDType::Float16);
let orig_84096_2_f16 = Tensor::zeros(4096, 4096, TensorDType::Float16);
let mut result_84096_f16 = Tensor::zeros(8, 4096, TensorDType::Float16);
let orig_f32 = Tensor::zeros(1024, 1024, TensorDType::Float32);
let orig_f16 = Tensor::zeros(1024, 1024, TensorDType::Float16);
let m1 = Tensor::random(1024, 128, TensorDType::Float32);
let m2 = Tensor::random(1, 128, TensorDType::Float32);
c.bench_function(
"matrix multiplication 8x4096 @ 4096x4096 f32 in-place, transposed",
|b| {
b.iter(|| {
let _ = result_84096.matrix_mul_inplace_transposed(
black_box(&orig_84096_1),
black_box(&orig_84096_2),
);
})
},
);
c.bench_function(
"matrix multiplication 8x4096 @ 4096x4096 f16 in-place, transposed",
|b| {
b.iter(|| {
let _ = result_84096_f16.matrix_mul_inplace_transposed(
black_box(&orig_84096_1_f16),
black_box(&orig_84096_2_f16),
);
})
},
);
c.bench_function(
"1024x128 * 1x128 matrix vector transposed multiplication",
|b| {
@ -136,18 +164,6 @@ pub fn tensor_benchmarks(c: &mut Criterion) {
},
);
c.bench_function(
"matrix multiplication 8x4096 @ 4096x4096 f32 in-place, transposed",
|b| {
b.iter(|| {
let _ = result_84096.matrix_mul_inplace_transposed(
black_box(&orig_84096_1),
black_box(&orig_84096_2),
);
})
},
);
c.bench_function("matrix multiplication f32 not in-place", |b| {
b.iter(|| {
let _ = black_box(&orig32_1).matrix_mul(black_box(&orig32_2));

@ -169,6 +169,17 @@ fn horizontal_sum(mut ymm: __m256) -> f32 {
}
}
#[inline]
fn horizontal_sum_f32_to_f16(mut ymm: __m256) -> f16 {
unsafe {
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1);
ymm = _mm256_add_ps(ymm, ymm2);
ymm = _mm256_hadd_ps(ymm, ymm);
ymm = _mm256_hadd_ps(ymm, ymm);
f16::from_f32(_mm256_cvtss_f32(ymm))
}
}
impl Tensor {
#[inline]
pub fn assume_on_gpu(&self) {
@ -824,8 +835,10 @@ impl Tensor {
self.rows, self.cols, other.cols, other.rows
);
}
// We don't have implementation for f16, so don't use the vector function if we have
// f16
#[cfg(not(feature = "opencl"))]
if other.rows == 1 {
if other.rows == 1 && other.dtype != TensorDType::Float16 {
return self.matrix_vector_mul_transposed(other);
}
#[cfg(feature = "opencl")]
@ -1054,8 +1067,7 @@ impl Tensor {
match src.dtype {
TensorDType::Float32 => {
const CACHE_LINE_SIZE: usize = 32;
const ITEMS_PER_CACHE_LINE: usize = CACHE_LINE_SIZE / std::mem::size_of::<f32>();
const ITEMS_PER_LINE: usize = 8;
let tgt_data: *mut f32 = self.data as *mut f32;
unsafe {
@ -1078,10 +1090,10 @@ impl Tensor {
let src_cols_capacity: usize = src.capacity_cols as usize;
let self_cols_capacity: usize = self.capacity_cols as usize;
let src_cols_its = if src_cols % ITEMS_PER_CACHE_LINE == 0 {
src_cols / ITEMS_PER_CACHE_LINE
let src_cols_its = if src_cols % ITEMS_PER_LINE == 0 {
src_cols / ITEMS_PER_LINE
} else {
src_cols / ITEMS_PER_CACHE_LINE + 1
src_cols / ITEMS_PER_LINE + 1
};
let row_its = if self_rows % 4 == 0 {
self_rows / 4
@ -1133,58 +1145,53 @@ impl Tensor {
];
for p in 0..src_cols_its {
let other8_0: __m256 = _mm256_loadu_ps(
other_data
.add(col0 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE),
other_data.add(col0 * other_cols_capacity + p * ITEMS_PER_LINE),
);
let other8_1: __m256 =
if col1 < other_rows {
_mm256_loadu_ps(other_data.add(
col1 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
let other8_1: __m256 = if col1 < other_rows {
_mm256_loadu_ps(
other_data
.add(col1 * other_cols_capacity + p * ITEMS_PER_LINE),
)
} else {
_mm256_setzero_ps()
};
let other8_2: __m256 =
if col2 < other_rows {
_mm256_loadu_ps(other_data.add(
col2 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
let other8_2: __m256 = if col2 < other_rows {
_mm256_loadu_ps(
other_data
.add(col2 * other_cols_capacity + p * ITEMS_PER_LINE),
)
} else {
_mm256_setzero_ps()
};
let other8_3: __m256 =
if col3 < other_rows {
_mm256_loadu_ps(other_data.add(
col3 * other_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
let other8_3: __m256 = if col3 < other_rows {
_mm256_loadu_ps(
other_data
.add(col3 * other_cols_capacity + p * ITEMS_PER_LINE),
)
} else {
_mm256_setzero_ps()
};
let src8_0: __m256 = _mm256_loadu_ps(
src_data
.add(row0 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE),
src_data.add(row0 * src_cols_capacity + p * ITEMS_PER_LINE),
);
let src8_1: __m256 =
if row1 < src_rows {
_mm256_loadu_ps(src_data.add(
row1 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
let src8_1: __m256 = if row1 < src_rows {
_mm256_loadu_ps(
src_data.add(row1 * src_cols_capacity + p * ITEMS_PER_LINE),
)
} else {
_mm256_setzero_ps()
};
let src8_2: __m256 =
if row2 < src_rows {
_mm256_loadu_ps(src_data.add(
row2 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
let src8_2: __m256 = if row2 < src_rows {
_mm256_loadu_ps(
src_data.add(row2 * src_cols_capacity + p * ITEMS_PER_LINE),
)
} else {
_mm256_setzero_ps()
};
let src8_3: __m256 =
if row3 < src_rows {
_mm256_loadu_ps(src_data.add(
row3 * src_cols_capacity + p * ITEMS_PER_CACHE_LINE,
))
let src8_3: __m256 = if row3 < src_rows {
_mm256_loadu_ps(
src_data.add(row3 * src_cols_capacity + p * ITEMS_PER_LINE),
)
} else {
_mm256_setzero_ps()
};
@ -1248,7 +1255,203 @@ impl Tensor {
}
}
}
TensorDType::Float16 => unimplemented!(),
TensorDType::Float16 => {
const ITEMS_PER_LINE: usize = 8;
let tgt_data: *mut f16 = self.data as *mut f16;
unsafe {
std::ptr::write_bytes(
tgt_data,
0,
self.rows as usize * self.capacity_cols as usize,
);
}
let src_data: *const f16 = src.data as *const f16;
let other_data: *const f16 = other.data as *const f16;
let src_rows: usize = src.rows as usize;
let src_cols: usize = src.cols as usize;
let self_rows: usize = self.rows as usize;
let self_cols: usize = self.cols as usize;
let _other_cols: usize = other.cols as usize;
let other_rows: usize = other.rows as usize;
let other_cols_capacity: usize = other.capacity_cols as usize;
let src_cols_capacity: usize = src.capacity_cols as usize;
let self_cols_capacity: usize = self.capacity_cols as usize;
let src_cols_its = if src_cols % ITEMS_PER_LINE == 0 {
src_cols / ITEMS_PER_LINE
} else {
src_cols / ITEMS_PER_LINE + 1
};
let row_its = if self_rows % 4 == 0 {
self_rows / 4
} else {
self_rows / 4 + 1
};
let self_cols_its = if self_cols % 4 == 0 {
self_cols / 4
} else {
self_cols / 4 + 1
};
unsafe {
for row in 0..row_its {
let row0 = row * 4;
let row1 = row * 4 + 1;
let row2 = row * 4 + 2;
let row3 = row * 4 + 3;
for col in 0..self_cols_its {
let col0 = col * 4;
let col1 = col * 4 + 1;
let col2 = col * 4 + 2;
let col3 = col * 4 + 3;
let mut targets8: [[__m256; 4]; 4] = [
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
[
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
_mm256_setzero_ps(),
],
];
// Loads from (row, column..column+8) and (row+1, column..column+8)
#[inline]
fn load2_rows(
ptr: *const f16,
row: usize,
column: usize,
cols_capacity: usize,
nrows: usize,
) -> (__m256, __m256) {
unsafe {
let (left, right) = if row + 1 < nrows {
(
_mm_loadu_si128(ptr.add(row * cols_capacity + column)
as *const __m128i),
_mm_loadu_si128(
ptr.add((row + 1) * cols_capacity + column)
as *const __m128i,
),
)
} else {
(
_mm_loadu_si128(ptr.add(row * cols_capacity + column)
as *const __m128i),
_mm_setzero_si128(),
)
};
let left: __m256 = _mm256_cvtph_ps(left);
let right: __m256 = _mm256_cvtph_ps(right);
(left, right)
}
}
for p in 0..src_cols_its {
let (other8_0, other8_1) = load2_rows(
other_data,
col0,
p * ITEMS_PER_LINE,
other_cols_capacity,
other_rows,
);
let (other8_2, other8_3) = load2_rows(
other_data,
col2,
p * ITEMS_PER_LINE,
other_cols_capacity,
other_rows,
);
let (src8_0, src8_1) = load2_rows(
src_data,
row0,
p * ITEMS_PER_LINE,
src_cols_capacity,
src_rows,
);
let (src8_2, src8_3) = load2_rows(
src_data,
row2,
p * ITEMS_PER_LINE,
src_cols_capacity,
src_rows,
);
targets8[0][0] = _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]);
targets8[0][1] = _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]);
targets8[0][2] = _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]);
targets8[0][3] = _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]);
targets8[1][0] = _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]);
targets8[1][1] = _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]);
targets8[1][2] = _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]);
targets8[1][3] = _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]);
targets8[2][0] = _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]);
targets8[2][1] = _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]);
targets8[2][2] = _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]);
targets8[2][3] = _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]);
targets8[3][0] = _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]);
targets8[3][1] = _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]);
targets8[3][2] = _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]);
targets8[3][3] = _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]);
}
let target00: f16 = horizontal_sum_f32_to_f16(targets8[0][0]);
let target01: f16 = horizontal_sum_f32_to_f16(targets8[0][1]);
let target02: f16 = horizontal_sum_f32_to_f16(targets8[0][2]);
let target03: f16 = horizontal_sum_f32_to_f16(targets8[0][3]);
let target10: f16 = horizontal_sum_f32_to_f16(targets8[1][0]);
let target11: f16 = horizontal_sum_f32_to_f16(targets8[1][1]);
let target12: f16 = horizontal_sum_f32_to_f16(targets8[1][2]);
let target13: f16 = horizontal_sum_f32_to_f16(targets8[1][3]);
let target20: f16 = horizontal_sum_f32_to_f16(targets8[2][0]);
let target21: f16 = horizontal_sum_f32_to_f16(targets8[2][1]);
let target22: f16 = horizontal_sum_f32_to_f16(targets8[2][2]);
let target23: f16 = horizontal_sum_f32_to_f16(targets8[2][3]);
let target30: f16 = horizontal_sum_f32_to_f16(targets8[3][0]);
let target31: f16 = horizontal_sum_f32_to_f16(targets8[3][1]);
let target32: f16 = horizontal_sum_f32_to_f16(targets8[3][2]);
let target33: f16 = horizontal_sum_f32_to_f16(targets8[3][3]);
*tgt_data.add(row0 * self_cols_capacity + col0) += target00;
*tgt_data.add(row0 * self_cols_capacity + col1) += target10;
*tgt_data.add(row0 * self_cols_capacity + col2) += target20;
*tgt_data.add(row0 * self_cols_capacity + col3) += target30;
if row1 < self_rows {
*tgt_data.add(row1 * self_cols_capacity + col0) += target01;
*tgt_data.add(row1 * self_cols_capacity + col1) += target11;
*tgt_data.add(row1 * self_cols_capacity + col2) += target21;
*tgt_data.add(row1 * self_cols_capacity + col3) += target31;
}
if row2 < self_rows {
*tgt_data.add(row2 * self_cols_capacity + col0) += target02;
*tgt_data.add(row2 * self_cols_capacity + col1) += target12;
*tgt_data.add(row2 * self_cols_capacity + col2) += target22;
*tgt_data.add(row2 * self_cols_capacity + col3) += target32;
}
if row3 < self_rows {
*tgt_data.add(row3 * self_cols_capacity + col0) += target03;
*tgt_data.add(row3 * self_cols_capacity + col1) += target13;
*tgt_data.add(row3 * self_cols_capacity + col2) += target23;
*tgt_data.add(row3 * self_cols_capacity + col3) += target33;
}
}
}
}
}
}
}
@ -2088,6 +2291,36 @@ mod tests {
}
}
#[test]
fn mat_mul_transposed_f32_agrees_mat_mul_transposed_f16() {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let a = rng.gen_range(1..=128);
let b = rng.gen_range(1..=128);
let r = rng.gen_range(1..=128);
// Make matrixes AxR and RxB
let a = Tensor::random(a, r, TensorDType::Float32);
let b = Tensor::random(r, b, TensorDType::Float32);
let a2 = a.clone().to_f16();
let b2 = b.clone().to_f16();
let b_transposed = b.transpose();
let b2_transposed = b2.transpose();
let c = a.matrix_mul_transposed(&b_transposed);
let c2 = a2.matrix_mul_transposed(&b2_transposed);
assert_eq!(c.rows, c2.rows);
assert_eq!(c.cols, c2.cols);
for row in 0..c.rows {
for col in 0..c.cols {
assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-1);
}
}
}
}
#[test]
fn view_preserves_values() {
fn test_with_type(dtype: TensorDType) {

Loading…
Cancel
Save