You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
rllama/src/simd_support.rs

357 lines
8.4 KiB
Rust

// This file contains platform-specific SIMD so that rest of rllama does not need to care which
// platform it is on.
use core::arch::x86_64::*;
use half::f16;
use std::fmt::Write;
pub type I32x8 = __m256i;
pub type F32x8 = __m256;
pub type I16x8 = __m128i;
/* ------------------ */
/* Loading and storing things */
/* ------------------ */
#[inline]
pub fn load_i16x8(ptr: *const I16x8) -> I16x8 {
unsafe { _mm_loadu_si128(ptr) }
}
#[inline]
pub fn store_i16x8(ptr: *mut I16x8, a: I16x8) {
unsafe { _mm_storeu_si128(ptr, a) }
}
#[inline]
pub fn load_f32x8(ptr: *const F32x8) -> F32x8 {
unsafe { _mm256_loadu_ps(ptr as *const f32) }
}
#[inline]
pub fn store_f32x8(ptr: *mut F32x8, a: F32x8) {
unsafe { _mm256_storeu_ps(ptr as *mut f32, a) }
}
#[inline]
pub fn gather_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 {
unsafe { _mm256_i32gather_ps(ptr, indices, 1) }
}
#[inline]
pub fn gather_scale4_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 {
unsafe { _mm256_i32gather_ps(ptr, indices, 4) }
}
/* ------------------ */
/* Conversions */
/* ------------------ */
#[inline]
pub fn i16x8_as_f16_to_f32x8(a: I16x8) -> F32x8 {
unsafe { _mm256_cvtph_ps(a) }
}
#[inline]
pub fn f32x8_to_i16x8_as_f16(a: F32x8) -> I16x8 {
unsafe { _mm256_cvtps_ph(a, 0) }
}
#[inline]
/// Converts f32x8 to i32x8, by just casting the bits. I.e. it does not round any numbers or
/// anything, it just copies the bits.
pub fn f32x8_to_i32x8_bitcast(a: F32x8) -> I32x8 {
unsafe { _mm256_castps_si256(a) }
}
/*
* ------------------
* Accessing individual elements
*/
// Rust has no const arguments (yet, maybe in future).
// So we have this awkward match statement in each of these.
#[inline]
pub fn f32x8_get(a: F32x8, idx: usize) -> f32 {
unsafe {
let a = f32x8_to_i32x8_bitcast(a);
let a = match idx {
0 => _mm256_extract_epi32(a, 0),
1 => _mm256_extract_epi32(a, 1),
2 => _mm256_extract_epi32(a, 2),
3 => _mm256_extract_epi32(a, 3),
4 => _mm256_extract_epi32(a, 4),
5 => _mm256_extract_epi32(a, 5),
6 => _mm256_extract_epi32(a, 6),
7 => _mm256_extract_epi32(a, 7),
_ => panic!("f32x8_get: index out of bounds"),
};
// bitcast the i32 back to f32
core::mem::transmute(a)
}
}
#[inline]
pub fn i32x8_get(a: I32x8, idx: usize) -> i32 {
unsafe {
let a = match idx {
0 => _mm256_extract_epi32(a, 0),
1 => _mm256_extract_epi32(a, 1),
2 => _mm256_extract_epi32(a, 2),
3 => _mm256_extract_epi32(a, 3),
4 => _mm256_extract_epi32(a, 4),
5 => _mm256_extract_epi32(a, 5),
6 => _mm256_extract_epi32(a, 6),
7 => _mm256_extract_epi32(a, 7),
_ => panic!("i32x8_get: index out of bounds"),
};
a
}
}
#[inline]
pub fn i16x8_get(a: I16x8, idx: usize) -> i16 {
unsafe {
let a = match idx {
0 => _mm_extract_epi16(a, 0),
1 => _mm_extract_epi16(a, 1),
2 => _mm_extract_epi16(a, 2),
3 => _mm_extract_epi16(a, 3),
4 => _mm_extract_epi16(a, 4),
5 => _mm_extract_epi16(a, 5),
6 => _mm_extract_epi16(a, 6),
7 => _mm_extract_epi16(a, 7),
_ => panic!("i16x8_get: index out of bounds"),
};
a as i16
}
}
/*
* Constants, creating from constants
*/
#[inline]
pub fn f32x8_zero() -> F32x8 {
unsafe { _mm256_setzero_ps() }
}
#[inline]
pub fn i16x8_zero() -> I16x8 {
unsafe { _mm_setzero_si128() }
}
#[inline]
pub fn i16x8_singleton(value: i16) -> I16x8 {
unsafe { _mm_set1_epi16(value) }
}
#[inline]
pub fn i16x8_singleton_u16(value: u16) -> I16x8 {
unsafe { _mm_set1_epi16(value as i16) }
}
#[inline]
pub fn f32x8_singleton(value: f32) -> F32x8 {
unsafe { _mm256_set1_ps(value) }
}
#[inline]
pub fn f32x8_from_values(
val0: f32,
val1: f32,
val2: f32,
val3: f32,
val4: f32,
val5: f32,
val6: f32,
val7: f32,
) -> F32x8 {
unsafe { _mm256_set_ps(val0, val1, val2, val3, val4, val5, val6, val7) }
}
#[inline]
pub fn i32x8_from_values(
val0: i32,
val1: i32,
val2: i32,
val3: i32,
val4: i32,
val5: i32,
val6: i32,
val7: i32,
) -> I32x8 {
unsafe { _mm256_set_epi32(val0, val1, val2, val3, val4, val5, val6, val7) }
}
#[inline]
pub fn i32x8_from_values_u32(
val0: u32,
val1: u32,
val2: u32,
val3: u32,
val4: u32,
val5: u32,
val6: u32,
val7: u32,
) -> I32x8 {
unsafe {
_mm256_set_epi32(
val0 as i32,
val1 as i32,
val2 as i32,
val3 as i32,
val4 as i32,
val5 as i32,
val6 as i32,
val7 as i32,
)
}
}
#[inline]
pub fn i16x8_from_values(
val0: i16,
val1: i16,
val2: i16,
val3: i16,
val4: i16,
val5: i16,
val6: i16,
val7: i16,
) -> I16x8 {
unsafe { _mm_set_epi16(val0, val1, val2, val3, val4, val5, val6, val7) }
}
/*
* Operations
*/
// FMA
// a * b + c
#[inline]
pub fn fma_f32x8(a: F32x8, b: F32x8, c: F32x8) -> F32x8 {
unsafe { _mm256_fmadd_ps(a, b, c) }
}
// bitwise and
#[inline]
pub fn and_i16x8(a: I16x8, b: I16x8) -> I16x8 {
unsafe { _mm_and_si128(a, b) }
}
#[inline]
pub fn and_i32x8(a: I32x8, b: I32x8) -> I32x8 {
unsafe { _mm256_and_si256(a, b) }
}
#[inline]
pub fn and_f32x8(a: F32x8, b: I32x8) -> F32x8 {
unsafe { std::mem::transmute(_mm256_and_si256(std::mem::transmute(a), b)) }
}
// shift right by 4 bits exactly, for each individual i16 value.
// extends by zeros from left.
#[inline]
pub fn shift_right_by_4_i16x8(a: I16x8) -> I16x8 {
unsafe { _mm_srli_epi16(a, 4) }
}
// shift right by half of an entire i16x8
// extends by zeros from left.
#[inline]
pub fn shift_right_by_64_i128(a: I16x8) -> I16x8 {
unsafe { _mm_srli_si128(a, 64 / 8) }
}
// Shuffle/premute
pub fn shuffle_i16x8(a: I16x8, permutation: I16x8) -> I16x8 {
unsafe { _mm_shuffle_epi8(a, permutation) }
}
// Extends 8 i8 values into 7 i16 values
//
// XXYYZZ -> 00XX00YY00ZZ
pub fn extend_i8_to_i16_i16x8(a: I16x8) -> I16x8 {
unsafe { _mm_cvtepi8_epi16(a) }
}
// Extends 8 i8 values into 4 i32 values
pub fn extend_i8_to_i32_i32x8(a: I16x8) -> I32x8 {
let i = extend_i8_to_i16_i16x8(a);
unsafe { _mm256_cvtepu16_epi32(i) }
}
// Horizontal sums
#[inline]
pub fn horizontal_sum_f32x8(mut ymm: __m256) -> f32 {
unsafe {
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1);
ymm = _mm256_add_ps(ymm, ymm2);
ymm = _mm256_hadd_ps(ymm, ymm);
ymm = _mm256_hadd_ps(ymm, ymm);
_mm256_cvtss_f32(ymm)
}
}
#[inline]
pub fn horizontal_sum_and_f32_to_f16(mut ymm: __m256) -> f16 {
unsafe {
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1);
ymm = _mm256_add_ps(ymm, ymm2);
ymm = _mm256_hadd_ps(ymm, ymm);
ymm = _mm256_hadd_ps(ymm, ymm);
f16::from_f32(_mm256_cvtss_f32(ymm))
}
}
/*
* Debugging
*/
/// Prints a binary representation of i16x8 to stdout in this form:
///
/// ```ignore
/// 0 0 0 0
/// 0x0000 0x0000 0x0000 0x0000
/// 0000000000000000 0000000000000000 0000000000000000 0000000000000000 etc.
/// ```
///
/// decimal on first line, hex on second, binary on third.
pub fn print_i16x8(a: I16x8) {
let mut decimal_line = String::new();
let mut hex_line = String::new();
let mut binary_line = String::new();
for i in 0..8 {
let val = i16x8_get(a, i);
write!(decimal_line, "{:>5} ", val).unwrap();
write!(hex_line, "0x{:04X} ", val).unwrap();
write!(binary_line, "{:016b} ", val).unwrap();
}
println!("{}", decimal_line.trim_end());
println!("{}", hex_line.trim_end());
println!("{}", binary_line.trim_end());
}
pub fn print_i32x8(a: I32x8) {
let mut decimal_line = String::new();
let mut hex_line = String::new();
let mut binary_line = String::new();
for i in 0..8 {
let val = i32x8_get(a, i);
write!(decimal_line, "{:>10} ", val).unwrap();
write!(hex_line, "0x{:08X} ", val).unwrap();
write!(binary_line, "{:032b} ", val).unwrap();
}
println!("{}", decimal_line.trim_end());
println!("{}", hex_line.trim_end());
println!("{}", binary_line.trim_end());
}