Compare commits

...

1 Commits

Author SHA1 Message Date
Mikko Juola 17b9d90570 Add lots of code I added for OpenCL, but text generation got broken and I have no idea why.
I don't want to throw this away so I'll park it.
3 years ago

@ -62,10 +62,10 @@ This is a hobby thing for me so don't expect updates or help.
* Some other CPU implementations use quantization to reduce the size of weights
* Put some of the operations on the OpenCL GPU/CPU. I've made some initial
OpenCL code but it is not used in the transformer loop yet. The CPU OpenCL
improves my own AVX2 code by like 100% and massively so on GPU although I am
also like 20x slower than equivalent operation on PyTorch on the same GPU.
* I've heard there is some thing called Tensor Cores on nVidia GPUs. Not
OpenCL code but there's still bunch of stuff that could be OpenCLified.
The OpenCL code is fast for both GPU OpenCL and CPU OpenCL (better than my
own handwritten AVX2 code which makes me sad).
* I've heard there is some thing called Tensor Cores on NVidia GPUs. Not
accessible with OpenCL. But might be accessible on Vulkan with a an
extension.
* More sophisticated token sampling. I saw on Hackernews some comments how the

@ -386,8 +386,40 @@ impl Tensor {
tensor
}
// Computes mean for each row, so that columns become 1.
// Computes mean for each row. The resulting matrix will have 1 column (which will contain the
// mean)
pub fn mean_cols(&self) -> Tensor {
#[cfg(feature = "opencl")]
{
if self.is_on_gpu() {
self.mean_cols_gpu()
} else {
self.mean_cols_cpu()
}
}
#[cfg(not(feature = "opencl"))]
{
self.mean_cols_cpu()
}
}
#[cfg(feature = "opencl")]
fn mean_cols_gpu(&self) -> Tensor {
self.assume_on_gpu();
self.with_opencl_data(|src_tensor| {
let cl: OpenCL = src_tensor.cl();
// TODO: don't generate a CPU-side copy, create the result directly on OpenCL side
let mut result = unsafe { Tensor::uninitialized(self.rows, 1, self.dtype) };
result = result.to_f16();
result.to_gpu(&cl).unwrap();
result.with_opencl_data_mut(|tgt_tensor| {
tgt_tensor.mean_cols_from(src_tensor).unwrap();
});
result
})
}
fn mean_cols_cpu(&self) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, 1, self.dtype) };
for row in 0..self.rows {
@ -414,6 +446,38 @@ impl Tensor {
}
pub fn pow(&self, power: f32) -> Tensor {
#[cfg(feature = "opencl")]
{
if self.is_on_gpu() {
return self.pow_gpu(power);
} else {
return self.pow_cpu(power);
}
}
#[cfg(not(feature = "opencl"))]
{
return self.pow_cpu(power);
}
}
#[cfg(feature = "opencl")]
fn pow_gpu(&self, power: f32) -> Tensor {
self.assume_on_gpu();
self.with_opencl_data(|src_tensor| {
let cl: OpenCL = src_tensor.cl();
// TODO: don't generate a CPU-side copy, create the result directly on OpenCL side
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
result = result.to_f16();
result.to_gpu(&cl).unwrap();
result.with_opencl_data_mut(|tgt_tensor| {
tgt_tensor.copy_inplace(src_tensor).unwrap();
tgt_tensor.pow_inplace(power).unwrap();
});
result
})
}
fn pow_cpu(&self, power: f32) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
for row in 0..self.rows {
@ -437,7 +501,39 @@ impl Tensor {
result
}
/// Computes 1/sqrt(x) for each element in the tensor.
pub fn rsqrt(&self) -> Tensor {
#[cfg(feature = "opencl")]
{
if self.is_on_gpu() {
return self.rsqrt_gpu();
} else {
return self.rsqrt_cpu();
}
}
#[cfg(not(feature = "opencl"))]
{
return self.rsqrt_cpu();
}
}
fn rsqrt_gpu(&self) -> Tensor {
self.assume_on_gpu();
self.with_opencl_data(|src_tensor| {
let cl: OpenCL = src_tensor.cl();
// TODO: don't generate a CPU-side copy, create the result directly on OpenCL side
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
result = result.to_f16();
result.to_gpu(&cl).unwrap();
result.with_opencl_data_mut(|tgt_tensor| {
tgt_tensor.copy_inplace(src_tensor).unwrap();
tgt_tensor.rsqrt_inplace().unwrap();
});
result
})
}
fn rsqrt_cpu(&self) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
for row in 0..self.rows {
@ -450,8 +546,6 @@ impl Tensor {
}
pub fn add(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
if self.rows() != other.rows() || self.cols() != other.cols() {
panic!(
"add: Tensors must have the same shape, left: {}x{} right: {}x{}",
@ -461,6 +555,43 @@ impl Tensor {
other.cols()
);
}
#[cfg(feature = "opencl")]
{
if self.is_on_gpu() {
return self.add_gpu(other);
} else {
return self.add_cpu(other);
}
}
#[cfg(not(feature = "opencl"))]
{
return self.add_cpu(other);
}
}
fn add_gpu(&self, other: &Tensor) -> Tensor {
self.assume_on_gpu();
other.assume_on_gpu();
self.with_opencl_data(|src_tensor| {
let cl: OpenCL = src_tensor.cl();
// TODO: don't generate a CPU-side copy, create the result directly on OpenCL side
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
result = result.to_f16();
result.to_gpu(&cl).unwrap();
other.with_opencl_data(|other_tensor| {
result.with_opencl_data_mut(|tgt_tensor| {
tgt_tensor.copy_inplace(src_tensor).unwrap();
tgt_tensor.add_inplace(other_tensor).unwrap();
});
});
result
})
}
fn add_cpu(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
for row in 0..self.rows {
for col in 0..self.cols {
@ -472,6 +603,37 @@ impl Tensor {
}
pub fn add_scalar(&self, scalar: f32) -> Tensor {
#[cfg(feature = "opencl")]
{
if self.is_on_gpu() {
self.add_scalar_gpu(scalar)
} else {
self.add_scalar_cpu(scalar)
}
}
#[cfg(not(feature = "opencl"))]
{
self.add_scalar_cpu(scalar)
}
}
fn add_scalar_gpu(&self, scalar: f32) -> Tensor {
self.assume_on_gpu();
self.with_opencl_data(|src_tensor| {
let cl: OpenCL = src_tensor.cl();
// TODO: don't generate a CPU-side copy, create the result directly on OpenCL side
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
result = result.to_f16();
result.to_gpu(&cl).unwrap();
result.with_opencl_data_mut(|tgt_tensor| {
tgt_tensor.copy_inplace(src_tensor).unwrap();
tgt_tensor.add_scalar_inplace(scalar).unwrap();
});
result
})
}
fn add_scalar_cpu(&self, scalar: f32) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
for row in 0..self.rows {
@ -496,6 +658,48 @@ impl Tensor {
}
pub fn scalar_multiply_broadcast(&self, other: &Tensor) -> Tensor {
#[cfg(feature = "opencl")]
{
if self.is_on_gpu() {
self.scalar_multiply_broadcast_gpu(other)
} else {
self.scalar_multiply_broadcast_cpu(other)
}
}
#[cfg(not(feature = "opencl"))]
{
self.scalar_multiply_broadcast_cpu(other)
}
}
fn scalar_multiply_broadcast_gpu(&self, other: &Tensor) -> Tensor {
self.assume_on_gpu();
other.assume_on_gpu();
if other.cols != 1 {
panic!("Invalid scalar broadcast");
}
if other.rows != self.rows {
panic!("Invalid scalar broadcast");
}
self.with_opencl_data(|src_tensor| {
let cl: OpenCL = src_tensor.cl();
// TODO: don't generate a CPU-side copy, create the result directly on OpenCL side
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
result = result.to_f16();
result.to_gpu(&cl).unwrap();
other.with_opencl_data(|other_tensor| {
result.with_opencl_data_mut(|tgt_tensor| {
tgt_tensor.copy_inplace(src_tensor).unwrap();
tgt_tensor
.scalar_multiply_broadcast_inplace(other_tensor)
.unwrap();
});
});
result
})
}
fn scalar_multiply_broadcast_cpu(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
if other.cols != 1 {
panic!("Invalid scalar broadcast");
@ -531,8 +735,6 @@ impl Tensor {
}
pub fn hadamard_product_broadcast(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
if self.cols != other.cols {
panic!(
"Invalid hadamard product broadcast: {}x{} vs {}x{}",
@ -545,6 +747,44 @@ impl Tensor {
self.rows, self.cols, other.rows, other.cols
);
}
#[cfg(feature = "opencl")]
{
if self.is_on_gpu() {
self.hadamard_product_broadcast_gpu(other)
} else {
self.hadamard_product_broadcast_cpu(other)
}
}
#[cfg(not(feature = "opencl"))]
{
self.hadamard_product_broadcast_cpu(scalar)
}
}
fn hadamard_product_broadcast_gpu(&self, other: &Tensor) -> Tensor {
self.assume_on_gpu();
self.with_opencl_data(|src_tensor| {
let cl: OpenCL = src_tensor.cl();
// TODO: don't generate a CPU-side copy, create the result directly on OpenCL side
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
result = result.to_f16();
result.to_gpu(&cl).unwrap();
other.with_opencl_data(|other_tensor| {
result.with_opencl_data_mut(|tgt_tensor| {
tgt_tensor.copy_inplace(src_tensor).unwrap();
tgt_tensor
.hadamard_product_broadcast_inplace(other_tensor)
.unwrap();
});
});
result
})
}
fn hadamard_product_broadcast_cpu(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
for row in 0..self.rows {
for col in 0..self.cols {
@ -644,21 +884,6 @@ impl Tensor {
result
}
pub fn silu(&self) -> Tensor {
#[cfg(feature = "opencl")]
{
if self.is_on_gpu() {
self.silu_gpu()
} else {
self.silu_cpu()
}
}
#[cfg(not(feature = "opencl"))]
{
self.silu_cpu()
}
}
// with_opencl_data & with_opencl_data_mut are utilities to get access to the underlying
// OpenCLTensor, if the tensor is on gpu. Panics if they are not on GPU.
#[cfg(feature = "opencl")]
@ -681,6 +906,21 @@ impl Tensor {
f(opencl_data.unwrap())
}
pub fn silu(&self) -> Tensor {
#[cfg(feature = "opencl")]
{
if self.is_on_gpu() {
self.silu_gpu()
} else {
self.silu_cpu()
}
}
#[cfg(not(feature = "opencl"))]
{
self.silu_cpu()
}
}
#[cfg(feature = "opencl")]
fn silu_gpu(&self) -> Tensor {
self.assume_on_gpu();
@ -2212,6 +2452,181 @@ mod tests {
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_rsqrt_and_cpu_rsqrt_agree() {
let cl = OpenCL::new(false, 0).unwrap();
for _trial in 0..300 {
let mut rng = rand::thread_rng();
let a = rng.gen_range(1..=300);
let b = rng.gen_range(1..=300);
let mat1 = Tensor::random(a, b, TensorDType::Float16);
let mat2 = mat1.clone();
let mut mat2 = mat2.to_f16();
mat2.to_gpu(&cl).unwrap();
let mat1_result = mat1.rsqrt();
let mut mat2_result = mat2.rsqrt();
mat2_result.to_cpu().unwrap();
assert_eq!(mat1_result.rows(), mat2_result.rows());
assert_eq!(mat1_result.cols(), mat2_result.cols());
for row in 0..mat1_result.rows {
for col in 0..mat1_result.cols {
let mat1_v = mat1_result.get_f32(row, col);
let mat2_v = mat2_result.get_f32(row, col);
if mat1_v.is_nan() && mat2_v.is_nan() {
continue;
}
assert_relative_eq!(mat1_v, mat2_v, epsilon = 1e-2);
}
}
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_add_and_cpu_add_agree() {
let cl = OpenCL::new(false, 0).unwrap();
for _trial in 0..300 {
let mut rng = rand::thread_rng();
let a = rng.gen_range(1..=300);
let b = rng.gen_range(1..=300);
let mat1 = Tensor::random(a, b, TensorDType::Float16);
let mat2 = Tensor::random(a, b, TensorDType::Float16);
let mut mat2 = mat2.to_f16();
mat2.to_gpu(&cl).unwrap();
let mat1_result = mat1.rsqrt();
let mut mat2_result = mat2.rsqrt();
mat2_result.to_cpu().unwrap();
assert_eq!(mat1_result.rows(), mat2_result.rows());
assert_eq!(mat1_result.cols(), mat2_result.cols());
for row in 0..mat1_result.rows {
for col in 0..mat1_result.cols {
let mat1_v = mat1_result.get_f32(row, col);
let mat2_v = mat2_result.get_f32(row, col);
if mat1_v.is_nan() && mat2_v.is_nan() {
continue;
}
assert_relative_eq!(mat1_v, mat2_v, epsilon = 1e-2);
}
}
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_pow_and_cpu_pow_agree() {
let cl = OpenCL::new(false, 0).unwrap();
for _trial in 0..300 {
let mut rng = rand::thread_rng();
let a = rng.gen_range(1..=100);
let b = rng.gen_range(1..=100);
let c = rng.gen_range(-1.2..1.2);
let mat1 = Tensor::random(a, b, TensorDType::Float16);
let mat2 = mat1.clone();
let mut mat2 = mat2.to_f16();
mat2.to_gpu(&cl).unwrap();
let mat1_result = mat1.pow(c);
let mut mat2_result = mat2.pow(c);
mat2_result.to_cpu().unwrap();
assert_eq!(mat1_result.rows(), mat2_result.rows());
assert_eq!(mat1_result.cols(), mat2_result.cols());
for row in 0..mat1_result.rows {
for col in 0..mat1_result.cols {
let left = mat1_result.get_f32(row, col);
let right = mat2_result.get_f32(row, col);
if left.is_nan() && right.is_nan() {
continue;
}
assert_relative_eq!(left, right, epsilon = 1e-1);
}
}
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_add_scalar_and_cpu_add_scalar_agree() {
let cl = OpenCL::new(false, 0).unwrap();
for _trial in 0..300 {
let mut rng = rand::thread_rng();
let a = rng.gen_range(1..=100);
let b = rng.gen_range(1..=100);
let mat1 = Tensor::random(a, b, TensorDType::Float16);
let mat2 = Tensor::random(a, b, TensorDType::Float16);
let mut mat1_gpu = mat1.clone();
let mut mat2_gpu = mat2.clone();
mat1_gpu.to_gpu(&cl).unwrap();
mat2_gpu.to_gpu(&cl).unwrap();
let result1 = mat1.add(&mat2);
let mut result2 = mat1_gpu.add(&mat2_gpu);
result2.to_cpu().unwrap();
assert_eq!(result1.rows(), result2.rows());
assert_eq!(result1.cols(), result2.cols());
for row in 0..result1.rows {
for col in 0..result1.cols {
let left = result1.get_f32(row, col);
let right = result2.get_f32(row, col);
if left.is_nan() && right.is_nan() {
continue;
}
assert_relative_eq!(left, right, epsilon = 1e-2);
}
}
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_mean_cols_and_cpu_mean_cols_agree() {
let cl = OpenCL::new(false, 0).unwrap();
for _trial in 0..300 {
let mut rng = rand::thread_rng();
let a = rng.gen_range(1..=100);
let b = rng.gen_range(1..=100);
let mat1 = Tensor::random(a, b, TensorDType::Float16);
let mat2 = mat1.clone();
let mut mat2 = mat2.to_f16();
mat2.to_gpu(&cl).unwrap();
let mat1_result = mat1.mean_cols();
let mut mat2_result = mat2.mean_cols();
mat2_result.to_cpu().unwrap();
assert_eq!(mat1_result.rows(), mat2_result.rows());
assert_eq!(mat1_result.cols(), mat2_result.cols());
for row in 0..mat1_result.rows {
for col in 0..mat1_result.cols {
let left = mat1_result.get_f32(row, col);
let right = mat2_result.get_f32(row, col);
if left.is_nan() && right.is_nan() {
continue;
}
assert_relative_eq!(left, right, epsilon = 1e-2,);
}
}
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_hadamard_product_and_cpu_hadamard_product_agree() {
@ -2248,6 +2663,78 @@ mod tests {
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_hadamard_product_broadcast_and_cpu_hadamard_product_broadcast_agree() {
let cl = OpenCL::new(false, 0).unwrap();
for _trial in 0..300 {
let mut rng = rand::thread_rng();
let a = rng.gen_range(1..=300);
let b = rng.gen_range(1..=300);
let mat1 = Tensor::random(a, b, TensorDType::Float16);
let mat2 = Tensor::random(1, b, TensorDType::Float16);
let mut mat1_gpu = mat1.to_f16();
let mut mat2_gpu = mat2.to_f16();
mat1_gpu.to_gpu(&cl).unwrap();
mat2_gpu.to_gpu(&cl).unwrap();
let result1 = mat1.hadamard_product_broadcast(&mat2);
let mut result2 = mat1_gpu.hadamard_product_broadcast(&mat2_gpu);
result2.to_cpu().unwrap();
assert_eq!(result1.rows(), result2.rows());
assert_eq!(result1.cols(), result2.cols());
for row in 0..result1.rows() {
for col in 0..result2.cols() {
assert_relative_eq!(
result1.get_f32(row, col),
result2.get_f32(row, col),
epsilon = 1e-2
);
}
}
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_scalar_multiply_product_broadcast_and_cpu_scalar_multiply_product_broadcast_agree() {
let cl = OpenCL::new(false, 0).unwrap();
for _trial in 0..300 {
let mut rng = rand::thread_rng();
let a = rng.gen_range(1..=300);
let b = rng.gen_range(1..=300);
let mat1 = Tensor::random(a, b, TensorDType::Float16);
let mat2 = Tensor::random(a, 1, TensorDType::Float16);
let mut mat1_gpu = mat1.to_f16();
let mut mat2_gpu = mat2.to_f16();
mat1_gpu.to_gpu(&cl).unwrap();
mat2_gpu.to_gpu(&cl).unwrap();
let result1 = mat1.scalar_multiply_broadcast(&mat2);
let mut result2 = mat1_gpu.scalar_multiply_broadcast(&mat2_gpu);
result2.to_cpu().unwrap();
assert_eq!(result1.rows(), result2.rows());
assert_eq!(result1.cols(), result2.cols());
for row in 0..result1.rows() {
for col in 0..result2.cols() {
assert_relative_eq!(
result1.get_f32(row, col),
result2.get_f32(row, col),
epsilon = 1e-2
);
}
}
}
}
#[cfg(feature = "opencl")]
#[test]
fn gpu_transpose_and_cpu_transpose_agree() {
@ -2313,11 +2800,13 @@ mod tests {
for row in 0..mat3.rows {
for col in 0..mat3.cols {
assert_relative_eq!(
mat3.get_f32(row, col),
mat3_gpu.get_f32(row, col),
epsilon = 1e-2,
);
let left = mat3.get_f32(row, col);
let right = mat3_gpu.get_f32(row, col);
if left.is_nan() && right.is_nan() {
continue;
}
assert_relative_eq!(left, right, epsilon = 1e-2,);
}
}
}

@ -18,6 +18,20 @@ struct Programs {
hadamard_product_f16: Kernel,
transpose_f16_program: Program,
transpose_f16: Kernel,
pow_f16_program: Program,
pow_f16: Kernel,
mean_cols_f16_program: Program,
mean_cols_f16: Kernel,
add_scalar_f16_program: Program,
add_scalar_f16: Kernel,
scalar_multiply_broadcast_f16_program: Program,
scalar_multiply_broadcast_f16: Kernel,
hadamard_product_broadcast_f16_program: Program,
hadamard_product_broadcast_f16: Kernel,
rsqrt_f16_program: Program,
rsqrt_f16: Kernel,
add_f16_program: Program,
add_f16: Kernel,
}
#[derive(Debug, Clone)]
@ -217,6 +231,58 @@ impl OpenCLTensor {
Ok(OpenCLEvent { event })
}
pub fn add_scalar_inplace(&mut self, scalar: f32) -> Result<OpenCLEvent, OpenCLError> {
let prg = self.cl.programs.write().unwrap();
prg.add_scalar_f16.set_arg(0, self.buf.clone()).unwrap();
prg.add_scalar_f16
.set_arg(1, self.cols_capacity as i32)
.unwrap();
prg.add_scalar_f16.set_arg(2, scalar).unwrap();
let mut event = Event::empty();
unsafe {
let b = prg
.add_scalar_f16
.cmd()
.queue(&self.queue)
.global_work_size([self.rows as usize, self.cols as usize])
.enew(&mut event);
b.enq()?;
}
self.last_event = Some(event.clone());
Ok(OpenCLEvent { event })
}
pub fn scalar_multiply_broadcast_inplace(
&mut self,
other: &OpenCLTensor,
) -> Result<OpenCLEvent, OpenCLError> {
let prg = self.cl.programs.write().unwrap();
prg.scalar_multiply_broadcast_f16
.set_arg(0, self.buf.clone())
.unwrap();
prg.scalar_multiply_broadcast_f16
.set_arg(1, other.buf.clone())
.unwrap();
prg.scalar_multiply_broadcast_f16
.set_arg(2, self.cols_capacity as i32)
.unwrap();
prg.scalar_multiply_broadcast_f16
.set_arg(3, other.cols_capacity as i32)
.unwrap();
let mut event = Event::empty();
unsafe {
let b = prg
.scalar_multiply_broadcast_f16
.cmd()
.queue(&self.queue)
.global_work_size([self.rows as usize, (self.cols_capacity / 16) as usize])
.enew(&mut event);
b.enq()?;
}
self.last_event = Some(event.clone());
Ok(OpenCLEvent { event })
}
pub fn transpose_from(&mut self, other: &OpenCLTensor) -> Result<OpenCLEvent, OpenCLError> {
let prg = self.cl.programs.write().unwrap();
prg.transpose_f16.set_arg(0, self.buf.clone()).unwrap();
@ -235,7 +301,7 @@ impl OpenCLTensor {
.queue(&self.queue)
.global_work_size([self.rows as usize, self.cols as usize])
.enew(&mut event);
b.enq().unwrap();
b.enq()?;
}
self.last_event = Some(event.clone());
Ok(OpenCLEvent { event })
@ -266,6 +332,85 @@ impl OpenCLTensor {
Ok(OpenCLEvent { event })
}
pub fn hadamard_product_broadcast_inplace(
&mut self,
other: &OpenCLTensor,
) -> Result<OpenCLEvent, OpenCLError> {
let prg = self.cl.programs.write().unwrap();
prg.hadamard_product_broadcast_f16
.set_arg(0, self.buf.clone())?;
prg.hadamard_product_broadcast_f16
.set_arg(1, other.buf.clone())?;
prg.hadamard_product_broadcast_f16
.set_arg(2, self.cols_capacity as i32)?;
prg.hadamard_product_broadcast_f16
.set_arg(3, other.cols_capacity as i32)?;
let mut event = Event::empty();
unsafe {
let b = prg
.hadamard_product_broadcast_f16
.cmd()
.queue(&self.queue)
.global_work_size([self.rows as usize, (self.cols_capacity as usize) / 16])
.enew(&mut event);
b.enq()?;
}
self.last_event = Some(event.clone());
Ok(OpenCLEvent { event })
}
pub fn mean_cols_from(&mut self, other: &OpenCLTensor) -> Result<OpenCLEvent, OpenCLError> {
if self.cols != 1 {
panic!(
"mean_cols_from: number of columns in target is not 1: {}",
self.cols
);
}
if self.rows != other.rows {
panic!(
"mean_cols_from: number of rows in target is not equal to number of rows in source: {} != {}",
self.rows, other.rows
);
}
let prg = self.cl.programs.write().unwrap();
prg.mean_cols_f16.set_arg(0, self.buf.clone())?;
prg.mean_cols_f16.set_arg(1, other.buf.clone())?;
prg.mean_cols_f16.set_arg(2, self.cols_capacity as i32)?;
prg.mean_cols_f16.set_arg(3, other.cols_capacity as i32)?;
prg.mean_cols_f16.set_arg(4, other.cols as i32)?;
let mut event = Event::empty();
unsafe {
let b = prg
.mean_cols_f16
.cmd()
.queue(&self.queue)
.global_work_size([self.rows as usize, 1])
.enew(&mut event);
b.enq()?;
}
self.last_event = Some(event.clone());
Ok(OpenCLEvent { event })
}
pub fn pow_inplace(&mut self, scalar: f32) -> Result<OpenCLEvent, OpenCLError> {
let prg = self.cl.programs.write().unwrap();
prg.pow_f16.set_arg(0, self.buf.clone())?;
prg.pow_f16.set_arg(1, self.cols_capacity as i32)?;
prg.pow_f16.set_arg(2, scalar)?;
let mut event = Event::empty();
unsafe {
let b = prg
.pow_f16
.cmd()
.queue(&self.queue)
.global_work_size([self.rows as usize, self.cols as usize])
.enew(&mut event);
b.enq()?;
}
self.last_event = Some(event.clone());
Ok(OpenCLEvent { event })
}
pub fn silu_inplace(&mut self) -> Result<OpenCLEvent, OpenCLError> {
let prg = self.cl.programs.write().unwrap();
prg.silu_f16.set_arg(0, self.buf.clone())?;
@ -284,6 +429,44 @@ impl OpenCLTensor {
Ok(OpenCLEvent { event })
}
pub fn add_inplace(&mut self, left: &OpenCLTensor) -> Result<OpenCLEvent, OpenCLError> {
let prg = self.cl.programs.write().unwrap();
prg.add_f16.set_arg(0, self.buf.clone())?;
prg.add_f16.set_arg(1, left.buf.clone())?;
prg.add_f16.set_arg(2, self.cols_capacity as i32)?;
prg.add_f16.set_arg(3, left.cols_capacity as i32)?;
let mut event = Event::empty();
unsafe {
let b = prg
.add_f16
.cmd()
.queue(&self.queue)
.global_work_size([self.rows as usize, self.cols as usize])
.enew(&mut event);
b.enq()?;
}
self.last_event = Some(event.clone());
Ok(OpenCLEvent { event })
}
pub fn rsqrt_inplace(&mut self) -> Result<OpenCLEvent, OpenCLError> {
let prg = self.cl.programs.write().unwrap();
prg.rsqrt_f16.set_arg(0, self.buf.clone())?;
prg.rsqrt_f16.set_arg(1, self.cols_capacity as i32)?;
let mut event = Event::empty();
unsafe {
let b = prg
.rsqrt_f16
.cmd()
.queue(&self.queue)
.global_work_size([self.rows as usize, self.cols as usize])
.enew(&mut event);
b.enq()?;
}
self.last_event = Some(event.clone());
Ok(OpenCLEvent { event })
}
pub fn matrix_mul_inplace_transposed(
&mut self,
src: &OpenCLTensor,
@ -397,6 +580,75 @@ fn make_programs(ctx: &Context, queue: &Queue) -> Result<Programs, OpenCLError>
.arg(&0)
.queue(queue.clone())
.build()?;
let pow_f16_program = make_program_with_src(ctx, POW_F16_SRC)?;
let pow_f16 = Kernel::builder()
.program(&pow_f16_program)
.name("pow_f16")
.arg(None::<&Buffer<u16>>)
.arg(&0)
.arg(&0)
.queue(queue.clone())
.build()?;
let mean_cols_f16_program = make_program_with_src(ctx, MEAN_COLS_F16_SRC)?;
let mean_cols_f16 = Kernel::builder()
.program(&mean_cols_f16_program)
.name("mean_cols_f16")
.arg(None::<&Buffer<u16>>)
.arg(None::<&Buffer<u16>>)
.arg(&0)
.arg(&0)
.arg(&0)
.queue(queue.clone())
.build()?;
let add_scalar_f16_program = make_program_with_src(ctx, ADD_SCALAR_F16_SRC)?;
let add_scalar_f16 = Kernel::builder()
.program(&add_scalar_f16_program)
.name("add_scalar_f16")
.arg(None::<&Buffer<u16>>)
.arg(&0)
.arg(&0)
.queue(queue.clone())
.build()?;
let scalar_multiply_broadcast_f16_program =
make_program_with_src(ctx, SCALAR_MULTIPLY_BROADCAST_F16_SRC)?;
let scalar_multiply_broadcast_f16 = Kernel::builder()
.program(&scalar_multiply_broadcast_f16_program)
.name("scalar_multiply_broadcast_f16")
.arg(None::<&Buffer<u16>>)
.arg(None::<&Buffer<u16>>)
.arg(&0)
.arg(&0)
.queue(queue.clone())
.build()?;
let hadamard_product_broadcast_f16_program =
make_program_with_src(ctx, HADAMARD_PRODUCT_BROADCAST_F16_SRC)?;
let hadamard_product_broadcast_f16 = Kernel::builder()
.program(&hadamard_product_broadcast_f16_program)
.name("hadamard_product_broadcast_f16")
.arg(None::<&Buffer<u16>>)
.arg(None::<&Buffer<u16>>)
.arg(&0)
.arg(&0)
.queue(queue.clone())
.build()?;
let rsqrt_f16_program = make_program_with_src(ctx, RSQRT_F16_SRC)?;
let rsqrt_f16 = Kernel::builder()
.program(&rsqrt_f16_program)
.name("rsqrt_f16")
.arg(None::<&Buffer<u16>>)
.arg(&0)
.queue(queue.clone())
.build()?;
let add_f16_program = make_program_with_src(ctx, ADD_F16_SRC)?;
let add_f16 = Kernel::builder()
.program(&add_f16_program)
.name("add_f16")
.arg(None::<&Buffer<u16>>)
.arg(None::<&Buffer<u16>>)
.arg(&0)
.arg(&0)
.queue(queue.clone())
.build()?;
Ok(Programs {
matrix_mul_transposed_by_row_f16_program,
matrix_mul_transposed_by_row_f16,
@ -406,6 +658,20 @@ fn make_programs(ctx: &Context, queue: &Queue) -> Result<Programs, OpenCLError>
hadamard_product_f16,
transpose_f16_program,
transpose_f16,
pow_f16_program,
pow_f16,
mean_cols_f16_program,
mean_cols_f16,
add_scalar_f16_program,
add_scalar_f16,
scalar_multiply_broadcast_f16_program,
scalar_multiply_broadcast_f16,
hadamard_product_broadcast_f16_program,
hadamard_product_broadcast_f16,
rsqrt_f16_program,
rsqrt_f16,
add_f16_program,
add_f16,
})
}
@ -532,3 +798,131 @@ __kernel void transpose_f16(__global half *tgt,
vstore_half(val, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt);
}
"#;
/// Computes x^scalar for every f16 value in the tensor
const POW_F16_SRC: &str = r#"
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void pow_f16(__global half *tgt,
const int ncols_capacity,
const float scalar)
{
const int tgt_row = get_global_id(0);
const int tgt_col = get_global_id(1);
const float val = vload_half(tgt_row * ncols_capacity + tgt_col, (__global const half*) tgt);
const float result = pow(val, scalar);
vstore_half(result, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt);
}
"#;
/// Computes the mean of each column in a tensor
const MEAN_COLS_F16_SRC: &str = r#"
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void mean_cols_f16(__global half *tgt,
__global const half *left,
const int ncols_capacity,
const int left_cols_capacity,
const int ncolumns)
{
// global work group size is nrows x 1
const int row = get_global_id(0);
float16 src_value = 0.0;
for (int col16 = 0; col16 < left_cols_capacity; col16 += 16) {
const int actual_col = col16;
if (actual_col >= ncolumns) {
break;
}
src_value += vload_half16((row * left_cols_capacity)/16 + col16/16, (__global const half*) left);
}
float src_value_sum = src_value.s0 + src_value.s1 + src_value.s2 + src_value.s3 + src_value.s4 + src_value.s5 + src_value.s6 + src_value.s7 + src_value.s8 + src_value.s9 + src_value.sa + src_value.sb + src_value.sc + src_value.sd + src_value.se + src_value.sf;
src_value_sum = src_value_sum / (float) ncolumns;
vstore_half(src_value_sum, row * ncols_capacity, (__global half*) tgt);
}
"#;
/// Adds a scalar to a tensor
const ADD_SCALAR_F16_SRC: &str = r#"
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void add_scalar_f16(__global half *tgt, const int ncols_capacity, const float scalar)
{
const int tgt_row = get_global_id(0);
const int tgt_col = get_global_id(1);
const float val = vload_half(tgt_row * ncols_capacity + tgt_col, (__global const half*) tgt);
const float result = val + scalar;
vstore_half(result, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt);
}
"#;
/// Adds scalars from a row vector to each row of a tensor
const SCALAR_MULTIPLY_BROADCAST_F16_SRC: &str = r#"
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void scalar_multiply_broadcast_f16(__global half *tgt,
__global const half *left,
const int ncols_capacity,
const int left_cols_capacity)
{
// global work group size is nrows x (ncols/16)
const int row = get_global_id(0);
const int col = get_global_id(1) * 16;
const float scalar = vload_half(row * left_cols_capacity, (__global const half*) left);
float16 src_value = vload_half16((row * ncols_capacity)/16 + col/16, (__global const half*) tgt) * scalar;
vstore_half16(src_value, (row * ncols_capacity)/16 + col/16, (__global half*) tgt);
}
"#;
/// Does a hadamard product from a column vector to each column of a tensor
const HADAMARD_PRODUCT_BROADCAST_F16_SRC: &str = r#"
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void hadamard_product_broadcast_f16(__global half *tgt,
__global const half *left,
const int ncols_capacity,
const int left_cols_capacity)
{
// global work group size is nrows x (ncols/16)
const int row = get_global_id(0);
const int col16 = get_global_id(1) * 16;
const float16 product_value = vload_half16(col16/16, (__global const half*) left);
const float16 src_value = vload_half16((row * ncols_capacity)/16 + col16/16, (__global const half*) tgt);
const float16 result = src_value * product_value;
vstore_half16(result, (row * ncols_capacity)/16 + col16/16, (__global half*) tgt);
}
"#;
/// Computes 1/sqrt(x) for each f16 value in the tensor
const RSQRT_F16_SRC: &str = r#"
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void rsqrt_f16(__global half *tgt, const int ncols_capacity)
{
const int tgt_row = get_global_id(0);
const int tgt_col = get_global_id(1);
const float val = vload_half(tgt_row * ncols_capacity + tgt_col, (__global const half*) tgt);
const float result = rsqrt(val);
vstore_half(result, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt);
}
"#;
/// Computes sum of two tensors
const ADD_F16_SRC: &str = r#"
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void add_f16(__global half *tgt,
__global const half *left,
const int tgt_ncols_capacity,
const int left_ncols_capacity)
{
const int tgt_row = get_global_id(0);
const int tgt_col = get_global_id(1);
const float tgt_v = vload_half(tgt_row * tgt_ncols_capacity + tgt_col, (__global const half*) tgt);
const float left_v = vload_half(tgt_row * left_ncols_capacity + tgt_col, (__global const half*) left);
const float result = tgt_v + left_v;
vstore_half(result, tgt_row * tgt_ncols_capacity + tgt_col, (__global half*) tgt);
}
"#;

@ -8,6 +8,7 @@ use crate::unpickler::UnpicklingError;
use indicatif::ProgressBar;
use num_complex::Complex;
use rayon::prelude::*;
use std::mem::drop;
use std::path::Path;
use std::sync::{Arc, RwLock};
@ -38,6 +39,8 @@ pub struct DataSettings {
#[cfg(feature = "opencl")]
use_opencl_for_attention: bool,
#[cfg(feature = "opencl")]
use_opencl_for_rmsnorm: bool,
#[cfg(feature = "opencl")]
cl: Option<OpenCL>,
}
@ -51,6 +54,7 @@ impl DataSettings {
DataSettings {
use_opencl_for_feedforward: false,
use_opencl_for_attention: false,
use_opencl_for_rmsnorm: false,
cl: cl.clone(),
}
}
@ -67,6 +71,7 @@ impl DataSettings {
}
self.use_opencl_for_feedforward = true;
self.use_opencl_for_attention = true;
self.use_opencl_for_rmsnorm = true;
self
}
}
@ -137,6 +142,7 @@ impl TransformerCaches {
pub struct RMSNorm {
eps: f64,
weight: Tensor,
data_settings: DataSettings,
}
pub struct Attention {
@ -195,9 +201,15 @@ impl Transformer {
result
})
.collect::<Result<Vec<TransformerBlock>, UnpicklingError>>()?;
std::mem::drop(progress_bar);
drop(progress_bar);
let norm = RMSNorm::from_unpickled(unpickled, "norm.weight".to_string(), eps, data_dir)?;
let norm = RMSNorm::from_unpickled(
unpickled,
"norm.weight".to_string(),
eps,
data_settings.clone(),
data_dir,
)?;
let output = Tensor::from_unpickled_pieces(
unpickled,
"output.weight",
@ -261,18 +273,23 @@ impl Transformer {
embs.push(emb);
}
let mut emb_tensor: Tensor = Tensor::concat(&embs);
std::mem::drop(embs);
drop(embs);
for (idx, layer) in self.layers.iter().enumerate() {
emb_tensor = layer.forward(
&emb_tensor,
&mut emb_tensor,
start_pos,
&self.freqs_cis,
&mask,
&mut caches.layer_caches[idx],
);
}
let out = self.norm.forward(&emb_tensor);
let mut out = self.norm.forward(&mut emb_tensor);
#[cfg(feature = "opencl")]
if out.is_on_gpu() {
out.to_cpu().unwrap();
out = out.to_f32();
}
let out = out.row(out.rows() - 1);
self.output.matrix_mul_transposed(&out)
@ -296,19 +313,21 @@ impl TransformerBlock {
layer_id,
n_local_heads,
head_dim,
data_settings,
data_settings.clone(),
data_dir,
)?;
let ffn_norm = RMSNorm::from_unpickled(
unpickled,
format!("layers.{}.ffn_norm.weight", layer_id),
eps,
data_settings.clone(),
data_dir,
)?;
let attn_norm = RMSNorm::from_unpickled(
unpickled,
format!("layers.{}.attention_norm.weight", layer_id),
eps,
data_settings.clone(),
data_dir,
)?;
Ok(Self {
@ -321,26 +340,61 @@ impl TransformerBlock {
pub fn forward(
&self,
x: &Tensor,
x: &mut Tensor,
start_pos: usize,
freqs_cis: &FreqsCis,
mask: &Option<Tensor>,
attention_cache: &mut AttentionCache,
) -> Tensor {
let now = std::time::Instant::now();
let mut attnorm_out = self.attention_norm.forward(x);
let att_out = self.attn.forward(
let now = std::time::Instant::now();
let mut att_out = self.attn.forward(
&mut attnorm_out,
start_pos,
freqs_cis,
mask,
attention_cache,
);
std::mem::drop(attnorm_out);
let now = std::time::Instant::now();
drop(attnorm_out);
let h = x.add(&att_out);
let mut att_out = self.ffn_norm.forward(&h);
#[cfg(feature = "opencl")]
let mut x_was_on_cpu: bool;
#[cfg(feature = "opencl")]
{
x_was_on_cpu = x.is_on_cpu();
if x_was_on_cpu {
*x = x.to_f16();
x.to_gpu(self.attention_norm.data_settings.cl.as_ref().unwrap())
.unwrap();
}
if x.is_on_gpu() {
att_out = att_out.to_f16();
att_out
.to_gpu(self.attention_norm.data_settings.cl.as_ref().unwrap())
.unwrap();
}
}
let mut h = x.add(&att_out);
let now = std::time::Instant::now();
let mut att_out = self.ffn_norm.forward(&mut h);
let now = std::time::Instant::now();
let att_out = self.feed_forward.forward(&mut att_out).transpose();
h.add(&att_out)
let mut result = h.add(&att_out);
#[cfg(feature = "opencl")]
{
if x_was_on_cpu {
result.to_cpu().unwrap();
return result.to_f32();
} else {
result
}
}
#[cfg(not(feature = "opencl"))]
{
result
}
}
}
@ -349,26 +403,64 @@ impl RMSNorm {
unpickled: &[unpickler::Value],
name: String,
eps: f64,
data_settings: DataSettings,
data_dir: P,
) -> Result<RMSNorm, UnpicklingError> {
let data_dir: &Path = data_dir.as_ref();
let weights = Tensor::from_unpickled_pieces(
let mut weights = Tensor::from_unpickled_pieces(
&unpickled[0..=0],
name.clone(),
data_dir,
FromPiecesDirection::Rows,
)?
.to_f32();
)?;
#[cfg(feature = "opencl")]
{
if data_settings.use_opencl_for_rmsnorm {
weights = weights.to_f16();
let ds = data_settings.clone();
weights.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap();
} else {
weights = weights.to_f32();
}
}
#[cfg(not(feature = "opencl"))]
{
weights = weights.to_f32();
}
Ok(Self {
eps,
weight: weights,
data_settings,
})
}
fn forward(&self, x: &Tensor) -> Tensor {
fn forward(&self, x: &mut Tensor) -> Tensor {
#[cfg(feature = "opencl")]
let x_was_on_cpu: bool;
#[cfg(feature = "opencl")]
{
x_was_on_cpu = x.is_on_cpu();
if self.data_settings.use_opencl_for_rmsnorm && x_was_on_cpu {
*x = x.to_f16();
x.to_gpu(self.data_settings.cl.as_ref().unwrap()).unwrap();
}
}
let inner = x.pow(2.0).mean_cols().add_scalar(self.eps as f32);
let out1 = x.scalar_multiply_broadcast(&inner.rsqrt());
out1.hadamard_product_broadcast(&self.weight)
let mut result = out1.hadamard_product_broadcast(&self.weight);
#[cfg(feature = "opencl")]
{
if x_was_on_cpu {
result.to_cpu().unwrap();
}
result
}
#[cfg(not(feature = "opencl"))]
{
result
}
}
}
@ -410,6 +502,10 @@ impl FeedForward {
w1.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap();
w2.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap();
w3.to_gpu(&ds.cl.unwrap()).unwrap();
} else {
w1 = w1.to_f32();
w2 = w2.to_f32();
w3 = w3.to_f32();
}
}
#[cfg(not(feature = "opencl"))]
@ -433,7 +529,7 @@ impl FeedForward {
#[cfg(feature = "opencl")]
{
x_was_on_cpu = x.is_on_cpu();
if self.data_settings.use_opencl_for_feedforward {
if self.data_settings.use_opencl_for_feedforward && x_was_on_cpu {
*x = x.to_f16();
x.to_gpu(self.data_settings.cl.as_ref().unwrap()).unwrap();
}
@ -514,6 +610,11 @@ impl Attention {
wk.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap();
wv.to_gpu(&ds.cl.as_ref().unwrap().clone()).unwrap();
wo.to_gpu(&ds.cl.unwrap()).unwrap();
} else {
wq = wq.to_f32();
wk = wk.to_f32();
wv = wv.to_f32();
wo = wo.to_f32();
}
}
#[cfg(not(feature = "opencl"))]
@ -548,7 +649,7 @@ impl Attention {
#[cfg(feature = "opencl")]
{
x_was_on_cpu = x.is_on_cpu();
if self.data_settings.use_opencl_for_attention {
if self.data_settings.use_opencl_for_attention && x_was_on_cpu {
*x = x.to_f16();
x.to_gpu(self.data_settings.cl.as_ref().unwrap()).unwrap();
}
@ -686,8 +787,21 @@ impl Attention {
.collect();
let output3: Vec<&Tensor> = output2.iter().collect();
let output2: Tensor = Tensor::concat(&output3);
output2
let mut output2: Tensor = Tensor::concat(&output3);
#[cfg(feature = "opencl")]
{
if x_was_on_cpu {
output2.to_cpu().unwrap();
return output2.to_f32();
} else {
return output2;
}
}
#[cfg(not(feature = "opencl"))]
{
output2
}
}
}

Loading…
Cancel
Save