Add some beginnings of OpenCL implementation.

I think I'll try to get the smaller modules run faster.
broken-opencl-code
Mikko Juola 3 years ago
parent 846759b277
commit 53d367e6fa

136
Cargo.lock generated

@ -29,7 +29,7 @@ version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
"num-traits 0.2.15",
]
[[package]]
@ -106,6 +106,15 @@ dependencies = [
"half 1.8.2",
]
[[package]]
name = "cl-sys"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8573fa3ff8acd6c49e8e113296c54277e82376b96c6ca6307848632cce38e44"
dependencies = [
"libc",
]
[[package]]
name = "clap"
version = "3.2.23"
@ -202,7 +211,7 @@ dependencies = [
"criterion-plot",
"itertools",
"lazy_static",
"num-traits",
"num-traits 0.2.15",
"oorandom",
"plotters",
"rayon",
@ -224,6 +233,20 @@ dependencies = [
"itertools",
]
[[package]]
name = "crossbeam"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c"
dependencies = [
"cfg-if",
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.7"
@ -258,6 +281,16 @@ dependencies = [
"scopeguard",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.15"
@ -294,6 +327,15 @@ version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"
[[package]]
name = "enum_primitive"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4551092f4d519593039259a9ed8daedf0da12e5109c5280338073eaeb81180"
dependencies = [
"num-traits 0.1.43",
]
[[package]]
name = "errno"
version = "0.2.8"
@ -333,6 +375,12 @@ dependencies = [
"gcd",
]
[[package]]
name = "futures"
version = "0.1.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a471a38ef8ed83cd6e40aa59c1ffe17db6855c18e3604d9c4ed8c08ebc28678"
[[package]]
name = "gcd"
version = "2.3.0"
@ -520,13 +568,28 @@ dependencies = [
"autocfg",
]
[[package]]
name = "nodrop"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb"
[[package]]
name = "num-complex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d"
dependencies = [
"num-traits",
"num-traits 0.2.15",
]
[[package]]
name = "num-traits"
version = "0.1.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92e5113e9fd4cc14ded8e499429f396a20f98c772a47cc8622a736e1ec843c31"
dependencies = [
"num-traits 0.2.15",
]
[[package]]
@ -554,6 +617,45 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "ocl"
version = "0.19.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e8b2e511640775a4d2f0f408501ffdd8813d6c6bcceafdb4e3867d2c98471c6"
dependencies = [
"futures",
"nodrop",
"num-traits 0.2.15",
"ocl-core",
"qutex",
"thiserror",
]
[[package]]
name = "ocl-core"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "66d46a57dfd1bb6fbee28986e92d39db9b941719dcf6b539c64242006d3c030d"
dependencies = [
"bitflags",
"cl-sys",
"enum_primitive",
"num-complex",
"num-traits 0.2.15",
"ocl-core-vector",
"rustc_version",
"thiserror",
]
[[package]]
name = "ocl-core-vector"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f562279e046ca160aeed5eaf6f7c4eb9fa56cb8fd9d038dbdbf56225caeb8074"
dependencies = [
"num-traits 0.2.15",
]
[[package]]
name = "once_cell"
version = "1.17.1"
@ -578,7 +680,7 @@ version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2538b639e642295546c50fcd545198c9d64ee2a38620a628724a3b266d5fbf97"
dependencies = [
"num-traits",
"num-traits 0.2.15",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
@ -705,6 +807,16 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "qutex"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cda4a51ba3d773c196f9450a6b239077ad8dda608b15263b4c9f29e58909883f"
dependencies = [
"crossbeam",
"futures",
]
[[package]]
name = "rand"
version = "0.8.5"
@ -795,6 +907,7 @@ dependencies = [
"half 2.2.1",
"indicatif",
"num-complex",
"ocl",
"protobuf",
"protobuf-codegen",
"protobuf-parse",
@ -805,6 +918,15 @@ dependencies = [
"thiserror",
]
[[package]]
name = "rustc_version"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366"
dependencies = [
"semver",
]
[[package]]
name = "rustix"
version = "0.36.8"
@ -840,6 +962,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "semver"
version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a"
[[package]]
name = "serde"
version = "1.0.152"

@ -24,6 +24,10 @@ indicatif = "0.17"
colored = "2"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
ocl = "0.19"
[features]
opencl = []
# We need protobuf compiler
[build-dependencies]

@ -1,9 +1,24 @@
extern crate rllama;
#[cfg(feature = "opencl")]
use rllama::tensor_opencl_support::OpenCL;
use rllama::tensor::{Tensor, TensorDType};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
#[cfg(feature = "opencl")]
pub fn opencl_benchmarks(c: &mut Criterion) {
let orig16 = Tensor::random(1024, 1024, TensorDType::Float16);
let cl = OpenCL::new(false, 0).unwrap();
c.bench_function("1024x1024 matrix from CPU to OpenCL device", |b| {
b.iter(|| {
let mut orig16 = orig16.clone();
let _ = orig16.to_gpu(&cl);
})
});
}
pub fn tensor_benchmarks(c: &mut Criterion) {
let orig16_1 = Tensor::full(16, 32, TensorDType::Float16, 3.0);
let orig16_2 = Tensor::full(32, 512, TensorDType::Float16, -1.33);
@ -96,5 +111,8 @@ pub fn tensor_benchmarks(c: &mut Criterion) {
});
}
#[cfg(feature = "opencl")]
criterion_group!(benches, opencl_benchmarks, tensor_benchmarks);
#[cfg(not(feature = "opencl"))]
criterion_group!(benches, tensor_benchmarks);
criterion_main!(benches);

@ -4,6 +4,8 @@ pub mod embedding;
pub mod protomodels;
pub mod rllama_main;
pub mod tensor;
#[cfg(feature = "opencl")]
pub mod tensor_opencl_support;
pub mod token_sampler;
pub mod tokenizer;
pub mod transformer;

@ -1,4 +1,6 @@
use crate::embedding::Embedding;
#[cfg(feature = "opencl")]
use crate::tensor_opencl_support::OpenCL;
use crate::token_sampler::TokenSampler;
use crate::tokenizer::{TokenId, Tokenizer};
use crate::transformer::Transformer;
@ -34,6 +36,9 @@ struct Cli {
top_p: Option<f32>,
#[arg(long)]
top_k: Option<i32>,
#[cfg(feature = "opencl")]
opencl_device: Option<usize>,
}
#[derive(Clone, Serialize, Deserialize)]
@ -57,6 +62,21 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
be_quiet = true;
}
#[cfg(feature = "opencl")]
let opencl: Option<OpenCL> = {
let opencl_device = cli.opencl_device.unwrap_or(0);
match OpenCL::new(!be_quiet, opencl_device) {
Err(openclerr) => {
eprintln!("OpenCL error: {}", openclerr);
None
}
Ok(opencl) => {
println!("OpenCL initialized.");
Some(opencl)
}
}
};
// Custom println-like macro that respects be_quiet
macro_rules! pln {
($($arg:tt)*) => {

@ -1,3 +1,5 @@
#[cfg(feature = "opencl")]
use crate::tensor_opencl_support::{OpenCL, OpenCLError, OpenCLTensor};
use crate::unpickler;
use crate::unpickler::UnpicklingError;
use half::f16;
@ -40,6 +42,9 @@ pub enum TensorError {
TensorBuilderRowsMismatch(i64, i64),
#[error("Tried to build a tensor from multiple files but the data types do not agree between the files. {0:?} != {1:?}")]
TensorBuilderDTypeMismatch(TensorDType, TensorDType),
#[cfg(feature = "opencl")]
#[error("OpenCL error")]
OpenCLError(#[from] OpenCLError),
}
impl TensorDType {
@ -54,6 +59,9 @@ impl TensorDType {
#[derive(Debug)]
pub struct Tensor {
data: *mut u8,
#[cfg(feature = "opencl")]
opencl_data: Option<OpenCLTensor>,
dtype: TensorDType,
layout: Layout,
rows: i64,
@ -125,6 +133,16 @@ fn horizontal_sum(mut ymm: __m256) -> f32 {
}
impl Tensor {
#[inline]
pub fn assume_on_cpu(&self) {
#[cfg(feature = "opencl")]
{
if self.opencl_data.is_some() {
panic!("Tried to assume_on_cpu on a tensor that is on the GPU");
}
}
}
pub fn from_unpickled<P: AsRef<Path>, S: AsRef<str>>(
unpickled: &unpickler::Value,
name: S,
@ -175,6 +193,7 @@ impl Tensor {
// Gets a value as f32 from the tensor.
#[inline]
pub fn get_f32(&self, row: i64, col: i64) -> f32 {
self.assume_on_cpu();
assert!(
row >= 0 && col >= 0 && row < self.rows && col < self.cols,
"Invalid index: {}, {} Size: {}, {}",
@ -183,6 +202,7 @@ impl Tensor {
self.rows,
self.cols
);
let idx = row * self.capacity_cols + col;
match self.dtype {
TensorDType::Float16 => {
@ -199,6 +219,7 @@ impl Tensor {
// Sets a value from f32. The value is cast into whatever the tensor's dtype is.
#[inline]
pub fn set_f32(&mut self, row: i64, col: i64, val: f32) {
self.assume_on_cpu();
let idx = row * self.capacity_cols + col;
match self.dtype {
TensorDType::Float16 => {
@ -214,6 +235,7 @@ impl Tensor {
// Converts the tensor to two-dimensional Vec<f32>.
// Meant for debugging and making it easy to print tensors.
pub fn to_vec(&self) -> Vec<Vec<f32>> {
self.assume_on_cpu();
let mut result = Vec::new();
for row in 0..self.rows {
let mut row_vec = Vec::new();
@ -229,6 +251,8 @@ impl Tensor {
pub fn empty() -> Self {
Self {
data: std::ptr::null_mut(),
#[cfg(feature = "opencl")]
opencl_data: None,
dtype: TensorDType::Float16,
layout: Layout::from_size_align(0, 0).unwrap(),
rows: 0,
@ -274,6 +298,8 @@ impl Tensor {
Self {
data,
#[cfg(feature = "opencl")]
opencl_data: None,
dtype,
rows,
cols,
@ -294,6 +320,7 @@ impl Tensor {
// Runs softmax on row dimension.
pub fn softmax(&self) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
for row in 0..self.rows {
let mut sum = 0.0;
@ -325,6 +352,7 @@ impl Tensor {
// Computes mean for each row, so that columns become 1.
pub fn mean_cols(&self) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, 1, self.dtype) };
for row in 0..self.rows {
let mut sum = 0.0;
@ -337,6 +365,7 @@ impl Tensor {
}
pub fn mean(&self) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(1, 1, self.dtype) };
let mut sum = 0.0;
for row in 0..self.rows {
@ -349,6 +378,7 @@ impl Tensor {
}
pub fn pow(&self, power: f32) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
for row in 0..self.rows {
for col in 0..self.cols {
@ -360,6 +390,7 @@ impl Tensor {
}
pub fn sqrt(&self) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
for row in 0..self.rows {
for col in 0..self.cols {
@ -371,6 +402,7 @@ impl Tensor {
}
pub fn rsqrt(&self) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
for row in 0..self.rows {
for col in 0..self.cols {
@ -382,6 +414,8 @@ impl Tensor {
}
pub fn add(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
if self.rows() != other.rows() || self.cols() != other.cols() {
panic!(
"add: Tensors must have the same shape, left: {}x{} right: {}x{}",
@ -402,6 +436,7 @@ impl Tensor {
}
pub fn add_scalar(&self, scalar: f32) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
for row in 0..self.rows {
for col in 0..self.cols {
@ -413,6 +448,7 @@ impl Tensor {
}
pub fn scalar_multiply_f32(&self, scalar: f32) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
for row in 0..self.rows {
for col in 0..self.cols {
@ -424,6 +460,7 @@ impl Tensor {
}
pub fn scalar_multiply_broadcast(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
if other.cols != 1 {
panic!("Invalid scalar broadcast");
}
@ -442,6 +479,7 @@ impl Tensor {
}
pub fn scalar_product(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
if other.cols != 1 || other.rows != 1 {
panic!("Invalid scalar product");
}
@ -457,6 +495,8 @@ impl Tensor {
}
pub fn hadamard_product_broadcast(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
if self.cols != other.cols {
panic!(
"Invalid hadamard product broadcast: {}x{} vs {}x{}",
@ -480,6 +520,8 @@ impl Tensor {
}
pub fn hadamard_product(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
if self.cols != other.cols || self.rows != other.rows {
panic!(
"Invalid hadamard product: incompatible shapes, {}x{} vs {}x{}",
@ -516,6 +558,7 @@ impl Tensor {
unsafe { Tensor::uninitialized(total_rows, expected_cols, pieces[0].dtype) };
let mut row_offset = 0;
for piece in pieces {
piece.assume_on_cpu();
for row in 0..piece.rows {
for col in 0..piece.cols {
let val = piece.get_f32(row, col);
@ -528,6 +571,7 @@ impl Tensor {
}
pub fn silu(&self) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) };
for row in 0..self.rows {
for col in 0..self.cols {
@ -540,6 +584,7 @@ impl Tensor {
}
pub fn transpose(&self) -> Tensor {
self.assume_on_cpu();
let mut result = unsafe { Tensor::uninitialized(self.cols, self.rows, self.dtype) };
for row in 0..self.rows {
for col in 0..self.cols {
@ -554,6 +599,8 @@ impl Tensor {
///
/// This is used as a reference to test correctness of other matrix multiplications.
pub fn matrix_mul_naive(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
if self.cols != other.rows {
panic!(
"Invalid matrix multiplication {}x{} vs {}x{}",
@ -574,6 +621,8 @@ impl Tensor {
}
pub fn matrix_mul(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
if self.cols != other.rows {
panic!(
"Invalid matrix multiplication {}x{} vs {}x{}",
@ -592,6 +641,8 @@ impl Tensor {
}
pub fn matrix_mul_transposed(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
if self.cols != other.cols {
panic!(
"Invalid matrix transposed multiplication {}x{} vs {}x{}",
@ -608,6 +659,9 @@ impl Tensor {
/// Matrix multiplication done in-place
pub fn matrix_mul_inplace(&mut self, src: &Tensor, other: &Tensor) {
self.assume_on_cpu();
src.assume_on_cpu();
other.assume_on_cpu();
if src.cols != other.rows {
panic!(
"Invalid matrix multiplication {}x{} vs {}x{}",
@ -752,6 +806,9 @@ impl Tensor {
/// Matrix multiplication done in-place, but the second matrix is transposed.
/// With this, you can avoid using .transpose() on the second matrix.
pub fn matrix_mul_inplace_transposed(&mut self, src: &Tensor, other: &Tensor) {
self.assume_on_cpu();
src.assume_on_cpu();
other.assume_on_cpu();
if src.cols != other.cols {
panic!(
"Invalid matrix multiplication {}x{} vs {}x{}",
@ -827,6 +884,8 @@ impl Tensor {
//
// AxB @ Cx1 = Ax1
pub fn matrix_vector_mul(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
// TODO: this function is not optimized.
if self.cols != other.rows {
panic!(
@ -851,6 +910,8 @@ impl Tensor {
/// Same as matrix_vector_mul, but right side is assumed to be transposed.
pub fn matrix_vector_mul_transposed(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
if self.cols != other.cols {
panic!(
"Invalid matrix-vector transposed multiplication {}x{} vs {}x{}",
@ -888,6 +949,8 @@ impl Tensor {
/// Same as matrix_vector_mul but uses threading.
pub fn matrix_vector_mul_transposed_multithreaded(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
if self.cols != other.cols {
panic!(
"Invalid matrix-vector transposed multiplication {}x{} vs {}x{}",
@ -957,6 +1020,8 @@ impl Tensor {
#[allow(clippy::erasing_op)]
#[allow(clippy::identity_op)]
pub fn vector_matrix_mul(&self, other: &Tensor) -> Tensor {
self.assume_on_cpu();
other.assume_on_cpu();
if self.cols != other.rows {
panic!(
"Invalid matrix-vector multiplication {}x{} vs {}x{}",
@ -1055,6 +1120,8 @@ impl Tensor {
}
Self {
data,
#[cfg(feature = "opencl")]
opencl_data: None,
dtype,
rows,
cols,
@ -1064,6 +1131,7 @@ impl Tensor {
}
pub fn clip_cols(&self, cols: usize) -> Tensor {
self.assume_on_cpu();
if cols == 0 {
return Self::empty();
}
@ -1087,6 +1155,7 @@ impl Tensor {
}
pub fn view(&self, rows: i64, cols: i64) -> Tensor {
self.assume_on_cpu();
if rows * cols != self.rows * self.cols {
panic!(
"Invalid tensor view, requested {}x{} but tensor is {}x{}",
@ -1144,8 +1213,40 @@ impl Tensor {
}
}
/// Sends a tensor to the GPU. This is a no-op if GPU support is not enabled, or if the tensor
/// is already on the GPU.
///
/// The tensor is moved asynchronously.
#[cfg(feature = "opencl")]
pub fn to_gpu(&mut self, cl: &OpenCL) -> Result<(), TensorError> {
if self.opencl_data.is_some() {
return Ok(());
}
if self.dtype != TensorDType::Float16 {
panic!("Only float16 tensors are supported on the GPU");
}
let cl_tensor = cl.data_u16_to_gpu(
self.data as *const u16,
self.layout,
(self.rows * self.capacity_cols) as usize,
)?;
self.data = std::ptr::null_mut();
self.opencl_data = Some(cl_tensor);
Ok(())
}
/// Make sure that the tensor has finished going to GPU. Used mostly for benchmarking.
#[cfg(feature = "opencl")]
pub fn wait_until_on_gpu(&mut self) {
if self.opencl_data.is_none() {
panic!("wait_until_on_gpu: Tensor is not on GPU");
}
self.opencl_data.as_mut().unwrap().wait_until_ready();
}
/// Naive implementation of to_f32, used for testing that the faster methods are correct.
pub fn to_f32_naive(&self) -> Tensor {
self.assume_on_cpu();
if self.dtype == TensorDType::Float32 {
return self.clone();
}
@ -1162,6 +1263,7 @@ impl Tensor {
}
pub fn to_f32(&self) -> Tensor {
self.assume_on_cpu();
if self.dtype == TensorDType::Float32 {
return self.clone();
}
@ -1196,6 +1298,7 @@ impl Tensor {
/// Naive implementation of to_f16, used for testing that the faster methods are correct.
pub fn to_f16_naive(&self) -> Tensor {
self.assume_on_cpu();
if self.dtype == TensorDType::Float16 {
return self.clone();
}
@ -1212,6 +1315,7 @@ impl Tensor {
}
pub fn to_f16(&self) -> Tensor {
self.assume_on_cpu();
if self.dtype == TensorDType::Float16 {
return self.clone();
}
@ -1245,6 +1349,7 @@ impl Tensor {
}
pub fn row(&self, row: i64) -> Tensor {
self.assume_on_cpu();
if row < 0 || row > self.rows {
panic!("Invalid row index");
}

@ -0,0 +1,123 @@
/*
* OpenCL stuff to run (some) of the tensor operations.
*/
use ocl::{Buffer, Context, Device, Event, Platform, Queue};
use std::alloc::Layout;
use thiserror::Error;
#[derive(Debug)]
pub struct OpenCL {
ctx: Context,
queue: Queue,
}
#[derive(Debug)]
pub struct OpenCLTensor {
buf: Buffer<u16>, // really is f16
write_event: Option<ocl::Event>, // if Some, the buffer is being written to
data: *const u16, // if non-null, is host pointer that should be freed
data_layout: Layout,
}
impl Drop for OpenCLTensor {
fn drop(&mut self) {
if !self.data.is_null() {
if self.write_event.is_some() {
self.write_event.as_ref().unwrap().wait_for().unwrap();
}
unsafe {
std::alloc::dealloc(self.data as *mut u8, self.data_layout);
}
}
}
}
#[derive(Error, Debug)]
pub enum OpenCLError {
#[error("OpenCL error: {0}")]
OpenCL(#[from] ocl::Error),
#[error("Cannot select device")]
OpenCLDeviceSelection,
}
impl OpenCL {
pub fn new(verbose: bool, nth_device: usize) -> Result<OpenCL, OpenCLError> {
let platforms = Platform::list();
let mut devices: Vec<(Platform, Device)> = Vec::new();
for platform in platforms {
for device in Device::list_all(platform)? {
devices.push((platform, device));
}
}
if verbose {
println!("Enumerating OpenCL devices:");
}
for (idx, (_, device)) in devices.iter().enumerate() {
if verbose {
println!("OpenCL {} device: {}", idx, device.name()?);
}
}
if nth_device > devices.len() {
return Err(OpenCLError::OpenCLDeviceSelection);
}
if verbose {
println!("---");
println!("Selected OpenCL device: {}", devices[nth_device].1.name()?);
}
let ctx = Context::builder()
.platform(devices[nth_device].0)
.devices(devices[nth_device].1)
.build()?;
let queue = Queue::new(&ctx, devices[nth_device].1, None)?;
Ok(OpenCL { ctx, queue })
}
pub fn data_u16_to_gpu(
&self,
data: *const u16,
data_layout: Layout,
nitems: usize,
) -> Result<OpenCLTensor, OpenCLError> {
unsafe {
let buf = Buffer::builder()
.queue(self.queue.clone())
.len(nitems)
.build()?;
let mut event = Event::empty();
let data_slice: &[u16] = std::slice::from_raw_parts(data, nitems);
buf.cmd()
.write(data_slice)
.block(false)
.enew(&mut event)
.enq()?;
Ok(OpenCLTensor {
buf,
write_event: Some(event),
data,
data_layout,
})
}
}
}
impl OpenCLTensor {
pub fn wait_until_ready(&mut self) {
if self.write_event.is_none() {
return;
}
self.write_event.as_ref().unwrap().wait_for().unwrap();
self.write_event = None;
if !self.data.is_null() {
if self.write_event.is_some() {
self.write_event.as_ref().unwrap().wait_for().unwrap();
}
unsafe {
std::alloc::dealloc(self.data as *mut u8, self.data_layout);
}
}
}
}
Loading…
Cancel
Save