Add partial OpenCL support, it's used in feed forward network only.

broken-opencl-code
Mikko Juola 3 years ago
parent df079bceb0
commit 63d27dba90

@ -3,7 +3,7 @@ use crate::embedding::Embedding;
use crate::tensor_opencl_support::OpenCL;
use crate::token_sampler::TokenSampler;
use crate::tokenizer::{TokenId, Tokenizer};
use crate::transformer::Transformer;
use crate::transformer::{DataSettings, Transformer};
use crate::unpickler;
use crate::unpickler::Value;
use clap::Parser;
@ -38,6 +38,7 @@ struct Cli {
top_k: Option<i32>,
#[cfg(feature = "opencl")]
#[arg(long)]
opencl_device: Option<usize>,
}
@ -63,7 +64,7 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
}
#[cfg(feature = "opencl")]
let _opencl: Option<OpenCL> = {
let opencl: Option<OpenCL> = {
let opencl_device = cli.opencl_device.unwrap_or(0);
match OpenCL::new(!be_quiet, opencl_device) {
Err(openclerr) => {
@ -154,6 +155,20 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
None => 1024,
};
let data_settings = {
#[cfg(feature = "opencl")]
{
if let Some(opencl) = opencl {
let ds = DataSettings::new(Some(opencl));
ds.use_opencl()
} else {
DataSettings::new(None)
}
}
#[cfg(not(feature = "opencl"))]
DataSettings::new()
};
pln!("Loading transformer weights from {}...", model_path);
let tr = Transformer::from_unpickled(
&unpickle_results,
@ -163,6 +178,7 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
params.n_heads,
max_seq_len,
params.norm_eps,
data_settings,
model_path,
)?;
pln!("All is loaded. Starting inference.");

@ -151,6 +151,18 @@ fn horizontal_sum(mut ymm: __m256) -> f32 {
}
impl Tensor {
#[inline]
pub fn assume_on_gpu(&self) {
#[cfg(feature = "opencl")]
{
self.process_waiting_for_data();
let od = self.opencl_data.read().unwrap();
if !od.is_some() {
panic!("Tried to assume_on_gpu on a tensor that is on the CPU");
}
}
}
#[inline]
pub fn assume_on_cpu(&self) {
#[cfg(feature = "opencl")]
@ -544,14 +556,52 @@ impl Tensor {
}
pub fn hadamard_product(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
if self.cols != other.cols || self.rows != other.rows {
panic!(
"Invalid hadamard product: incompatible shapes, {}x{} vs {}x{}",
self.rows, self.cols, other.rows, other.cols
);
}
#[cfg(feature = "opencl")]
{
if self.is_on_gpu() {
self.hadamard_product_gpu(other)
} else {
self.hadamard_product_cpu(other)
}
}
#[cfg(not(feature = "opencl"))]
{
self.hadamard_product_cpu(other)
}
}
#[cfg(feature = "opencl")]
fn hadamard_product_gpu(&self, other: &Tensor) -> Tensor {
// Assume: sizes have been checked already
self.assume_on_gpu();
other.assume_on_gpu();
self.with_opencl_data(|self_tensor| {
let cl = self_tensor.cl();
// TODO: do not create a CPU-side copy
let result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
let mut result = result.to_f16();
result.to_gpu(&cl).unwrap();
result.with_opencl_data_mut(|tgt_tensor| {
tgt_tensor.copy_inplace(self_tensor).unwrap();
other.with_opencl_data(|other_tensor| {
tgt_tensor.hadamard_product_inplace(other_tensor).unwrap();
});
});
result
})
}
fn hadamard_product_cpu(&self, other: &Tensor) -> Tensor {
// Assume: sizes have been checked already
self.assume_on_cpu();
other.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
for row in 0..self.rows {
for col in 0..self.cols {
@ -595,7 +645,60 @@ impl Tensor {
}
pub fn silu(&self) -> Tensor {
self.assume_on_cpu();
#[cfg(feature = "opencl")]
{
if self.is_on_gpu() {
self.silu_gpu()
} else {
self.silu_cpu()
}
}
#[cfg(not(feature = "opencl"))]
{
self.silu_cpu()
}
}
// with_opencl_data & with_opencl_data_mut are utilities to get access to the underlying
// OpenCLTensor, if the tensor is on gpu. Panics if they are not on GPU.
#[cfg(feature = "opencl")]
fn with_opencl_data<F, R>(&self, f: F) -> R
where
F: FnOnce(&OpenCLTensor) -> R,
{
let opencl_data = self.opencl_data.read().unwrap();
let opencl_data = opencl_data.as_ref();
f(opencl_data.unwrap())
}
#[cfg(feature = "opencl")]
fn with_opencl_data_mut<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut OpenCLTensor) -> R,
{
let mut opencl_data = self.opencl_data.write().unwrap();
let opencl_data = opencl_data.as_mut();
f(opencl_data.unwrap())
}
#[cfg(feature = "opencl")]
fn silu_gpu(&self) -> Tensor {
self.assume_on_gpu();
self.with_opencl_data(|src_tensor| {
let cl: OpenCL = src_tensor.cl();
// TODO: don't generate a CPU-side copy, create the result directly on OpenCL side
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
result = result.to_f16();
result.to_gpu(&cl).unwrap();
result.with_opencl_data_mut(|tgt_tensor| {
tgt_tensor.copy_inplace(src_tensor).unwrap();
tgt_tensor.silu_inplace().unwrap();
});
result
})
}
fn silu_cpu(&self) -> Tensor {
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
for row in 0..self.rows {
for col in 0..self.cols {
@ -608,6 +711,37 @@ impl Tensor {
}
pub fn transpose(&self) -> Tensor {
#[cfg(feature = "opencl")]
{
if self.is_on_gpu() {
self.transpose_gpu()
} else {
self.transpose_cpu()
}
}
#[cfg(not(feature = "opencl"))]
{
self.transpose_cpu()
}
}
#[cfg(feature = "opencl")]
fn transpose_gpu(&self) -> Tensor {
self.assume_on_gpu();
self.with_opencl_data(|src_tensor| {
let cl: OpenCL = src_tensor.cl();
// TODO: don't generate a CPU-side copy, create the result directly on OpenCL side
let mut result = unsafe { Tensor::uninitialized(self.cols, self.rows, self.dtype) };
result = result.to_f16();
result.to_gpu(&cl).unwrap();
result.with_opencl_data_mut(|tgt_tensor| {
tgt_tensor.transpose_from(src_tensor).unwrap();
});
result
})
}
fn transpose_cpu(&self) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.cols, self.rows, self.dtype) };
for row in 0..self.rows {
@ -665,18 +799,27 @@ impl Tensor {
}
pub fn matrix_mul_transposed(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
if self.cols != other.cols {
panic!(
"Invalid matrix transposed multiplication {}x{} vs {}x{}",
self.rows, self.cols, other.cols, other.rows
);
}
#[cfg(not(feature = "opencl"))]
if other.rows == 1 {
return self.matrix_vector_mul_transposed(other);
}
#[cfg(feature = "opencl")]
if other.rows == 1 && self.is_on_cpu() {
return self.matrix_vector_mul_transposed(other);
}
let mut result = unsafe { Tensor::uninitialized(self.rows, other.rows, self.dtype) };
#[cfg(feature = "opencl")]
if self.is_on_gpu() {
let od = self.opencl_data.write().unwrap();
result.to_gpu(&od.as_ref().unwrap().cl()).unwrap();
}
result.matrix_mul_inplace_transposed(self, other);
result
}
@ -839,6 +982,11 @@ impl Tensor {
false
}
#[cfg(feature = "opencl")]
pub fn is_on_cpu(&self) -> bool {
return !self.is_on_gpu();
}
#[cfg(feature = "opencl")]
fn matrix_mul_inplace_transposed_gpu(&mut self, src: &Tensor, other: &Tensor) {
let mut self_od = self.opencl_data.write().unwrap();
@ -2031,10 +2179,110 @@ mod tests {
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_silu_and_cpu_silu_agree() {
let cl = OpenCL::new(false, 0).unwrap();
for _trial in 0..300 {
let mut rng = rand::thread_rng();
let a = rng.gen_range(1..=300);
let b = rng.gen_range(1..=300);
let mat1 = Tensor::random(a, b, TensorDType::Float16);
let mat2 = mat1.clone();
let mut mat2 = mat2.to_f16();
mat2.to_gpu(&cl).unwrap();
let mat1_result = mat1.silu();
let mut mat2_result = mat2.silu();
mat2_result.to_cpu().unwrap();
assert_eq!(mat1_result.rows(), mat2_result.rows());
assert_eq!(mat1_result.cols(), mat2_result.cols());
for row in 0..mat1_result.rows {
for col in 0..mat1_result.cols {
assert_relative_eq!(
mat1_result.get_f32(row, col),
mat2_result.get_f32(row, col),
epsilon = 1e-2
);
}
}
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_hadamard_product_and_cpu_hadamard_product_agree() {
let cl = OpenCL::new(false, 0).unwrap();
for _trial in 0..300 {
let mut rng = rand::thread_rng();
let a = rng.gen_range(1..=300);
let b = rng.gen_range(1..=300);
let mat1 = Tensor::random(a, b, TensorDType::Float16);
let mat2 = Tensor::random(a, b, TensorDType::Float16);
let mut mat1_gpu = mat1.to_f16();
let mut mat2_gpu = mat2.to_f16();
mat1_gpu.to_gpu(&cl).unwrap();
mat2_gpu.to_gpu(&cl).unwrap();
let result1 = mat1.hadamard_product(&mat2);
let mut result2 = mat1_gpu.hadamard_product(&mat2_gpu);
result2.to_cpu().unwrap();
assert_eq!(result1.rows(), result2.rows());
assert_eq!(result1.cols(), result2.cols());
for row in 0..result1.rows() {
for col in 0..result2.cols() {
assert_relative_eq!(
result1.get_f32(row, col),
result2.get_f32(row, col),
epsilon = 1e-2
);
}
}
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_transpose_and_cpu_transpose_agree() {
let cl = OpenCL::new(false, 0).unwrap();
let mut rng = rand::thread_rng();
for _trial in 0..300 {
let a = rng.gen_range(1..=100);
let b = rng.gen_range(1..=100);
let mat1 = Tensor::random(a, b, TensorDType::Float16);
let mut mat1_gpu = mat1.to_f16();
mat1_gpu.to_gpu(&cl).unwrap();
let mat1_transposed = mat1.transpose();
let mut mat1_gpu_transposed = mat1_gpu.transpose();
mat1_gpu_transposed.to_cpu().unwrap();
assert_eq!(mat1_transposed.rows(), mat1_gpu_transposed.rows());
assert_eq!(mat1_transposed.cols(), mat1_gpu_transposed.cols());
for row in 0..mat1_transposed.rows {
for col in 0..mat1_transposed.cols {
assert_relative_eq!(
mat1_transposed.get_f32(row, col),
mat1_gpu_transposed.get_f32(row, col),
epsilon = 1e-2,
);
}
}
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_matrix_mul_transposed_is_close_to_cpu_matrix_mul_transposed() {
let cl = OpenCL::new(true, 1).unwrap();
let cl = OpenCL::new(false, 0).unwrap();
let mut rng = rand::thread_rng();
for _trial in 0..300 {

@ -12,9 +12,15 @@ use thiserror::Error;
struct Programs {
matrix_mul_transposed_by_row_f16_program: Program,
matrix_mul_transposed_by_row_f16: Kernel,
silu_f16_program: Program,
silu_f16: Kernel,
hadamard_product_f16_program: Program,
hadamard_product_f16: Kernel,
transpose_f16_program: Program,
transpose_f16: Kernel,
}
#[derive(Debug)]
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct OpenCL {
ctx: Context,
@ -34,7 +40,7 @@ pub struct OpenCLTensor {
cols: i64,
cols_capacity: i64,
queue: Queue,
programs: Arc<RwLock<Programs>>,
cl: OpenCL,
}
#[derive(Debug)]
@ -143,13 +149,17 @@ impl OpenCL {
cols,
cols_capacity,
queue: self.queue.clone(),
programs: self.programs.clone(),
cl: self.clone(),
})
}
}
}
impl OpenCLTensor {
pub fn cl(&self) -> OpenCL {
self.cl.clone()
}
pub fn wait_until_ready(&mut self) {
if self.last_event.is_some() {
self.last_event.as_ref().unwrap().wait_for().unwrap();
@ -187,6 +197,93 @@ impl OpenCLTensor {
}
}
/// Copies all values from another tensor
pub fn copy_inplace(&mut self, other: &OpenCLTensor) -> Result<OpenCLEvent, OpenCLError> {
if other.rows != self.rows || other.cols != self.cols {
panic!(
"Cannot in-place copy tensors of different sizes: {}x{} <-- {}x{}",
self.rows, self.cols, other.rows, other.cols
);
}
let mut event = Event::empty();
other
.buf
.cmd()
.queue(&other.queue)
.copy(&self.buf, None, None)
.enew(&mut event)
.enq()?;
self.last_event = Some(event.clone());
Ok(OpenCLEvent { event })
}
pub fn transpose_from(&mut self, other: &OpenCLTensor) -> Result<OpenCLEvent, OpenCLError> {
let prg = self.cl.programs.write().unwrap();
prg.transpose_f16.set_arg(0, self.buf.clone()).unwrap();
prg.transpose_f16.set_arg(1, other.buf.clone()).unwrap();
prg.transpose_f16
.set_arg(2, self.cols_capacity as i32)
.unwrap();
prg.transpose_f16
.set_arg(3, other.cols_capacity as i32)
.unwrap();
let mut event = Event::empty();
unsafe {
let b = prg
.transpose_f16
.cmd()
.queue(&self.queue)
.global_work_size([self.rows as usize, self.cols as usize])
.enew(&mut event);
b.enq().unwrap();
}
self.last_event = Some(event.clone());
Ok(OpenCLEvent { event })
}
pub fn hadamard_product_inplace(
&mut self,
other: &OpenCLTensor,
) -> Result<OpenCLEvent, OpenCLError> {
let prg = self.cl.programs.write().unwrap();
prg.hadamard_product_f16.set_arg(0, self.buf.clone())?;
prg.hadamard_product_f16.set_arg(1, other.buf.clone())?;
prg.hadamard_product_f16
.set_arg(2, self.cols_capacity as i32)?;
prg.hadamard_product_f16
.set_arg(3, other.cols_capacity as i32)?;
let mut event = Event::empty();
unsafe {
let b = prg
.hadamard_product_f16
.cmd()
.queue(&self.queue)
.global_work_size([self.rows as usize, self.cols as usize])
.enew(&mut event);
b.enq()?;
}
self.last_event = Some(event.clone());
Ok(OpenCLEvent { event })
}
pub fn silu_inplace(&mut self) -> Result<OpenCLEvent, OpenCLError> {
let prg = self.cl.programs.write().unwrap();
prg.silu_f16.set_arg(0, self.buf.clone())?;
prg.silu_f16.set_arg(1, self.cols_capacity as i32)?;
let mut event = Event::empty();
unsafe {
let b = prg
.silu_f16
.cmd()
.queue(&self.queue)
.global_work_size([self.rows as usize, self.cols as usize])
.enew(&mut event);
b.enq()?;
}
self.last_event = Some(event.clone());
Ok(OpenCLEvent { event })
}
pub fn matrix_mul_inplace_transposed(
&mut self,
src: &OpenCLTensor,
@ -208,7 +305,7 @@ impl OpenCLTensor {
// Clear out the target memory
unsafe { self.buf.cmd().fill(0u16, None).block(false).enq()? };
let prg = self.programs.write().unwrap();
let prg = self.cl.programs.write().unwrap();
prg.matrix_mul_transposed_by_row_f16
.set_arg(0, self.buf.clone())?;
prg.matrix_mul_transposed_by_row_f16
@ -234,7 +331,7 @@ impl OpenCLTensor {
.matrix_mul_transposed_by_row_f16
.cmd()
.queue(&self.queue)
.global_work_size([self.rows as usize, self.cols_capacity as usize])
.global_work_size([self.rows as usize, self.cols as usize])
.enew(&mut event);
b.enq()?;
}
@ -251,18 +348,15 @@ impl OpenCLEvent {
}
fn make_programs(ctx: &Context, queue: &Queue) -> Result<Programs, OpenCLError> {
let mut last_err: Option<OpenCLError> = None;
// There used to be more sources here but now it's just one. This can go through programs and
// accept first one that compiles
for src in &[MATRIX_MUL_TRANSPOSED_BY_ROW_F16_SRC] {
fn make_programs_with_src(
ctx: &Context,
queue: &Queue,
src: &str,
) -> Result<Programs, OpenCLError> {
fn make_program_with_src(ctx: &Context, src: &str) -> Result<Program, OpenCLError> {
let program = Program::builder().src(src).build(&ctx)?;
let kernel = Kernel::builder()
.program(&program)
Ok(program)
}
let matrix_mul_transposed_by_row_f16_program =
make_program_with_src(ctx, MATRIX_MUL_TRANSPOSED_BY_ROW_F16_SRC)?;
let matrix_mul_transposed_by_row_f16 = Kernel::builder()
.program(&matrix_mul_transposed_by_row_f16_program)
.name("matrix_mul_transposed_by_row_f16")
.arg(None::<&Buffer<u16>>)
.arg(None::<&Buffer<u16>>)
@ -275,23 +369,44 @@ fn make_programs(ctx: &Context, queue: &Queue) -> Result<Programs, OpenCLError>
.arg(&0)
.queue(queue.clone())
.build()?;
let silu_f16_program = make_program_with_src(ctx, SILU_F16_SRC)?;
let silu_f16 = Kernel::builder()
.program(&silu_f16_program)
.name("silu_f16")
.arg(None::<&Buffer<u16>>)
.arg(&0)
.queue(queue.clone())
.build()?;
let hadamard_product_f16_program = make_program_with_src(ctx, HADAMARD_PRODUCT_F16_SRC)?;
let hadamard_product_f16 = Kernel::builder()
.program(&hadamard_product_f16_program)
.name("hadamard_product_f16")
.arg(None::<&Buffer<u16>>)
.arg(None::<&Buffer<u16>>)
.arg(&0)
.arg(&0)
.queue(queue.clone())
.build()?;
let transpose_f16_program = make_program_with_src(ctx, TRANSPOSE_F16_SRC)?;
let transpose_f16 = Kernel::builder()
.program(&transpose_f16_program)
.name("transpose_f16")
.arg(None::<&Buffer<u16>>)
.arg(None::<&Buffer<u16>>)
.arg(&0)
.arg(&0)
.queue(queue.clone())
.build()?;
Ok(Programs {
matrix_mul_transposed_by_row_f16_program: program,
matrix_mul_transposed_by_row_f16: kernel,
matrix_mul_transposed_by_row_f16_program,
matrix_mul_transposed_by_row_f16,
silu_f16_program,
silu_f16,
hadamard_product_f16_program,
hadamard_product_f16,
transpose_f16_program,
transpose_f16,
})
}
match make_programs_with_src(ctx, queue, src) {
Err(e) => {
last_err = Some(e);
continue;
}
Ok(p) => return Ok(p),
}
}
if last_err.is_none() {
unreachable!();
}
Err(last_err.unwrap())
}
const MATRIX_MUL_TRANSPOSED_BY_ROW_F16_SRC: &str = r#"
@ -367,3 +482,53 @@ __kernel void matrix_mul_transposed_by_row_f16(
vstore_half(total, 0, (__global half*) &tgt[tgt_row * ncols_capacity + tgt_col]);
}
"#;
/// Computes SILU for every f16 value in the tensor
const SILU_F16_SRC: &str = r#"
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void silu_f16(__global half *tgt,
const int ncols_capacity)
{
const int tgt_row = get_global_id(0);
const int tgt_col = get_global_id(1);
const float val = vload_half(tgt_row * ncols_capacity + tgt_col, (__global const half*) tgt);
const float result = val * (1.0 / (1.0 + exp(-val)));
vstore_half(result, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt);
}
"#;
/// Computes hadamard product of two identially sized tensors
const HADAMARD_PRODUCT_F16_SRC: &str = r#"
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void hadamard_product_f16(__global half *tgt,
__global const half *left,
const int ncols_capacity,
const int left_cols_capacity) {
const int tgt_row = get_global_id(0);
const int tgt_col = get_global_id(1);
const float tgt_value = vload_half(tgt_row * ncols_capacity + tgt_col, (__global const half*) tgt);
const float left_value = vload_half(tgt_row * left_cols_capacity + tgt_col, (__global const half*) left);
const float result = tgt_value * left_value;
vstore_half(result, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt);
}
"#;
/// Computes the transpose of a matrix
const TRANSPOSE_F16_SRC: &str = r#"
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void transpose_f16(__global half *tgt,
__global const half *left,
const int ncols_capacity,
const int left_cols_capacity)
{
const int tgt_row = get_global_id(0);
const int tgt_col = get_global_id(1);
const int src_row = tgt_col;
const int src_col = tgt_row;
const float val = vload_half(src_row * left_cols_capacity + src_col, (__global const half*) left);
vstore_half(val, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt);
}
"#;

@ -1,5 +1,7 @@
use crate::embedding::Embedding;
use crate::tensor::{FromPiecesDirection, Tensor, TensorDType};
#[cfg(feature = "opencl")]
use crate::tensor_opencl_support::OpenCL;
use crate::tokenizer::TokenId;
use crate::unpickler;
use crate::unpickler::UnpicklingError;
@ -28,6 +30,43 @@ pub struct Transformer {
layers: Vec<TransformerBlock>,
}
// Clone is cheap
#[derive(Clone)]
pub struct DataSettings {
#[cfg(feature = "opencl")]
use_opencl_for_feedforward: bool,
#[cfg(feature = "opencl")]
cl: Option<OpenCL>,
}
// OpenCL is safe to send to threads but Rust doesn't know that
unsafe impl Send for DataSettings {}
unsafe impl Sync for DataSettings {}
impl DataSettings {
#[cfg(feature = "opencl")]
pub fn new(cl: Option<OpenCL>) -> Self {
DataSettings {
use_opencl_for_feedforward: false,
cl: cl.clone(),
}
}
#[cfg(not(feature = "opencl"))]
pub fn new() -> Self {
DataSettings {}
}
#[cfg(feature = "opencl")]
pub fn use_opencl(mut self) -> DataSettings {
if self.cl.is_none() {
panic!("OpenCL is not available, cannot call use_opencl() on DataSettings.");
}
self.use_opencl_for_feedforward = true;
self
}
}
pub struct TransformerCaches {
layer_caches: Vec<AttentionCache>,
}
@ -105,10 +144,12 @@ pub struct Attention {
head_dim: usize,
}
#[allow(dead_code)]
pub struct FeedForward {
w1: Tensor,
w2: Tensor,
w3: Tensor,
data_settings: DataSettings,
}
impl Transformer {
@ -121,6 +162,7 @@ impl Transformer {
n_heads: usize,
max_seq_len: usize,
eps: f64,
data_settings: DataSettings,
data_dir: P,
) -> Result<Transformer, UnpicklingError> {
assert_eq!(dim % n_heads, 0);
@ -141,6 +183,7 @@ impl Transformer {
eps,
n_local_heads,
head_dim,
data_settings.clone(),
data_dir,
);
progress_bar.inc(1);
@ -238,10 +281,11 @@ impl TransformerBlock {
eps: f64,
n_local_heads: usize,
head_dim: usize,
data_settings: DataSettings,
data_dir: P,
) -> Result<Self, UnpicklingError> {
let data_dir: &Path = data_dir.as_ref();
let ff = FeedForward::from_unpickled(unpickled, layer_id, data_dir)?;
let ff = FeedForward::from_unpickled(unpickled, layer_id, data_dir, data_settings)?;
let attn =
Attention::from_unpickled(unpickled, layer_id, n_local_heads, head_dim, data_dir)?;
let ffn_norm = RMSNorm::from_unpickled(
@ -277,8 +321,8 @@ impl TransformerBlock {
.attn
.forward(&attnorm_out, start_pos, freqs_cis, mask, attention_cache);
let h = x.add(&att_out);
let att_out = self.ffn_norm.forward(&h);
let att_out = self.feed_forward.forward(&att_out).transpose();
let mut att_out = self.ffn_norm.forward(&h);
let att_out = self.feed_forward.forward(&mut att_out).transpose();
h.add(&att_out)
}
}
@ -316,35 +360,70 @@ impl FeedForward {
unpickled: &[unpickler::Value],
layer_id: usize,
data_dir: P,
data_settings: DataSettings,
) -> Result<FeedForward, UnpicklingError> {
let data_dir: &Path = data_dir.as_ref();
let w1 = Tensor::from_unpickled_pieces(
let mut w1 = Tensor::from_unpickled_pieces(
unpickled,
format!("layers.{}.feed_forward.w1.weight", layer_id),
data_dir,
FromPiecesDirection::Rows,
)?
.to_f32();
let w2 = Tensor::from_unpickled_pieces(
)?;
let mut w2 = Tensor::from_unpickled_pieces(
unpickled,
format!("layers.{}.feed_forward.w2.weight", layer_id),
data_dir,
FromPiecesDirection::Cols,
)?
.to_f32();
let w3 = Tensor::from_unpickled_pieces(
)?;
let mut w3 = Tensor::from_unpickled_pieces(
unpickled,
format!("layers.{}.feed_forward.w3.weight", layer_id),
data_dir,
FromPiecesDirection::Rows,
)?
.to_f32();
)?;
Ok(Self { w1, w2, w3 })
#[cfg(feature = "opencl")]
{
if data_settings.use_opencl_for_feedforward {
w1 = w1.to_f16();
w2 = w2.to_f16();
w3 = w3.to_f16();
let ds = data_settings.clone();
w1.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap();
w2.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap();
w3.to_gpu(&ds.cl.unwrap()).unwrap();
}
}
#[cfg(not(feature = "opencl"))]
{
w1 = w1.to_f32();
w2 = w2.to_f32();
w3 = w3.to_f32();
}
Ok(Self {
w1,
w2,
w3,
data_settings,
})
}
pub fn forward(&self, x: &Tensor) -> Tensor {
pub fn forward(&self, x: &mut Tensor) -> Tensor {
#[cfg(feature = "opencl")]
let x_was_on_cpu: bool;
#[cfg(feature = "opencl")]
{
x_was_on_cpu = x.is_on_cpu();
}
#[cfg(feature = "opencl")]
{
if self.data_settings.use_opencl_for_feedforward {
*x = x.to_f16();
x.to_gpu(self.data_settings.cl.as_ref().unwrap()).unwrap();
}
}
let (w1_out, w3_out) = rayon::join(
|| self.w1.matrix_mul_transposed(x),
|| self.w3.matrix_mul_transposed(x),
@ -352,12 +431,24 @@ impl FeedForward {
let w1_out = w1_out.silu();
let w1w3_out = w1_out.hadamard_product(&w3_out).transpose();
#[cfg(not(feature = "opencl"))]
if w1w3_out.rows() == 1 {
return self
.w2
.matrix_vector_mul_transposed_multithreaded(&w1w3_out);
} else {
return self.w2.matrix_mul_transposed(&w1w3_out);
}
#[cfg(feature = "opencl")]
{
let mut result = self.w2.matrix_mul_transposed(&w1w3_out);
if x_was_on_cpu {
result.to_cpu().unwrap();
result
} else {
result
}
}
self.w2.matrix_mul_transposed(&w1w3_out)
}
}

Loading…
Cancel
Save