diff --git a/Cargo.lock b/Cargo.lock index 87a9224..56d4bfa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -29,7 +29,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" dependencies = [ - "num-traits", + "num-traits 0.2.15", ] [[package]] @@ -106,6 +106,15 @@ dependencies = [ "half 1.8.2", ] +[[package]] +name = "cl-sys" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8573fa3ff8acd6c49e8e113296c54277e82376b96c6ca6307848632cce38e44" +dependencies = [ + "libc", +] + [[package]] name = "clap" version = "3.2.23" @@ -202,7 +211,7 @@ dependencies = [ "criterion-plot", "itertools", "lazy_static", - "num-traits", + "num-traits 0.2.15", "oorandom", "plotters", "rayon", @@ -224,6 +233,20 @@ dependencies = [ "itertools", ] +[[package]] +name = "crossbeam" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c" +dependencies = [ + "cfg-if", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + [[package]] name = "crossbeam-channel" version = "0.5.7" @@ -258,6 +281,16 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.15" @@ -294,6 +327,15 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +[[package]] +name = "enum_primitive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4551092f4d519593039259a9ed8daedf0da12e5109c5280338073eaeb81180" +dependencies = [ + "num-traits 0.1.43", +] + [[package]] name = "errno" version = "0.2.8" @@ -333,6 +375,12 @@ dependencies = [ "gcd", ] +[[package]] +name = "futures" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a471a38ef8ed83cd6e40aa59c1ffe17db6855c18e3604d9c4ed8c08ebc28678" + [[package]] name = "gcd" version = "2.3.0" @@ -520,13 +568,28 @@ dependencies = [ "autocfg", ] +[[package]] +name = "nodrop" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb" + [[package]] name = "num-complex" version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" dependencies = [ - "num-traits", + "num-traits 0.2.15", +] + +[[package]] +name = "num-traits" +version = "0.1.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92e5113e9fd4cc14ded8e499429f396a20f98c772a47cc8622a736e1ec843c31" +dependencies = [ + "num-traits 0.2.15", ] [[package]] @@ -554,6 +617,45 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "ocl" +version = "0.19.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e8b2e511640775a4d2f0f408501ffdd8813d6c6bcceafdb4e3867d2c98471c6" +dependencies = [ + "futures", + "nodrop", + "num-traits 0.2.15", + "ocl-core", + "qutex", + "thiserror", +] + +[[package]] +name = "ocl-core" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66d46a57dfd1bb6fbee28986e92d39db9b941719dcf6b539c64242006d3c030d" +dependencies = [ + "bitflags", + "cl-sys", + "enum_primitive", + "num-complex", + "num-traits 0.2.15", + "ocl-core-vector", + "rustc_version", + "thiserror", +] + +[[package]] +name = "ocl-core-vector" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f562279e046ca160aeed5eaf6f7c4eb9fa56cb8fd9d038dbdbf56225caeb8074" +dependencies = [ + "num-traits 0.2.15", +] + [[package]] name = "once_cell" version = "1.17.1" @@ -578,7 +680,7 @@ version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2538b639e642295546c50fcd545198c9d64ee2a38620a628724a3b266d5fbf97" dependencies = [ - "num-traits", + "num-traits 0.2.15", "plotters-backend", "plotters-svg", "wasm-bindgen", @@ -705,6 +807,16 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "qutex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cda4a51ba3d773c196f9450a6b239077ad8dda608b15263b4c9f29e58909883f" +dependencies = [ + "crossbeam", + "futures", +] + [[package]] name = "rand" version = "0.8.5" @@ -795,6 +907,7 @@ dependencies = [ "half 2.2.1", "indicatif", "num-complex", + "ocl", "protobuf", "protobuf-codegen", "protobuf-parse", @@ -805,6 +918,15 @@ dependencies = [ "thiserror", ] +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.36.8" @@ -840,6 +962,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "semver" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a" + [[package]] name = "serde" version = "1.0.152" diff --git a/Cargo.toml b/Cargo.toml index 7ea2f16..fa9f64b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,10 @@ indicatif = "0.17" colored = "2" serde = { version = "1", features = ["derive"] } serde_json = "1" +ocl = "0.19" + +[features] +opencl = [] # We need protobuf compiler [build-dependencies] diff --git a/src/benches/benchmark.rs b/src/benches/benchmark.rs index 9a4c5fa..908fb46 100644 --- a/src/benches/benchmark.rs +++ b/src/benches/benchmark.rs @@ -1,9 +1,24 @@ extern crate rllama; +#[cfg(feature = "opencl")] +use rllama::tensor_opencl_support::OpenCL; use rllama::tensor::{Tensor, TensorDType}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +#[cfg(feature = "opencl")] +pub fn opencl_benchmarks(c: &mut Criterion) { + let orig16 = Tensor::random(1024, 1024, TensorDType::Float16); + let cl = OpenCL::new(false, 0).unwrap(); + + c.bench_function("1024x1024 matrix from CPU to OpenCL device", |b| { + b.iter(|| { + let mut orig16 = orig16.clone(); + let _ = orig16.to_gpu(&cl); + }) + }); +} + pub fn tensor_benchmarks(c: &mut Criterion) { let orig16_1 = Tensor::full(16, 32, TensorDType::Float16, 3.0); let orig16_2 = Tensor::full(32, 512, TensorDType::Float16, -1.33); @@ -96,5 +111,8 @@ pub fn tensor_benchmarks(c: &mut Criterion) { }); } +#[cfg(feature = "opencl")] +criterion_group!(benches, opencl_benchmarks, tensor_benchmarks); +#[cfg(not(feature = "opencl"))] criterion_group!(benches, tensor_benchmarks); criterion_main!(benches); diff --git a/src/lib.rs b/src/lib.rs index 4ce1c6b..5e1cf29 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,8 @@ pub mod embedding; pub mod protomodels; pub mod rllama_main; pub mod tensor; +#[cfg(feature = "opencl")] +pub mod tensor_opencl_support; pub mod token_sampler; pub mod tokenizer; pub mod transformer; diff --git a/src/rllama_main.rs b/src/rllama_main.rs index 7ee7f21..c9863fd 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -1,4 +1,6 @@ use crate::embedding::Embedding; +#[cfg(feature = "opencl")] +use crate::tensor_opencl_support::OpenCL; use crate::token_sampler::TokenSampler; use crate::tokenizer::{TokenId, Tokenizer}; use crate::transformer::Transformer; @@ -34,6 +36,9 @@ struct Cli { top_p: Option, #[arg(long)] top_k: Option, + + #[cfg(feature = "opencl")] + opencl_device: Option, } #[derive(Clone, Serialize, Deserialize)] @@ -57,6 +62,21 @@ pub fn main() -> Result<(), Box> { be_quiet = true; } + #[cfg(feature = "opencl")] + let opencl: Option = { + let opencl_device = cli.opencl_device.unwrap_or(0); + match OpenCL::new(!be_quiet, opencl_device) { + Err(openclerr) => { + eprintln!("OpenCL error: {}", openclerr); + None + } + Ok(opencl) => { + println!("OpenCL initialized."); + Some(opencl) + } + } + }; + // Custom println-like macro that respects be_quiet macro_rules! pln { ($($arg:tt)*) => { diff --git a/src/tensor.rs b/src/tensor.rs index f1f93e9..6857b2b 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "opencl")] +use crate::tensor_opencl_support::{OpenCL, OpenCLError, OpenCLTensor}; use crate::unpickler; use crate::unpickler::UnpicklingError; use half::f16; @@ -40,6 +42,9 @@ pub enum TensorError { TensorBuilderRowsMismatch(i64, i64), #[error("Tried to build a tensor from multiple files but the data types do not agree between the files. {0:?} != {1:?}")] TensorBuilderDTypeMismatch(TensorDType, TensorDType), + #[cfg(feature = "opencl")] + #[error("OpenCL error")] + OpenCLError(#[from] OpenCLError), } impl TensorDType { @@ -54,6 +59,9 @@ impl TensorDType { #[derive(Debug)] pub struct Tensor { data: *mut u8, + #[cfg(feature = "opencl")] + opencl_data: Option, + dtype: TensorDType, layout: Layout, rows: i64, @@ -125,6 +133,16 @@ fn horizontal_sum(mut ymm: __m256) -> f32 { } impl Tensor { + #[inline] + pub fn assume_on_cpu(&self) { + #[cfg(feature = "opencl")] + { + if self.opencl_data.is_some() { + panic!("Tried to assume_on_cpu on a tensor that is on the GPU"); + } + } + } + pub fn from_unpickled, S: AsRef>( unpickled: &unpickler::Value, name: S, @@ -175,6 +193,7 @@ impl Tensor { // Gets a value as f32 from the tensor. #[inline] pub fn get_f32(&self, row: i64, col: i64) -> f32 { + self.assume_on_cpu(); assert!( row >= 0 && col >= 0 && row < self.rows && col < self.cols, "Invalid index: {}, {} Size: {}, {}", @@ -183,6 +202,7 @@ impl Tensor { self.rows, self.cols ); + let idx = row * self.capacity_cols + col; match self.dtype { TensorDType::Float16 => { @@ -199,6 +219,7 @@ impl Tensor { // Sets a value from f32. The value is cast into whatever the tensor's dtype is. #[inline] pub fn set_f32(&mut self, row: i64, col: i64, val: f32) { + self.assume_on_cpu(); let idx = row * self.capacity_cols + col; match self.dtype { TensorDType::Float16 => { @@ -214,6 +235,7 @@ impl Tensor { // Converts the tensor to two-dimensional Vec. // Meant for debugging and making it easy to print tensors. pub fn to_vec(&self) -> Vec> { + self.assume_on_cpu(); let mut result = Vec::new(); for row in 0..self.rows { let mut row_vec = Vec::new(); @@ -229,6 +251,8 @@ impl Tensor { pub fn empty() -> Self { Self { data: std::ptr::null_mut(), + #[cfg(feature = "opencl")] + opencl_data: None, dtype: TensorDType::Float16, layout: Layout::from_size_align(0, 0).unwrap(), rows: 0, @@ -274,6 +298,8 @@ impl Tensor { Self { data, + #[cfg(feature = "opencl")] + opencl_data: None, dtype, rows, cols, @@ -294,6 +320,7 @@ impl Tensor { // Runs softmax on row dimension. pub fn softmax(&self) -> Tensor { + self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { let mut sum = 0.0; @@ -325,6 +352,7 @@ impl Tensor { // Computes mean for each row, so that columns become 1. pub fn mean_cols(&self) -> Tensor { + self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, 1, self.dtype) }; for row in 0..self.rows { let mut sum = 0.0; @@ -337,6 +365,7 @@ impl Tensor { } pub fn mean(&self) -> Tensor { + self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(1, 1, self.dtype) }; let mut sum = 0.0; for row in 0..self.rows { @@ -349,6 +378,7 @@ impl Tensor { } pub fn pow(&self, power: f32) -> Tensor { + self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { for col in 0..self.cols { @@ -360,6 +390,7 @@ impl Tensor { } pub fn sqrt(&self) -> Tensor { + self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { for col in 0..self.cols { @@ -371,6 +402,7 @@ impl Tensor { } pub fn rsqrt(&self) -> Tensor { + self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { for col in 0..self.cols { @@ -382,6 +414,8 @@ impl Tensor { } pub fn add(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); + other.assume_on_cpu(); if self.rows() != other.rows() || self.cols() != other.cols() { panic!( "add: Tensors must have the same shape, left: {}x{} right: {}x{}", @@ -402,6 +436,7 @@ impl Tensor { } pub fn add_scalar(&self, scalar: f32) -> Tensor { + self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { for col in 0..self.cols { @@ -413,6 +448,7 @@ impl Tensor { } pub fn scalar_multiply_f32(&self, scalar: f32) -> Tensor { + self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { for col in 0..self.cols { @@ -424,6 +460,7 @@ impl Tensor { } pub fn scalar_multiply_broadcast(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); if other.cols != 1 { panic!("Invalid scalar broadcast"); } @@ -442,6 +479,7 @@ impl Tensor { } pub fn scalar_product(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); if other.cols != 1 || other.rows != 1 { panic!("Invalid scalar product"); } @@ -457,6 +495,8 @@ impl Tensor { } pub fn hadamard_product_broadcast(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); + other.assume_on_cpu(); if self.cols != other.cols { panic!( "Invalid hadamard product broadcast: {}x{} vs {}x{}", @@ -480,6 +520,8 @@ impl Tensor { } pub fn hadamard_product(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); + other.assume_on_cpu(); if self.cols != other.cols || self.rows != other.rows { panic!( "Invalid hadamard product: incompatible shapes, {}x{} vs {}x{}", @@ -516,6 +558,7 @@ impl Tensor { unsafe { Tensor::uninitialized(total_rows, expected_cols, pieces[0].dtype) }; let mut row_offset = 0; for piece in pieces { + piece.assume_on_cpu(); for row in 0..piece.rows { for col in 0..piece.cols { let val = piece.get_f32(row, col); @@ -528,6 +571,7 @@ impl Tensor { } pub fn silu(&self) -> Tensor { + self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; for row in 0..self.rows { for col in 0..self.cols { @@ -540,6 +584,7 @@ impl Tensor { } pub fn transpose(&self) -> Tensor { + self.assume_on_cpu(); let mut result = unsafe { Tensor::uninitialized(self.cols, self.rows, self.dtype) }; for row in 0..self.rows { for col in 0..self.cols { @@ -554,6 +599,8 @@ impl Tensor { /// /// This is used as a reference to test correctness of other matrix multiplications. pub fn matrix_mul_naive(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); + other.assume_on_cpu(); if self.cols != other.rows { panic!( "Invalid matrix multiplication {}x{} vs {}x{}", @@ -574,6 +621,8 @@ impl Tensor { } pub fn matrix_mul(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); + other.assume_on_cpu(); if self.cols != other.rows { panic!( "Invalid matrix multiplication {}x{} vs {}x{}", @@ -592,6 +641,8 @@ impl Tensor { } pub fn matrix_mul_transposed(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); + other.assume_on_cpu(); if self.cols != other.cols { panic!( "Invalid matrix transposed multiplication {}x{} vs {}x{}", @@ -608,6 +659,9 @@ impl Tensor { /// Matrix multiplication done in-place pub fn matrix_mul_inplace(&mut self, src: &Tensor, other: &Tensor) { + self.assume_on_cpu(); + src.assume_on_cpu(); + other.assume_on_cpu(); if src.cols != other.rows { panic!( "Invalid matrix multiplication {}x{} vs {}x{}", @@ -752,6 +806,9 @@ impl Tensor { /// Matrix multiplication done in-place, but the second matrix is transposed. /// With this, you can avoid using .transpose() on the second matrix. pub fn matrix_mul_inplace_transposed(&mut self, src: &Tensor, other: &Tensor) { + self.assume_on_cpu(); + src.assume_on_cpu(); + other.assume_on_cpu(); if src.cols != other.cols { panic!( "Invalid matrix multiplication {}x{} vs {}x{}", @@ -827,6 +884,8 @@ impl Tensor { // // AxB @ Cx1 = Ax1 pub fn matrix_vector_mul(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); + other.assume_on_cpu(); // TODO: this function is not optimized. if self.cols != other.rows { panic!( @@ -851,6 +910,8 @@ impl Tensor { /// Same as matrix_vector_mul, but right side is assumed to be transposed. pub fn matrix_vector_mul_transposed(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); + other.assume_on_cpu(); if self.cols != other.cols { panic!( "Invalid matrix-vector transposed multiplication {}x{} vs {}x{}", @@ -888,6 +949,8 @@ impl Tensor { /// Same as matrix_vector_mul but uses threading. pub fn matrix_vector_mul_transposed_multithreaded(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); + other.assume_on_cpu(); if self.cols != other.cols { panic!( "Invalid matrix-vector transposed multiplication {}x{} vs {}x{}", @@ -957,6 +1020,8 @@ impl Tensor { #[allow(clippy::erasing_op)] #[allow(clippy::identity_op)] pub fn vector_matrix_mul(&self, other: &Tensor) -> Tensor { + self.assume_on_cpu(); + other.assume_on_cpu(); if self.cols != other.rows { panic!( "Invalid matrix-vector multiplication {}x{} vs {}x{}", @@ -1055,6 +1120,8 @@ impl Tensor { } Self { data, + #[cfg(feature = "opencl")] + opencl_data: None, dtype, rows, cols, @@ -1064,6 +1131,7 @@ impl Tensor { } pub fn clip_cols(&self, cols: usize) -> Tensor { + self.assume_on_cpu(); if cols == 0 { return Self::empty(); } @@ -1087,6 +1155,7 @@ impl Tensor { } pub fn view(&self, rows: i64, cols: i64) -> Tensor { + self.assume_on_cpu(); if rows * cols != self.rows * self.cols { panic!( "Invalid tensor view, requested {}x{} but tensor is {}x{}", @@ -1144,8 +1213,40 @@ impl Tensor { } } + /// Sends a tensor to the GPU. This is a no-op if GPU support is not enabled, or if the tensor + /// is already on the GPU. + /// + /// The tensor is moved asynchronously. + #[cfg(feature = "opencl")] + pub fn to_gpu(&mut self, cl: &OpenCL) -> Result<(), TensorError> { + if self.opencl_data.is_some() { + return Ok(()); + } + if self.dtype != TensorDType::Float16 { + panic!("Only float16 tensors are supported on the GPU"); + } + let cl_tensor = cl.data_u16_to_gpu( + self.data as *const u16, + self.layout, + (self.rows * self.capacity_cols) as usize, + )?; + self.data = std::ptr::null_mut(); + self.opencl_data = Some(cl_tensor); + Ok(()) + } + + /// Make sure that the tensor has finished going to GPU. Used mostly for benchmarking. + #[cfg(feature = "opencl")] + pub fn wait_until_on_gpu(&mut self) { + if self.opencl_data.is_none() { + panic!("wait_until_on_gpu: Tensor is not on GPU"); + } + self.opencl_data.as_mut().unwrap().wait_until_ready(); + } + /// Naive implementation of to_f32, used for testing that the faster methods are correct. pub fn to_f32_naive(&self) -> Tensor { + self.assume_on_cpu(); if self.dtype == TensorDType::Float32 { return self.clone(); } @@ -1162,6 +1263,7 @@ impl Tensor { } pub fn to_f32(&self) -> Tensor { + self.assume_on_cpu(); if self.dtype == TensorDType::Float32 { return self.clone(); } @@ -1196,6 +1298,7 @@ impl Tensor { /// Naive implementation of to_f16, used for testing that the faster methods are correct. pub fn to_f16_naive(&self) -> Tensor { + self.assume_on_cpu(); if self.dtype == TensorDType::Float16 { return self.clone(); } @@ -1212,6 +1315,7 @@ impl Tensor { } pub fn to_f16(&self) -> Tensor { + self.assume_on_cpu(); if self.dtype == TensorDType::Float16 { return self.clone(); } @@ -1245,6 +1349,7 @@ impl Tensor { } pub fn row(&self, row: i64) -> Tensor { + self.assume_on_cpu(); if row < 0 || row > self.rows { panic!("Invalid row index"); } diff --git a/src/tensor_opencl_support.rs b/src/tensor_opencl_support.rs new file mode 100644 index 0000000..fddb502 --- /dev/null +++ b/src/tensor_opencl_support.rs @@ -0,0 +1,123 @@ +/* + * OpenCL stuff to run (some) of the tensor operations. + */ + +use ocl::{Buffer, Context, Device, Event, Platform, Queue}; +use std::alloc::Layout; +use thiserror::Error; + +#[derive(Debug)] +pub struct OpenCL { + ctx: Context, + queue: Queue, +} + +#[derive(Debug)] +pub struct OpenCLTensor { + buf: Buffer, // really is f16 + write_event: Option, // if Some, the buffer is being written to + data: *const u16, // if non-null, is host pointer that should be freed + data_layout: Layout, +} + +impl Drop for OpenCLTensor { + fn drop(&mut self) { + if !self.data.is_null() { + if self.write_event.is_some() { + self.write_event.as_ref().unwrap().wait_for().unwrap(); + } + unsafe { + std::alloc::dealloc(self.data as *mut u8, self.data_layout); + } + } + } +} + +#[derive(Error, Debug)] +pub enum OpenCLError { + #[error("OpenCL error: {0}")] + OpenCL(#[from] ocl::Error), + #[error("Cannot select device")] + OpenCLDeviceSelection, +} + +impl OpenCL { + pub fn new(verbose: bool, nth_device: usize) -> Result { + let platforms = Platform::list(); + let mut devices: Vec<(Platform, Device)> = Vec::new(); + for platform in platforms { + for device in Device::list_all(platform)? { + devices.push((platform, device)); + } + } + if verbose { + println!("Enumerating OpenCL devices:"); + } + for (idx, (_, device)) in devices.iter().enumerate() { + if verbose { + println!("OpenCL {} device: {}", idx, device.name()?); + } + } + if nth_device > devices.len() { + return Err(OpenCLError::OpenCLDeviceSelection); + } + if verbose { + println!("---"); + println!("Selected OpenCL device: {}", devices[nth_device].1.name()?); + } + + let ctx = Context::builder() + .platform(devices[nth_device].0) + .devices(devices[nth_device].1) + .build()?; + + let queue = Queue::new(&ctx, devices[nth_device].1, None)?; + + Ok(OpenCL { ctx, queue }) + } + + pub fn data_u16_to_gpu( + &self, + data: *const u16, + data_layout: Layout, + nitems: usize, + ) -> Result { + unsafe { + let buf = Buffer::builder() + .queue(self.queue.clone()) + .len(nitems) + .build()?; + let mut event = Event::empty(); + let data_slice: &[u16] = std::slice::from_raw_parts(data, nitems); + buf.cmd() + .write(data_slice) + .block(false) + .enew(&mut event) + .enq()?; + Ok(OpenCLTensor { + buf, + write_event: Some(event), + data, + data_layout, + }) + } + } +} + +impl OpenCLTensor { + pub fn wait_until_ready(&mut self) { + if self.write_event.is_none() { + return; + } + self.write_event.as_ref().unwrap().wait_for().unwrap(); + self.write_event = None; + if !self.data.is_null() { + if self.write_event.is_some() { + self.write_event.as_ref().unwrap().wait_for().unwrap(); + } + unsafe { + std::alloc::dealloc(self.data as *mut u8, self.data_layout); + } + } + } +}