Make matrix multiplication multithreaded.

This improves performance greatly with f16. It's faster now than OpenCL
on LLaMA-7B.
master
Mikko Juola 3 years ago
parent 8134c20d57
commit 3d0afcf243

@ -129,6 +129,18 @@ pub fn tensor_benchmarks(c: &mut Criterion) {
}, },
); );
c.bench_function(
"matrix multiplication 8x4096 @ 4096x4096 f16 in-place, transposed",
|b| {
b.iter(|| {
let _ = result_84096_f16.matrix_mul_inplace_transposed(
black_box(&orig_84096_1_f16),
black_box(&orig_84096_2_f16),
);
})
},
);
c.bench_function( c.bench_function(
"matrix multiplication 8x4096 @ 4096x4096 f32 in-place, transposed", "matrix multiplication 8x4096 @ 4096x4096 f32 in-place, transposed",
|b| { |b| {
@ -142,13 +154,11 @@ pub fn tensor_benchmarks(c: &mut Criterion) {
); );
c.bench_function( c.bench_function(
"matrix multiplication 8x4096 @ 4096x4096 f16 in-place, transposed", "matrix multiplication 8x4096 @ 4096x4096 f32 in-place",
|b| { |b| {
b.iter(|| { b.iter(|| {
let _ = result_84096_f16.matrix_mul_inplace_transposed( let _ = result_84096
black_box(&orig_84096_1_f16), .matrix_mul_inplace(black_box(&orig_84096_1), black_box(&orig_84096_2));
black_box(&orig_84096_2_f16),
);
}) })
}, },
); );
@ -165,16 +175,6 @@ pub fn tensor_benchmarks(c: &mut Criterion) {
}) })
}); });
c.bench_function(
"matrix multiplication 8x4096 @ 4096x4096 f32 in-place",
|b| {
b.iter(|| {
let _ = result_84096
.matrix_mul_inplace(black_box(&orig_84096_1), black_box(&orig_84096_2));
})
},
);
c.bench_function("matrix multiplication f32 not in-place", |b| { c.bench_function("matrix multiplication f32 not in-place", |b| {
b.iter(|| { b.iter(|| {
let _ = black_box(&orig32_1).matrix_mul(black_box(&orig32_2)); let _ = black_box(&orig32_1).matrix_mul(black_box(&orig32_2));

@ -171,7 +171,7 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
DataSettings::new() DataSettings::new()
}; };
if cli.f16 == true { if cli.f16 {
data_settings = data_settings.force_f16(); data_settings = data_settings.force_f16();
} }

@ -135,6 +135,23 @@ impl Drop for Tensor {
} }
} }
// Use this to smuggle pointers to threads without Rust getting so goddamn mad
//
// Assumption usize = pointer size.
#[derive(Copy, Clone)]
struct WrappedPtr {
ptr: usize,
}
impl WrappedPtr {
fn wrap(ptr: *const u8) -> WrappedPtr {
WrappedPtr { ptr: ptr as usize }
}
fn unwrap(self) -> *const u8 {
self.ptr as *const u8
}
}
fn compute_capacity_cols(dtype: TensorDType, cols: i64) -> i64 { fn compute_capacity_cols(dtype: TensorDType, cols: i64) -> i64 {
match dtype { match dtype {
TensorDType::Float16 => compute_capacity_cols_f16(cols), TensorDType::Float16 => compute_capacity_cols_f16(cols),
@ -1025,12 +1042,12 @@ impl Tensor {
} }
pub fn is_on_cpu(&self) -> bool { pub fn is_on_cpu(&self) -> bool {
return !self.is_on_gpu(); !self.is_on_gpu()
} }
// Casts data type to whatever the other tensors data type is. // Casts data type to whatever the other tensors data type is.
pub fn to_same_type(&self, other: &Tensor) -> Tensor { pub fn to_same_type(&self, other: &Tensor) -> Tensor {
let mut result = self.clone(); let result = self.clone();
if result.dtype() == other.dtype() { if result.dtype() == other.dtype() {
return result; return result;
} }
@ -1115,8 +1132,8 @@ impl Tensor {
self.rows as usize * self.capacity_cols as usize, self.rows as usize * self.capacity_cols as usize,
); );
} }
let src_data: *const f32 = src.data as *const f32; let _src_data: *const f32 = src.data as *const f32;
let other_data: *const f32 = other.data as *const f32; let _other_data: *const f32 = other.data as *const f32;
let src_rows: usize = src.rows as usize; let src_rows: usize = src.rows as usize;
let src_cols: usize = src.cols as usize; let src_cols: usize = src.cols as usize;
@ -1145,12 +1162,23 @@ impl Tensor {
}; };
unsafe { unsafe {
let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data);
let other_data: WrappedPtr = WrappedPtr::wrap(other.data);
let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data);
(0..32).into_par_iter().for_each(|thread_idx| {
let src_data: *const f32 = src_data_wrap.unwrap() as *const f32;
let other_data: *const f32 = other_data.unwrap() as *const f32;
let tgt_data: *mut f32 = tgt_data.unwrap() as *mut f32;
for row in 0..row_its { for row in 0..row_its {
let row0 = row * 4; let row0 = row * 4;
let row1 = row * 4 + 1; let row1 = row * 4 + 1;
let row2 = row * 4 + 2; let row2 = row * 4 + 2;
let row3 = row * 4 + 3; let row3 = row * 4 + 3;
for col in 0..self_cols_its { for col in 0..self_cols_its {
let row_col = row * self_cols_its + col;
if row_col % 32 != thread_idx {
continue;
}
let col0 = col * 4; let col0 = col * 4;
let col1 = col * 4 + 1; let col1 = col * 4 + 1;
let col2 = col * 4 + 2; let col2 = col * 4 + 2;
@ -1183,29 +1211,30 @@ impl Tensor {
]; ];
for p in 0..src_cols_its { for p in 0..src_cols_its {
let other8_0: __m256 = _mm256_loadu_ps( let other8_0: __m256 = _mm256_loadu_ps(
other_data.add(col0 * other_cols_capacity + p * ITEMS_PER_LINE),
);
let other8_1: __m256 = if col1 < other_rows {
_mm256_loadu_ps(
other_data other_data
.add(col1 * other_cols_capacity + p * ITEMS_PER_LINE), .add(col0 * other_cols_capacity + p * ITEMS_PER_LINE),
) );
let other8_1: __m256 =
if col1 < other_rows {
_mm256_loadu_ps(other_data.add(
col1 * other_cols_capacity + p * ITEMS_PER_LINE,
))
} else { } else {
_mm256_setzero_ps() _mm256_setzero_ps()
}; };
let other8_2: __m256 = if col2 < other_rows { let other8_2: __m256 =
_mm256_loadu_ps( if col2 < other_rows {
other_data _mm256_loadu_ps(other_data.add(
.add(col2 * other_cols_capacity + p * ITEMS_PER_LINE), col2 * other_cols_capacity + p * ITEMS_PER_LINE,
) ))
} else { } else {
_mm256_setzero_ps() _mm256_setzero_ps()
}; };
let other8_3: __m256 = if col3 < other_rows { let other8_3: __m256 =
_mm256_loadu_ps( if col3 < other_rows {
other_data _mm256_loadu_ps(other_data.add(
.add(col3 * other_cols_capacity + p * ITEMS_PER_LINE), col3 * other_cols_capacity + p * ITEMS_PER_LINE,
) ))
} else { } else {
_mm256_setzero_ps() _mm256_setzero_ps()
}; };
@ -1214,41 +1243,60 @@ impl Tensor {
); );
let src8_1: __m256 = if row1 < src_rows { let src8_1: __m256 = if row1 < src_rows {
_mm256_loadu_ps( _mm256_loadu_ps(
src_data.add(row1 * src_cols_capacity + p * ITEMS_PER_LINE), src_data
.add(row1 * src_cols_capacity + p * ITEMS_PER_LINE),
) )
} else { } else {
_mm256_setzero_ps() _mm256_setzero_ps()
}; };
let src8_2: __m256 = if row2 < src_rows { let src8_2: __m256 = if row2 < src_rows {
_mm256_loadu_ps( _mm256_loadu_ps(
src_data.add(row2 * src_cols_capacity + p * ITEMS_PER_LINE), src_data
.add(row2 * src_cols_capacity + p * ITEMS_PER_LINE),
) )
} else { } else {
_mm256_setzero_ps() _mm256_setzero_ps()
}; };
let src8_3: __m256 = if row3 < src_rows { let src8_3: __m256 = if row3 < src_rows {
_mm256_loadu_ps( _mm256_loadu_ps(
src_data.add(row3 * src_cols_capacity + p * ITEMS_PER_LINE), src_data
.add(row3 * src_cols_capacity + p * ITEMS_PER_LINE),
) )
} else { } else {
_mm256_setzero_ps() _mm256_setzero_ps()
}; };
targets8[0][0] = _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]); targets8[0][0] =
targets8[0][1] = _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]); _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]);
targets8[0][2] = _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]); targets8[0][1] =
targets8[0][3] = _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]); _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]);
targets8[1][0] = _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]); targets8[0][2] =
targets8[1][1] = _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]); _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]);
targets8[1][2] = _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]); targets8[0][3] =
targets8[1][3] = _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]); _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]);
targets8[2][0] = _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]); targets8[1][0] =
targets8[2][1] = _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]); _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]);
targets8[2][2] = _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]); targets8[1][1] =
targets8[2][3] = _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]); _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]);
targets8[3][0] = _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]); targets8[1][2] =
targets8[3][1] = _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]); _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]);
targets8[3][2] = _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]); targets8[1][3] =
targets8[3][3] = _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]); _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]);
targets8[2][0] =
_mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]);
targets8[2][1] =
_mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]);
targets8[2][2] =
_mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]);
targets8[2][3] =
_mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]);
targets8[3][0] =
_mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]);
targets8[3][1] =
_mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]);
targets8[3][2] =
_mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]);
targets8[3][3] =
_mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]);
} }
let target00: f32 = horizontal_sum(targets8[0][0]); let target00: f32 = horizontal_sum(targets8[0][0]);
let target01: f32 = horizontal_sum(targets8[0][1]); let target01: f32 = horizontal_sum(targets8[0][1]);
@ -1291,6 +1339,7 @@ impl Tensor {
} }
} }
} }
});
} }
} }
TensorDType::Float16 => { TensorDType::Float16 => {
@ -1304,8 +1353,8 @@ impl Tensor {
self.rows as usize * self.capacity_cols as usize, self.rows as usize * self.capacity_cols as usize,
); );
} }
let src_data: *const f16 = src.data as *const f16; let _src_data: *const f16 = src.data as *const f16;
let other_data: *const f16 = other.data as *const f16; let _other_data: *const f16 = other.data as *const f16;
let src_rows: usize = src.rows as usize; let src_rows: usize = src.rows as usize;
let src_cols: usize = src.cols as usize; let src_cols: usize = src.cols as usize;
@ -1334,12 +1383,23 @@ impl Tensor {
}; };
unsafe { unsafe {
let src_data_wrap: WrappedPtr = WrappedPtr::wrap(src.data);
let other_data: WrappedPtr = WrappedPtr::wrap(other.data);
let tgt_data: WrappedPtr = WrappedPtr::wrap(self.data);
(0..32).into_par_iter().for_each(|thread_idx| {
let src_data: *const f16 = src_data_wrap.unwrap() as *const f16;
let other_data: *const f16 = other_data.unwrap() as *const f16;
let tgt_data: *mut f16 = tgt_data.unwrap() as *mut f16;
for row in 0..row_its { for row in 0..row_its {
let row0 = row * 4; let row0 = row * 4;
let row1 = row * 4 + 1; let row1 = row * 4 + 1;
let row2 = row * 4 + 2; let row2 = row * 4 + 2;
let row3 = row * 4 + 3; let row3 = row * 4 + 3;
for col in 0..self_cols_its { for col in 0..self_cols_its {
let row_col = row * self_cols_its + col;
if row_col % 32 != thread_idx {
continue;
}
let col0 = col * 4; let col0 = col * 4;
let col1 = col * 4 + 1; let col1 = col * 4 + 1;
let col2 = col * 4 + 2; let col2 = col * 4 + 2;
@ -1382,8 +1442,10 @@ impl Tensor {
unsafe { unsafe {
let (left, right) = if row + 1 < nrows { let (left, right) = if row + 1 < nrows {
( (
_mm_loadu_si128(ptr.add(row * cols_capacity + column) _mm_loadu_si128(
as *const __m128i), ptr.add(row * cols_capacity + column)
as *const __m128i,
),
_mm_loadu_si128( _mm_loadu_si128(
ptr.add((row + 1) * cols_capacity + column) ptr.add((row + 1) * cols_capacity + column)
as *const __m128i, as *const __m128i,
@ -1391,8 +1453,10 @@ impl Tensor {
) )
} else { } else {
( (
_mm_loadu_si128(ptr.add(row * cols_capacity + column) _mm_loadu_si128(
as *const __m128i), ptr.add(row * cols_capacity + column)
as *const __m128i,
),
_mm_setzero_si128(), _mm_setzero_si128(),
) )
}; };
@ -1430,22 +1494,38 @@ impl Tensor {
src_cols_capacity, src_cols_capacity,
src_rows, src_rows,
); );
targets8[0][0] = _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]); targets8[0][0] =
targets8[0][1] = _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]); _mm256_fmadd_ps(src8_0, other8_0, targets8[0][0]);
targets8[0][2] = _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]); targets8[0][1] =
targets8[0][3] = _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]); _mm256_fmadd_ps(src8_1, other8_0, targets8[0][1]);
targets8[1][0] = _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]); targets8[0][2] =
targets8[1][1] = _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]); _mm256_fmadd_ps(src8_2, other8_0, targets8[0][2]);
targets8[1][2] = _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]); targets8[0][3] =
targets8[1][3] = _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]); _mm256_fmadd_ps(src8_3, other8_0, targets8[0][3]);
targets8[2][0] = _mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]); targets8[1][0] =
targets8[2][1] = _mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]); _mm256_fmadd_ps(src8_0, other8_1, targets8[1][0]);
targets8[2][2] = _mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]); targets8[1][1] =
targets8[2][3] = _mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]); _mm256_fmadd_ps(src8_1, other8_1, targets8[1][1]);
targets8[3][0] = _mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]); targets8[1][2] =
targets8[3][1] = _mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]); _mm256_fmadd_ps(src8_2, other8_1, targets8[1][2]);
targets8[3][2] = _mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]); targets8[1][3] =
targets8[3][3] = _mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]); _mm256_fmadd_ps(src8_3, other8_1, targets8[1][3]);
targets8[2][0] =
_mm256_fmadd_ps(src8_0, other8_2, targets8[2][0]);
targets8[2][1] =
_mm256_fmadd_ps(src8_1, other8_2, targets8[2][1]);
targets8[2][2] =
_mm256_fmadd_ps(src8_2, other8_2, targets8[2][2]);
targets8[2][3] =
_mm256_fmadd_ps(src8_3, other8_2, targets8[2][3]);
targets8[3][0] =
_mm256_fmadd_ps(src8_0, other8_3, targets8[3][0]);
targets8[3][1] =
_mm256_fmadd_ps(src8_1, other8_3, targets8[3][1]);
targets8[3][2] =
_mm256_fmadd_ps(src8_2, other8_3, targets8[3][2]);
targets8[3][3] =
_mm256_fmadd_ps(src8_3, other8_3, targets8[3][3]);
} }
let target00: f16 = horizontal_sum_f32_to_f16(targets8[0][0]); let target00: f16 = horizontal_sum_f32_to_f16(targets8[0][0]);
let target01: f16 = horizontal_sum_f32_to_f16(targets8[0][1]); let target01: f16 = horizontal_sum_f32_to_f16(targets8[0][1]);
@ -1488,6 +1568,7 @@ impl Tensor {
} }
} }
} }
});
} }
} }
} }
@ -1534,6 +1615,7 @@ impl Tensor {
assert_eq!(other.rows, 1); assert_eq!(other.rows, 1);
assert_eq!(other.dtype, self.dtype); assert_eq!(other.dtype, self.dtype);
#[allow(unreachable_patterns)]
match self.dtype { match self.dtype {
TensorDType::Float32 => self.matrix_vector_mul_transposed_f32(other), TensorDType::Float32 => self.matrix_vector_mul_transposed_f32(other),
TensorDType::Float16 => self.matrix_vector_mul_transposed_f16(other), TensorDType::Float16 => self.matrix_vector_mul_transposed_f16(other),
@ -1669,7 +1751,7 @@ impl Tensor {
self.assume_on_cpu(); self.assume_on_cpu();
other.assume_on_cpu(); other.assume_on_cpu();
unsafe { unsafe {
let mut result = Tensor::uninitialized(self.rows, 1, self.dtype); let result = Tensor::zeros(self.rows, 1, self.dtype);
let col_its: usize = if self.cols % 8 == 0 { let col_its: usize = if self.cols % 8 == 0 {
(self.cols / 8) as usize (self.cols / 8) as usize
} else { } else {
@ -1680,16 +1762,18 @@ impl Tensor {
} else { } else {
(self.rows / 4 + 1) as usize (self.rows / 4 + 1) as usize
}; };
let self_data: *const f32 = self.data as *const f32;
let other_data: *const f32 = other.data as *const f32;
let tgt_data: *mut f32 = result.data as *mut f32;
let ncols_capacity: usize = result.capacity_cols as usize;
let mut sum8s: [__m256; 4] = [ let mut sum8s: [__m256; 4] = [
_mm256_setzero_ps(), _mm256_setzero_ps(),
_mm256_setzero_ps(), _mm256_setzero_ps(),
_mm256_setzero_ps(), _mm256_setzero_ps(),
_mm256_setzero_ps(), _mm256_setzero_ps(),
]; ];
let self_data: *const f32 = self.data as *const f32;
let other_data: *const f32 = other.data as *const f32;
let _tgt_data: *mut f32 = result.data as *mut f32;
let _ncols_capacity: usize = result.capacity_cols as usize;
for row in 0..row_its { for row in 0..row_its {
let row: i64 = row as i64; let row: i64 = row as i64;
sum8s[0] = _mm256_setzero_ps(); sum8s[0] = _mm256_setzero_ps();
@ -1732,87 +1816,18 @@ impl Tensor {
let sum_2: f32 = horizontal_sum(sum8s[2]); let sum_2: f32 = horizontal_sum(sum8s[2]);
let sum_3: f32 = horizontal_sum(sum8s[3]); let sum_3: f32 = horizontal_sum(sum8s[3]);
if row4_0 < result.rows { if row4_0 < result.rows {
result.set_f32(row4_0, 0, sum_0); *(tgt_data.add(row4_0 as usize * ncols_capacity)) = sum_0;
} }
if row4_1 < result.rows { if row4_1 < result.rows {
result.set_f32(row4_1, 0, sum_1); *(tgt_data.add(row4_1 as usize * ncols_capacity)) = sum_1;
} }
if row4_2 < result.rows { if row4_2 < result.rows {
result.set_f32(row4_2, 0, sum_2); *(tgt_data.add(row4_2 as usize * ncols_capacity)) = sum_2;
} }
if row4_3 < result.rows { if row4_3 < result.rows {
result.set_f32(row4_3, 0, sum_3); *(tgt_data.add(row4_3 as usize * ncols_capacity)) = sum_3;
}
}
result
}
}
/// Same as matrix_vector_mul but uses threading.
pub fn matrix_vector_mul_transposed_multithreaded(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
if self.cols != other.cols {
panic!(
"Invalid matrix-vector transposed multiplication {}x{} vs {}x{}",
self.rows, self.cols, other.rows, other.cols
);
}
assert_eq!(other.rows, 1);
assert_eq!(other.dtype, self.dtype);
assert_eq!(self.dtype, TensorDType::Float32);
// Use this to smuggle pointers to threads without Rust getting so goddamn mad
//
// Assumption usize = pointer size.
#[derive(Copy, Clone)]
struct WrappedPtr {
ptr: usize,
}
impl WrappedPtr {
fn wrap(ptr: *const u8) -> WrappedPtr {
WrappedPtr { ptr: ptr as usize }
}
fn unwrap(self) -> *const u8 {
self.ptr as *const u8
} }
} }
unsafe {
let result = Tensor::uninitialized(self.rows, 1, self.dtype);
let capacity_cols: i64 = self.capacity_cols;
let result_capacity_cols: i64 = result.capacity_cols;
let col_its: usize = if self.cols % 8 == 0 {
(self.cols / 8) as usize
} else {
(self.cols / 8 + 1) as usize
};
let self_data = WrappedPtr::wrap(self.data);
let other_data = WrappedPtr::wrap(other.data);
let result_data = WrappedPtr::wrap(result.data);
(0..self.rows as usize)
.into_par_iter()
.with_min_len(64)
.for_each(|row| {
let row = row as i64;
let self_data: *const f32 = self_data.unwrap() as *const f32;
let other_data: *const f32 = other_data.unwrap() as *const f32;
let result_data: *mut f32 = result_data.unwrap() as *mut f32;
let mut sum8: __m256 = _mm256_setzero_ps();
for col in 0..col_its {
let col = col * 8;
let left_side8 =
_mm256_loadu_ps(self_data.add((row * capacity_cols) as usize + col));
let right_side8 = _mm256_loadu_ps(other_data.add(col));
sum8 = _mm256_fmadd_ps(left_side8, right_side8, sum8);
}
let sum: f32 = horizontal_sum(sum8);
result_data
.add((row * result_capacity_cols) as usize)
.write(sum);
});
result result
} }
} }

@ -58,6 +58,7 @@ impl DataSettings {
} }
} }
#[allow(clippy::new_without_default)]
#[cfg(not(feature = "opencl"))] #[cfg(not(feature = "opencl"))]
pub fn new() -> Self { pub fn new() -> Self {
DataSettings { force_f16: false } DataSettings { force_f16: false }
@ -147,6 +148,7 @@ pub struct RMSNorm {
weight: Tensor, weight: Tensor,
} }
#[allow(dead_code)]
pub struct Attention { pub struct Attention {
wq: Tensor, wq: Tensor,
wk: Tensor, wk: Tensor,

Loading…
Cancel
Save