From ac4adb54eaeadb0fa55fa4b984af01e1ac4dfd2c Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Mon, 18 Sep 2023 19:56:53 -0400 Subject: [PATCH] Burn compute (#809) --- Cargo.toml | 51 +-- burn-common/src/id.rs | 4 +- burn-compute/Cargo.toml | 27 ++ burn-compute/README.md | 7 + burn-compute/src/channel/base.rs | 21 ++ burn-compute/src/channel/cell.rs | 65 ++++ burn-compute/src/channel/mod.rs | 17 + burn-compute/src/channel/mpsc.rs | 154 +++++++++ burn-compute/src/channel/mutex.rs | 55 +++ burn-compute/src/client.rs | 65 ++++ burn-compute/src/compute.rs | 83 +++++ burn-compute/src/id.rs | 53 +++ burn-compute/src/lib.rs | 26 ++ burn-compute/src/memory_management/base.rs | 27 ++ burn-compute/src/memory_management/mod.rs | 5 + burn-compute/src/memory_management/simple.rs | 339 +++++++++++++++++++ burn-compute/src/server.rs | 40 +++ burn-compute/src/storage/base.rs | 48 +++ burn-compute/src/storage/bytes_cpu.rs | 121 +++++++ burn-compute/src/storage/mod.rs | 8 + burn-compute/tests/dummy/compute.rs | 26 ++ burn-compute/tests/dummy/kernel.rs | 25 ++ burn-compute/tests/dummy/mod.rs | 7 + burn-compute/tests/dummy/server.rs | 60 ++++ burn-compute/tests/integration_test.rs | 38 +++ burn-wgpu/Cargo.toml | 4 + burn-wgpu/src/compute/base.rs | 129 +++++++ burn-wgpu/src/compute/mod.rs | 7 + burn-wgpu/src/compute/server.rs | 253 ++++++++++++++ burn-wgpu/src/compute/storage.rs | 102 ++++++ burn-wgpu/src/context/base.rs | 2 +- burn-wgpu/src/kernel/base.rs | 4 +- burn-wgpu/src/lib.rs | 4 + xtask/src/runchecks.rs | 150 ++++---- 34 files changed, 1941 insertions(+), 86 deletions(-) create mode 100644 burn-compute/Cargo.toml create mode 100644 burn-compute/README.md create mode 100644 burn-compute/src/channel/base.rs create mode 100644 burn-compute/src/channel/cell.rs create mode 100644 burn-compute/src/channel/mod.rs create mode 100644 burn-compute/src/channel/mpsc.rs create mode 100644 burn-compute/src/channel/mutex.rs create mode 100644 burn-compute/src/client.rs create mode 100644 burn-compute/src/compute.rs create mode 100644 burn-compute/src/id.rs create mode 100644 burn-compute/src/lib.rs create mode 100644 burn-compute/src/memory_management/base.rs create mode 100644 burn-compute/src/memory_management/mod.rs create mode 100644 burn-compute/src/memory_management/simple.rs create mode 100644 burn-compute/src/server.rs create mode 100644 burn-compute/src/storage/base.rs create mode 100644 burn-compute/src/storage/bytes_cpu.rs create mode 100644 burn-compute/src/storage/mod.rs create mode 100644 burn-compute/tests/dummy/compute.rs create mode 100644 burn-compute/tests/dummy/kernel.rs create mode 100644 burn-compute/tests/dummy/mod.rs create mode 100644 burn-compute/tests/dummy/server.rs create mode 100644 burn-compute/tests/integration_test.rs create mode 100644 burn-wgpu/src/compute/base.rs create mode 100644 burn-wgpu/src/compute/mod.rs create mode 100644 burn-wgpu/src/compute/server.rs create mode 100644 burn-wgpu/src/compute/storage.rs diff --git a/Cargo.toml b/Cargo.toml index c893cc05e..387068b7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "burn", "burn-autodiff", "burn-common", + "burn-compute", "burn-core", "burn-dataset", "burn-derive", @@ -24,9 +25,7 @@ members = [ "examples/*", ] -exclude = [ - "examples/notebook", -] +exclude = ["examples/notebook"] [workspace.dependencies] bytemuck = "1.13" @@ -37,7 +36,7 @@ dirs = "5.0.1" fake = "2.6.1" flate2 = "1.0.26" float-cmp = "0.9.0" -gix-tempfile = {version = "8.0.0", features = ["signals"]} +gix-tempfile = { version = "8.0.0", features = ["signals"] } hashbrown = "0.14.0" indicatif = "0.17.5" libm = "0.2.7" @@ -47,17 +46,17 @@ proc-macro2 = "1.0.60" protobuf-codegen = "3.2" quote = "1.0.28" r2d2 = "0.8.10" -r2d2_sqlite = {version = "0.22.0"} +r2d2_sqlite = { version = "0.22.0" } rayon = "1.7.0" rmp-serde = "1.1.1" rstest = "0.18.1" -rusqlite = {version = "0.29"} +rusqlite = { version = "0.29" } sanitize-filename = "0.5.0" serde_rusqlite = "0.33.1" -spin = {version = "0.9.8", features = ["mutex", "spin_mutex"]} +spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] } strum = "0.24" strum_macros = "0.24" -syn = {version = "2.0", features = ["full", "extra-traits"]} +syn = { version = "2.0", features = ["full", "extra-traits"] } tempfile = "3.6.0" thiserror = "1.0.40" tracing-subscriber = "0.3.17" @@ -67,19 +66,33 @@ tracing-appender = "0.2.2" # WGPU stuff futures-intrusive = "0.5" pollster = "0.3" -text_placeholder = {version = "0.5.0", features = ["struct_context"]} +text_placeholder = { version = "0.5.0", features = ["struct_context"] } wgpu = "0.17.0" # # The following packages disable the "std" feature for no_std compatibility # -bincode = {version = "2.0.0-rc.3", features = ["alloc", "serde"], default-features = false} -derive-new = {version = "0.5.9", default-features = false} -half = {version = "2.3.1", features = ["alloc", "num-traits", "serde"], default-features = false} -ndarray = {version = "0.15.6", default-features = false} -num-traits = {version = "0.2.15", default-features = false, features = ["libm"]}# libm is for no_std -rand = {version = "0.8.5", default-features = false, features = ["std_rng"]}# std_rng is for no_std -rand_distr = {version = "0.4.3", default-features = false} -serde = {version = "1.0.164", default-features = false, features = ["derive", "alloc"]}# alloc is for no_std, derive is needed -serde_json = {version = "1.0.96", default-features = false} -uuid = {version = "1.3.4", default-features = false} +bincode = { version = "2.0.0-rc.3", features = [ + "alloc", + "serde", +], default-features = false } +derive-new = { version = "0.5.9", default-features = false } +half = { version = "2.3.1", features = [ + "alloc", + "num-traits", + "serde", +], default-features = false } +ndarray = { version = "0.15.6", default-features = false } +num-traits = { version = "0.2.15", default-features = false, features = [ + "libm", +] } # libm is for no_std +rand = { version = "0.8.5", default-features = false, features = [ + "std_rng", +] } # std_rng is for no_std +rand_distr = { version = "0.4.3", default-features = false } +serde = { version = "1.0.164", default-features = false, features = [ + "derive", + "alloc", +] } # alloc is for no_std, derive is needed +serde_json = { version = "1.0.96", default-features = false } +uuid = { version = "1.3.4", default-features = false } diff --git a/burn-common/src/id.rs b/burn-common/src/id.rs index 6ac50e082..7c5ed7ff9 100644 --- a/burn-common/src/id.rs +++ b/burn-common/src/id.rs @@ -1,7 +1,5 @@ -use alloc::string::{String, ToString}; - use crate::rand::{get_seeded_rng, Rng, SEED}; - +use alloc::string::{String, ToString}; use uuid::{Builder, Bytes}; /// Simple ID generator. diff --git a/burn-compute/Cargo.toml b/burn-compute/Cargo.toml new file mode 100644 index 000000000..60e1804f0 --- /dev/null +++ b/burn-compute/Cargo.toml @@ -0,0 +1,27 @@ +[package] +authors = ["louisfd ", "Nathaniel Simard"] +categories = ["science"] +description = "Compute crate that helps creating high performance async backends." +edition = "2021" +keywords = ["deep-learning", "machine-learning", "data"] +license = "MIT OR Apache-2.0" +name = "burn-compute" +readme = "README.md" +repository = "https://github.com/burn-rs/burn/tree/main/burn-compute" +version = "0.10.0" + +[features] +default = ["std", "channel-mutex", "channel-mpsc", "channel-cell", "storage-bytes"] +std = [] +channel-mutex = [] +channel-cell = [] +channel-mpsc = [] # Assume std +storage-bytes = [] + +[dependencies] +burn-common = { path = "../burn-common", version = "0.10.0", default-features = false } +derive-new = { workspace = true } +spin = { workspace = true } +hashbrown = { workspace = true } + +[dev-dependencies] diff --git a/burn-compute/README.md b/burn-compute/README.md new file mode 100644 index 000000000..f986b1b16 --- /dev/null +++ b/burn-compute/README.md @@ -0,0 +1,7 @@ +# Burn Compute + +This crate helps creating high performance async backends. + +- [x] Asynchronous kernel executions +- [x] Memory allocation management +- [ ] Autotuning diff --git a/burn-compute/src/channel/base.rs b/burn-compute/src/channel/base.rs new file mode 100644 index 000000000..23180cdb7 --- /dev/null +++ b/burn-compute/src/channel/base.rs @@ -0,0 +1,21 @@ +use crate::server::{ComputeServer, Handle}; +use alloc::vec::Vec; + +/// The ComputeChannel trait links the ComputeClient to the ComputeServer +/// while ensuring thread-safety +pub trait ComputeChannel: Clone { + /// Given a handle, returns owned resource as bytes + fn read(&self, handle: &Handle) -> Vec; + + /// Given a resource as bytes, stores it and returns the resource handle + fn create(&self, data: &[u8]) -> Handle; + + /// Reserves `size` bytes in the storage, and returns a handle over them + fn empty(&self, size: usize) -> Handle; + + /// Executes the `kernel` over the given `handles`. + fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]); + + /// Wait for the completion of every task in the server. + fn sync(&self); +} diff --git a/burn-compute/src/channel/cell.rs b/burn-compute/src/channel/cell.rs new file mode 100644 index 000000000..ada9f0665 --- /dev/null +++ b/burn-compute/src/channel/cell.rs @@ -0,0 +1,65 @@ +use super::ComputeChannel; +use crate::server::{ComputeServer, Handle}; +use alloc::sync::Arc; +use alloc::vec::Vec; + +/// A channel using a [ref cell](core::cell::RefCell) to access the server with mutability. +/// +/// # Important +/// +/// Only use this channel if you don't use any threading in your application, otherwise it will +/// panic or cause undefined behaviors. +/// +/// This is mosly useful for `no-std` environments where threads aren't supported, otherwise prefer +/// the [mutex](super::MutexComputeChannel) or the [mpsc](super::MpscComputeChannel) channels. +pub struct RefCellComputeChannel { + server: Arc>, +} + +impl Clone for RefCellComputeChannel { + fn clone(&self) -> Self { + Self { + server: self.server.clone(), + } + } +} +impl RefCellComputeChannel +where + Server: ComputeServer, +{ + /// Create a new cell compute channel. + pub fn new(server: Server) -> Self { + Self { + server: Arc::new(core::cell::RefCell::new(server)), + } + } +} + +impl ComputeChannel for RefCellComputeChannel +where + Server: ComputeServer, +{ + fn read(&self, handle: &Handle) -> Vec { + let mut server = self.server.borrow_mut(); + + server.read(handle) + } + + fn create(&self, resource: &[u8]) -> Handle { + self.server.borrow_mut().create(resource) + } + + fn empty(&self, size: usize) -> Handle { + self.server.borrow_mut().empty(size) + } + + fn execute(&self, kernel_description: Server::Kernel, handles: &[&Handle]) { + self.server + .borrow_mut() + .execute(kernel_description, handles) + } + + fn sync(&self) { + self.server.borrow_mut().sync() + } +} diff --git a/burn-compute/src/channel/mod.rs b/burn-compute/src/channel/mod.rs new file mode 100644 index 000000000..881f52566 --- /dev/null +++ b/burn-compute/src/channel/mod.rs @@ -0,0 +1,17 @@ +mod base; +pub use base::*; + +#[cfg(feature = "channel-mutex")] +mod mutex; +#[cfg(feature = "channel-mutex")] +pub use mutex::*; + +#[cfg(feature = "channel-mpsc")] +mod mpsc; +#[cfg(feature = "channel-mpsc")] +pub use mpsc::*; + +#[cfg(feature = "channel-cell")] +mod cell; +#[cfg(feature = "channel-cell")] +pub use cell::*; diff --git a/burn-compute/src/channel/mpsc.rs b/burn-compute/src/channel/mpsc.rs new file mode 100644 index 000000000..63392e0f0 --- /dev/null +++ b/burn-compute/src/channel/mpsc.rs @@ -0,0 +1,154 @@ +use std::{ + sync::{mpsc, Arc}, + thread, +}; + +use super::ComputeChannel; +use crate::server::{ComputeServer, Handle}; + +/// Create a channel using the [multi-producer, single-consumer channel](mpsc) to communicate with +/// the compute server spawn on its own thread. +pub struct MpscComputeChannel +where + Server: ComputeServer, +{ + state: Arc>, +} + +struct MpscComputeChannelState +where + Server: ComputeServer, +{ + _handle: thread::JoinHandle<()>, + sender: mpsc::SyncSender>, +} + +type Callback = mpsc::SyncSender; + +enum Message +where + Server: ComputeServer, +{ + Read(Handle, Callback>), + Create(Vec, Callback>), + Empty(usize, Callback>), + Execute(Server::Kernel, Vec>), + Sync(Callback<()>), +} + +impl MpscComputeChannel +where + Server: ComputeServer + 'static, +{ + /// Create a new mpsc compute channel. + pub fn new(mut server: Server, bound: usize) -> Self { + let (sender, receiver) = mpsc::sync_channel(bound); + + let _handle = thread::spawn(move || { + while let Ok(message) = receiver.recv() { + match message { + Message::Read(handle, callback) => { + let data = server.read(&handle); + core::mem::drop(handle); + callback.send(data).unwrap(); + } + Message::Create(data, callback) => { + let handle = server.create(&data); + callback.send(handle).unwrap(); + } + Message::Empty(size, callback) => { + let handle = server.empty(size); + callback.send(handle).unwrap(); + } + Message::Execute(kernel, handles) => { + server.execute(kernel, &handles.iter().collect::>()); + } + Message::Sync(callback) => { + server.sync(); + callback.send(()).unwrap(); + } + }; + } + }); + + let state = Arc::new(MpscComputeChannelState { sender, _handle }); + + Self { state } + } +} + +impl Clone for MpscComputeChannel { + fn clone(&self) -> Self { + Self { + state: self.state.clone(), + } + } +} + +impl ComputeChannel for MpscComputeChannel +where + Server: ComputeServer + 'static, +{ + fn read(&self, handle: &Handle) -> Vec { + let (callback, response) = mpsc::sync_channel(1); + + self.state + .sender + .send(Message::Read(handle.clone(), callback)) + .unwrap(); + + self.response(response) + } + + fn create(&self, data: &[u8]) -> Handle { + let (callback, response) = mpsc::sync_channel(1); + + self.state + .sender + .send(Message::Create(data.to_vec(), callback)) + .unwrap(); + + self.response(response) + } + + fn empty(&self, size: usize) -> Handle { + let (callback, response) = mpsc::sync_channel(1); + + self.state + .sender + .send(Message::Empty(size, callback)) + .unwrap(); + + self.response(response) + } + + fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { + self.state + .sender + .send(Message::Execute( + kernel, + handles + .iter() + .map(|h| (*h).clone()) + .collect::>>(), + )) + .unwrap() + } + + fn sync(&self) { + let (callback, response) = mpsc::sync_channel(1); + + self.state.sender.send(Message::Sync(callback)).unwrap(); + + self.response(response) + } +} + +impl MpscComputeChannel { + fn response(&self, response: mpsc::Receiver) -> Response { + match response.recv() { + Ok(val) => val, + Err(err) => panic!("Can't connect to the server correctly {err:?}"), + } + } +} diff --git a/burn-compute/src/channel/mutex.rs b/burn-compute/src/channel/mutex.rs new file mode 100644 index 000000000..369b365fa --- /dev/null +++ b/burn-compute/src/channel/mutex.rs @@ -0,0 +1,55 @@ +use super::ComputeChannel; +use crate::server::{ComputeServer, Handle}; +use alloc::sync::Arc; +use alloc::vec::Vec; +use spin::Mutex; + +/// The MutexComputeChannel ensures thread-safety by locking the server +/// on every operation +pub struct MutexComputeChannel { + server: Arc>, +} + +impl Clone for MutexComputeChannel { + fn clone(&self) -> Self { + Self { + server: self.server.clone(), + } + } +} +impl MutexComputeChannel +where + Server: ComputeServer, +{ + /// Create a new mutex compute channel. + pub fn new(server: Server) -> Self { + Self { + server: Arc::new(Mutex::new(server)), + } + } +} + +impl ComputeChannel for MutexComputeChannel +where + Server: ComputeServer, +{ + fn read(&self, handle: &Handle) -> Vec { + self.server.lock().read(handle) + } + + fn create(&self, data: &[u8]) -> Handle { + self.server.lock().create(data) + } + + fn empty(&self, size: usize) -> Handle { + self.server.lock().empty(size) + } + + fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { + self.server.lock().execute(kernel, handles) + } + + fn sync(&self) { + self.server.lock().sync() + } +} diff --git a/burn-compute/src/client.rs b/burn-compute/src/client.rs new file mode 100644 index 000000000..422603118 --- /dev/null +++ b/burn-compute/src/client.rs @@ -0,0 +1,65 @@ +use crate::{ + channel::ComputeChannel, + server::{ComputeServer, Handle}, +}; +use alloc::vec::Vec; +use core::marker::PhantomData; + +/// The ComputeClient is the entry point to require tasks from the ComputeServer. +/// It should be obtained for a specific device via the Compute struct. +pub struct ComputeClient { + channel: Channel, + _server: PhantomData, +} + +impl Clone for ComputeClient +where + S: ComputeServer, + C: ComputeChannel, +{ + fn clone(&self) -> Self { + Self { + channel: self.channel.clone(), + _server: PhantomData, + } + } +} + +impl ComputeClient +where + Server: ComputeServer, + Channel: ComputeChannel, +{ + /// Create a new client. + pub fn new(channel: Channel) -> Self { + Self { + channel, + _server: PhantomData, + } + } + + /// Given a handle, returns owned resource as bytes. + pub fn read(&self, handle: &Handle) -> Vec { + self.channel.read(handle) + } + + /// Given a resource, stores it and returns the resource handle. + pub fn create(&self, data: &[u8]) -> Handle { + self.channel.create(data) + } + + /// Reserves `size` bytes in the storage, and returns a handle over them. + pub fn empty(&self, size: usize) -> Handle { + self.channel.empty(size) + } + + /// Executes the `kernel` over the given `handles`. + pub fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { + self.channel.execute(kernel, handles) + } + + /// Wait for the completion of every task in the server. + pub fn sync(&self) { + self.channel.sync() + } +} diff --git a/burn-compute/src/compute.rs b/burn-compute/src/compute.rs new file mode 100644 index 000000000..2f935e8ed --- /dev/null +++ b/burn-compute/src/compute.rs @@ -0,0 +1,83 @@ +use crate::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer}; +use core::ops::DerefMut; +use hashbrown::HashMap; + +/// The compute type has the responsibility to retrieve the correct compute client based on the +/// given device. +pub struct Compute { + clients: spin::Mutex>>>, +} + +impl Compute +where + Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, + Server: ComputeServer, + Channel: ComputeChannel, +{ + /// Create a new compute. + pub const fn new() -> Self { + Self { + clients: spin::Mutex::new(None), + } + } + + /// Get the compute client for the given device. + /// + /// Provide the init function to create a new client if it isn't already initialized. + pub fn client(&self, device: &Device, init: Init) -> ComputeClient + where + Init: Fn() -> ComputeClient, + { + let mut clients = self.clients.lock(); + + if clients.is_none() { + Self::register_inner(device, init(), &mut clients); + } + + match clients.deref_mut() { + Some(clients) => match clients.get(device) { + Some(client) => client.clone(), + None => { + let client = init(); + clients.insert(device.clone(), client.clone()); + client + } + }, + _ => unreachable!(), + } + } + + /// Register the compute client for the given device. + /// + /// # Note + /// + /// This function is mostly useful when the creation of the compute client can't be done + /// synchronously and require special context. + /// + /// # Panics + /// + /// If a client is already registered for the given device. + pub fn register(&self, device: &Device, client: ComputeClient) { + let mut clients = self.clients.lock(); + + Self::register_inner(device, client, &mut clients); + } + + fn register_inner( + device: &Device, + client: ComputeClient, + clients: &mut Option>>, + ) { + if clients.is_none() { + *clients = Some(HashMap::new()); + } + + if let Some(clients) = clients { + if clients.contains_key(device) { + panic!("Client already created for device {:?}", device); + } + + clients.insert(device.clone(), client); + } + } +} diff --git a/burn-compute/src/id.rs b/burn-compute/src/id.rs new file mode 100644 index 000000000..1c71ccd84 --- /dev/null +++ b/burn-compute/src/id.rs @@ -0,0 +1,53 @@ +#[macro_export(local_inner_macros)] +/// Create a new storage ID type. +macro_rules! storage_id_type { + ($name:ident) => { + #[derive(Clone, Hash, PartialEq, Eq)] + /// Storage ID. + pub struct $name { + id: alloc::sync::Arc, + } + + impl $name { + /// Create a new ID. + pub fn new() -> Self { + Self { + id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()), + } + } + } + + impl Default for $name { + fn default() -> Self { + Self::new() + } + } + }; +} + +#[macro_export(local_inner_macros)] +/// Create a new memory ID type. +macro_rules! memory_id_type { + ($name:ident) => { + #[derive(Clone, Hash, PartialEq, Eq)] + /// Memory ID. + pub struct $name { + id: alloc::sync::Arc, + } + + impl $name { + /// Create a new ID. + pub(crate) fn new() -> Self { + Self { + id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()), + } + } + } + + impl Default for $name { + fn default() -> Self { + Self::new() + } + } + }; +} diff --git a/burn-compute/src/lib.rs b/burn-compute/src/lib.rs new file mode 100644 index 000000000..d93502a7d --- /dev/null +++ b/burn-compute/src/lib.rs @@ -0,0 +1,26 @@ +#![cfg_attr(not(feature = "std"), no_std)] +#![warn(missing_docs)] + +//! Burn compute crate that helps creating high performance async backends. + +extern crate alloc; + +#[macro_use] +extern crate derive_new; + +mod id; + +/// Compute channel module. +pub mod channel; +/// Compute client module. +pub mod client; + +/// Memory management module. +pub mod memory_management; +/// Compute server module. +pub mod server; +/// Compute Storage module. +pub mod storage; + +mod compute; +pub use compute::*; diff --git a/burn-compute/src/memory_management/base.rs b/burn-compute/src/memory_management/base.rs new file mode 100644 index 000000000..bf0203290 --- /dev/null +++ b/burn-compute/src/memory_management/base.rs @@ -0,0 +1,27 @@ +use crate::storage::ComputeStorage; + +/// The MemoryHandle trait is an abstract way to refer to some memory segment. +/// It should not contain actual references to data. +/// +/// It is responsible for determining if the memory segment can be mutated, +/// for instance by keeping track of a reference count +pub trait MemoryHandle: Clone + Send { + /// Checks if the underlying memory can be safely mutated. + fn can_mut(&self) -> bool; +} + +/// The MemoryManagement trait encapsulates strategies for (de)allocating memory. +/// It is bound to the ComputeStorage trait, which does the actual (de)allocations. +/// +/// The MemoryManagement can only reserve memory space or get the resource located at a space. +/// Modification of the resource data should be done directly on the resource. +pub trait MemoryManagement: Send { + /// The associated type Handle must implement MemoryHandle + type Handle: MemoryHandle; + + /// Returns the resource from the storage at the specified handle + fn get(&mut self, handle: &Self::Handle) -> Storage::Resource; + + /// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it + fn reserve(&mut self, size: usize) -> Self::Handle; +} diff --git a/burn-compute/src/memory_management/mod.rs b/burn-compute/src/memory_management/mod.rs new file mode 100644 index 000000000..5adbb7fd8 --- /dev/null +++ b/burn-compute/src/memory_management/mod.rs @@ -0,0 +1,5 @@ +mod base; +mod simple; + +pub use base::*; +pub use simple::*; diff --git a/burn-compute/src/memory_management/simple.rs b/burn-compute/src/memory_management/simple.rs new file mode 100644 index 000000000..6cad5a062 --- /dev/null +++ b/burn-compute/src/memory_management/simple.rs @@ -0,0 +1,339 @@ +use super::{MemoryHandle, MemoryManagement}; +use crate::{ + memory_id_type, + storage::{ComputeStorage, StorageHandle, StorageUtilization}, +}; +use alloc::{sync::Arc, vec::Vec}; +use hashbrown::HashMap; + +// The ChunkId allows to keep track of how many references there are to a specific chunk. +memory_id_type!(ChunkId); +// The SliceId allows to keep track of how many references there are to a specific slice. +memory_id_type!(SliceId); + +impl ChunkId { + /// A chunk is free if it is only referred by the chunk hashmap. + fn is_free(&self) -> bool { + Arc::strong_count(&self.id) <= 1 + } +} + +impl SliceId { + /// A slice is free if it is only referred by the slice hashmap and the chunk it is in. + fn is_free(&self) -> bool { + Arc::strong_count(&self.id) <= 2 + } +} + +/// The SimpleHandle is a memory handle, referring to either a chunk or a slice. +#[derive(Clone)] +pub enum SimpleHandle { + /// A whole chunk of memory. + Chunk(ChunkId), + /// A slice of a chunk of memory. + Slice(SliceId), +} + +/// The strategy defines the frequency at which deallocation of unused memory chunks should occur. +pub enum DeallocStrategy { + /// Once every n calls to reserve. + /// + /// First associated data is n, second is the state and should start at 0 + PeriodTick(usize, usize), + #[cfg(feature = "std")] + /// Once every period of time + PeriodTime(std::time::Duration, std::time::Instant), + /// Never deallocate. + Never, +} + +impl DeallocStrategy { + /// Create a new strategy with the given period. + pub fn new_period_tick(period: usize) -> Self { + DeallocStrategy::PeriodTick(period, 0) + } + + fn should_dealloc(&mut self) -> bool { + match self { + DeallocStrategy::PeriodTick(period, last) => { + *last = (*last + 1) % *period; + *last == 0 + } + #[cfg(feature = "std")] + DeallocStrategy::PeriodTime(period, last) => { + if &last.elapsed() > period { + *last = std::time::Instant::now(); + true + } else { + false + } + } + DeallocStrategy::Never => false, + } + } +} + +/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks. +pub struct SimpleMemoryManagement { + chunks: HashMap)>, + slices: HashMap, + dealloc_strategy: DeallocStrategy, + storage: Storage, +} + +impl MemoryHandle for SimpleHandle { + /// Returns true if referenced by only one tensor, and only once by the + /// memory management hashmaps + fn can_mut(&self) -> bool { + // One reference in the chunk hashmap, another owned by one tensor. + const REFERENCE_LIMIT_CHUNK: usize = 2; + // One reference in the chunk hashmap (for the chunk on which this slice is built), + // another in the slice hashmap for this slice, and another owned by one tensor. + const REFERENCE_LIMIT_SLICE: usize = 3; + + match &self { + SimpleHandle::Chunk(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_CHUNK, + SimpleHandle::Slice(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_SLICE, + } + } +} + +impl MemoryManagement for SimpleMemoryManagement { + type Handle = SimpleHandle; + + /// Returns the resource from the storage, for the specified handle. + fn get(&mut self, handle: &Self::Handle) -> Storage::Resource { + let resource = match &handle { + SimpleHandle::Chunk(id) => &self.chunks.get(id).unwrap().0, + SimpleHandle::Slice(id) => &self.slices.get(id).unwrap().0, + }; + + self.storage.get(resource) + } + + /// Reserves memory of specified size using the reserve algorithm, and return + /// a handle to the reserved memory. + /// + /// Also clean ups, removing unused slices, and chunks if permitted by deallocation strategy. + fn reserve(&mut self, size: usize) -> Self::Handle { + self.cleanup_slices(); + + let handle = self.reserve_algorithm(size); + + if self.dealloc_strategy.should_dealloc() { + self.cleanup_chunks(); + } + + handle + } +} + +impl SimpleMemoryManagement { + /// Creates a new instance using the given storage and deallocation strategy. + pub fn new(storage: Storage, dealloc_strategy: DeallocStrategy) -> Self { + Self { + chunks: HashMap::new(), + slices: HashMap::new(), + dealloc_strategy, + storage, + } + } + + /// Creates an new instance using the given storage without deallocation. + pub fn never_dealloc(storage: Storage) -> Self { + Self::new(storage, DeallocStrategy::Never) + } + + fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle { + // Looks for a large enough, existing but unused chunk of memory. + let chunk = self.find_free_chunk(size); + + match chunk { + Some((chunk_id, chunk_size)) => { + if size == chunk_size { + // If there is one of exactly the same size, it reuses it. + SimpleHandle::Chunk(chunk_id.clone()) + } else { + // Otherwise creates a slice of the right size upon it, always starting at zero. + self.create_slice(size, chunk_id) + } + } + // If no chunk available, creates one of exactly the right size. + None => self.create_chunk(size), + } + } + + /// Finds the smallest of the free and large enough chunks to fit `size` + /// Returns the chunk's id and size. + fn find_free_chunk(&self, size: usize) -> Option<(ChunkId, usize)> { + let mut size_diff_current = usize::MAX; + let mut current = None; + + self.chunks + .iter() + .for_each(|(chunk_id, (resource, slices))| { + let is_free = slices.is_empty() && chunk_id.is_free(); + + if is_free && resource.size() > size { + let size_diff = resource.size() - size; + if size_diff < size_diff_current { + current = Some((chunk_id, resource)); + size_diff_current = size_diff; + } + } + }); + + current.map(|(id, handle)| (id.clone(), handle.size())) + } + + /// Creates a slice of size `size` upon the given chunk. + /// + /// For now slices must start at zero, therefore there can be only one per chunk + fn create_slice(&mut self, size: usize, chunk_id: ChunkId) -> SimpleHandle { + let (handle, slices) = self.chunks.get_mut(&chunk_id).unwrap(); + let slice_id = SliceId::new(); + + let storage = StorageHandle { + id: handle.id.clone(), + utilization: StorageUtilization::Slice(0, size), + }; + + if slices.is_empty() { + self.slices.insert(slice_id.clone(), (storage, chunk_id)); + } else { + panic!("Can't have more than 1 slice yet."); + } + + slices.push(slice_id.clone()); + + SimpleHandle::Slice(slice_id) + } + + /// Creates a chunk of given size by allocating on the storage. + fn create_chunk(&mut self, size: usize) -> SimpleHandle { + let resource = self.storage.alloc(size); + let chunk_id = ChunkId::new(); + + self.chunks.insert(chunk_id.clone(), (resource, Vec::new())); + + SimpleHandle::Chunk(chunk_id) + } + + /// Deallocates free chunks and remove them from chunks map. + fn cleanup_chunks(&mut self) { + let mut ids_to_remove = Vec::new(); + + self.chunks.iter().for_each(|(chunk_id, _resource)| { + if chunk_id.is_free() { + ids_to_remove.push(chunk_id.clone()); + } + }); + + ids_to_remove + .iter() + .map(|chunk_id| self.chunks.remove(chunk_id).unwrap()) + .for_each(|(resource, _slices)| { + self.storage.dealloc(resource.id); + }); + } + + /// Removes free slices from slice map and corresponding chunks. + fn cleanup_slices(&mut self) { + let mut ids_to_remove = Vec::new(); + + self.slices.iter().for_each(|(slice_id, _resource)| { + if slice_id.is_free() { + ids_to_remove.push(slice_id.clone()); + } + }); + + ids_to_remove + .iter() + .map(|slice_id| { + let value = self.slices.remove(slice_id).unwrap(); + (slice_id, value.1) + }) + .for_each(|(slice_id, chunk_id)| { + let (_chunk, slices) = self.chunks.get_mut(&chunk_id).unwrap(); + slices.retain(|id| id != slice_id); + }); + } +} + +#[cfg(test)] +mod tests { + use crate::{ + memory_management::{MemoryHandle, MemoryManagement}, + storage::BytesStorage, + }; + + use super::{DeallocStrategy, SimpleMemoryManagement}; + + #[test] + fn can_mut_with_single_tensor_reference() { + let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default()); + + let chunk_size = 4; + let simple_handle = memory_management.create_chunk(chunk_size); + + let x = simple_handle.clone(); + core::mem::drop(simple_handle); + + assert!(x.can_mut()); + } + + #[test] + fn two_tensor_references_remove_mutability() { + let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default()); + + let chunk_size = 4; + let simple_handle = memory_management.create_chunk(chunk_size); + + let x = simple_handle.clone(); + + assert!(!simple_handle.can_mut()); + assert!(!x.can_mut()) + } + + #[test] + fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() { + let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default()); + let chunk_size = 4; + let _chunk_handle = memory_management.reserve(chunk_size); + let _new_handle = memory_management.reserve(chunk_size); + + assert_eq!(memory_management.chunks.len(), 2); + } + + #[test] + fn when_empty_chunk_is_cleaned_upexists_it_disappears() { + let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default()); + let chunk_size = 4; + let chunk_handle = memory_management.reserve(chunk_size); + drop(chunk_handle); + memory_management.cleanup_chunks(); + + assert_eq!(memory_management.chunks.len(), 0); + } + + #[test] + fn never_dealloc_strategy_never_deallocs() { + let mut never_dealloc = DeallocStrategy::Never; + for _ in 0..20 { + assert!(!never_dealloc.should_dealloc()) + } + } + + #[test] + fn period_tick_dealloc_strategy_should_dealloc_after_period() { + let period = 3; + let mut period_tick_dealloc = DeallocStrategy::new_period_tick(period); + + for _ in 0..3 { + for _ in 0..period - 1 { + assert!(!period_tick_dealloc.should_dealloc()); + } + assert!(period_tick_dealloc.should_dealloc()); + } + } +} diff --git a/burn-compute/src/server.rs b/burn-compute/src/server.rs new file mode 100644 index 000000000..24fd5a59d --- /dev/null +++ b/burn-compute/src/server.rs @@ -0,0 +1,40 @@ +use alloc::vec::Vec; + +use crate::{memory_management::MemoryManagement, storage::ComputeStorage}; + +type _Storage = ::Storage; +type _MemoryManagement = ::MemoryManagement; + +/// This alias for a [memory handle](MemoryManagement::Handle). +pub type Handle = <_MemoryManagement as MemoryManagement<_Storage>>::Handle; + +/// The compute server is responsible for handling resources and computations over resources. +/// +/// Everything in the server is mutable, therefore it should be solely accessed through the +/// [compute channel](crate::channel::ComputeChannel) for thread safety. +pub trait ComputeServer: Send { + /// The kernel type defines the computation algorithms. + type Kernel: Send; + /// The [storage](ComputeStorage) type defines how data is stored and accessed. + type Storage: ComputeStorage; + /// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type. + type MemoryManagement: MemoryManagement; + + /// Given a handle, returns the owned resource as bytes. + fn read(&mut self, handle: &Handle) -> Vec; + + /// Given a resource as bytes, stores it and returns the memory handle. + fn create(&mut self, data: &[u8]) -> Handle; + + /// Reserves `size` bytes in the storage, and returns a handle over them. + fn empty(&mut self, size: usize) -> Handle; + + /// Executes the `kernel` over the given memory `handles`. + /// + /// Kernels have mutable access to every resource they are given + /// and are responsible of determining which should be read or written. + fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle]); + + /// Wait for the completion of every task in the server. + fn sync(&mut self); +} diff --git a/burn-compute/src/storage/base.rs b/burn-compute/src/storage/base.rs new file mode 100644 index 000000000..ce6be5bce --- /dev/null +++ b/burn-compute/src/storage/base.rs @@ -0,0 +1,48 @@ +use crate::storage_id_type; + +// This ID is used to map a handle to its actual data. +storage_id_type!(StorageId); + +/// Defines if data uses a full memory chunk or a slice of it. +#[derive(Clone)] +pub enum StorageUtilization { + /// Full memory chunk of specified size + Full(usize), + /// Slice of memory chunk with start index and size. + Slice(usize, usize), +} + +/// Contains the [storage id](StorageId) of a resource and the way it is used. +#[derive(new)] +pub struct StorageHandle { + /// Storage id. + pub id: StorageId, + /// How the storage is used. + pub utilization: StorageUtilization, +} + +impl StorageHandle { + /// Returns the size the handle is pointing to in memory. + pub fn size(&self) -> usize { + match self.utilization { + StorageUtilization::Full(size) => size, + StorageUtilization::Slice(_, size) => size, + } + } +} + +/// Storage types are responsible for allocating and deallocating memory. +pub trait ComputeStorage: Send { + /// The resource associated type determines the way data is implemented and how + /// it can be accessed by kernels. + type Resource: Send; + + /// Returns the underlying resource for a specified storage handle + fn get(&mut self, handle: &StorageHandle) -> Self::Resource; + + /// Allocates `size` units of memory and returns a handle to it + fn alloc(&mut self, size: usize) -> StorageHandle; + + /// Deallocates the memory pointed by the given storage id. + fn dealloc(&mut self, id: StorageId); +} diff --git a/burn-compute/src/storage/bytes_cpu.rs b/burn-compute/src/storage/bytes_cpu.rs new file mode 100644 index 000000000..a9b3f3753 --- /dev/null +++ b/burn-compute/src/storage/bytes_cpu.rs @@ -0,0 +1,121 @@ +use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; +use alloc::alloc::{alloc, dealloc, Layout}; +use hashbrown::HashMap; + +/// The bytes storage maps ids to pointers of bytes in a contiguous layout. +#[derive(Default)] +pub struct BytesStorage { + memory: HashMap, +} + +/// Can send to other threads, but can't sync. +unsafe impl Send for BytesStorage {} +unsafe impl Send for BytesResource {} + +/// This struct is a pointer to a memory chunk or slice. +pub struct BytesResource { + ptr: *mut u8, + utilization: StorageUtilization, +} + +/// This struct refers to a specific (contiguous) layout of bytes. +struct AllocatedBytes { + ptr: *mut u8, + layout: Layout, +} + +impl BytesResource { + fn get_exact_location_and_length(&self) -> (*mut u8, usize) { + match self.utilization { + StorageUtilization::Full(len) => (self.ptr, len), + StorageUtilization::Slice(location, len) => unsafe { (self.ptr.add(location), len) }, + } + } + + /// Returns the resource as a mutable slice of bytes. + pub fn write<'a>(&self) -> &'a mut [u8] { + let (ptr, len) = self.get_exact_location_and_length(); + + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + } + + /// Returns the resource as an immutable slice of bytes. + pub fn read<'a>(&self) -> &'a [u8] { + let (ptr, len) = self.get_exact_location_and_length(); + + unsafe { core::slice::from_raw_parts(ptr, len) } + } +} + +impl ComputeStorage for BytesStorage { + type Resource = BytesResource; + + fn get(&mut self, handle: &StorageHandle) -> Self::Resource { + let allocated_bytes = self.memory.get_mut(&handle.id).unwrap(); + + BytesResource { + ptr: allocated_bytes.ptr, + utilization: handle.utilization.clone(), + } + } + + fn alloc(&mut self, size: usize) -> StorageHandle { + let id = StorageId::new(); + let handle = StorageHandle { + id: id.clone(), + utilization: StorageUtilization::Full(size), + }; + + unsafe { + let layout = Layout::array::(size).unwrap(); + let ptr = alloc(layout); + let memory = AllocatedBytes { ptr, layout }; + + self.memory.insert(id, memory); + } + + handle + } + + fn dealloc(&mut self, id: StorageId) { + if let Some(memory) = self.memory.remove(&id) { + unsafe { + dealloc(memory.ptr, memory.layout); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_can_alloc_and_dealloc() { + let mut storage = BytesStorage::default(); + let handle_1 = storage.alloc(64); + + assert_eq!(handle_1.size(), 64); + storage.dealloc(handle_1.id); + } + + #[test] + fn test_slices() { + let mut storage = BytesStorage::default(); + let handle_1 = storage.alloc(64); + let handle_2 = StorageHandle::new(handle_1.id.clone(), StorageUtilization::Slice(24, 8)); + + storage + .get(&handle_1) + .write() + .iter_mut() + .enumerate() + .for_each(|(i, b)| { + *b = i as u8; + }); + + let bytes = storage.get(&handle_2).read().to_vec(); + storage.dealloc(handle_1.id); + assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]); + } +} diff --git a/burn-compute/src/storage/mod.rs b/burn-compute/src/storage/mod.rs new file mode 100644 index 000000000..0bf21bd31 --- /dev/null +++ b/burn-compute/src/storage/mod.rs @@ -0,0 +1,8 @@ +mod base; + +pub use base::*; + +#[cfg(feature = "storage-bytes")] +mod bytes_cpu; +#[cfg(feature = "storage-bytes")] +pub use bytes_cpu::*; diff --git a/burn-compute/tests/dummy/compute.rs b/burn-compute/tests/dummy/compute.rs new file mode 100644 index 000000000..dbb4d7555 --- /dev/null +++ b/burn-compute/tests/dummy/compute.rs @@ -0,0 +1,26 @@ +use super::DummyServer; +use burn_compute::channel::MutexComputeChannel; +use burn_compute::client::ComputeClient; +use burn_compute::memory_management::SimpleMemoryManagement; +use burn_compute::storage::BytesStorage; +use burn_compute::Compute; + +/// The dummy device. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct DummyDevice; + +static COMPUTE: Compute> = + Compute::new(); + +pub fn client( + device: &DummyDevice, +) -> ComputeClient> { + COMPUTE.client(device, || { + let storage = BytesStorage::default(); + let memory_management = SimpleMemoryManagement::never_dealloc(storage); + let server = DummyServer::new(memory_management); + let channel = MutexComputeChannel::new(server); + + ComputeClient::new(channel) + }) +} diff --git a/burn-compute/tests/dummy/kernel.rs b/burn-compute/tests/dummy/kernel.rs new file mode 100644 index 000000000..b8212c6c1 --- /dev/null +++ b/burn-compute/tests/dummy/kernel.rs @@ -0,0 +1,25 @@ +use burn_compute::storage::BytesResource; + +/// The DummyKernel trait should be implemented for every supported operation +pub trait DummyKernel: Send { + fn compute<'a>(&self, resources: &mut [BytesResource]); +} + +/// Contains the algorithm for element-wise addition +pub struct DummyElementwiseAddition; + +impl DummyKernel for DummyElementwiseAddition { + fn compute<'a>(&self, inputs: &mut [BytesResource]) { + // Notice how the kernel is responsible for determining which inputs + // are read-only and which are writable. + let lhs = &inputs[0].read(); + let rhs = &inputs[1].read(); + let out = &mut inputs[2].write(); + + let size = lhs.len(); + + for i in 0..size { + out[i] = lhs[i] + rhs[i]; + } + } +} diff --git a/burn-compute/tests/dummy/mod.rs b/burn-compute/tests/dummy/mod.rs new file mode 100644 index 000000000..347fb1c40 --- /dev/null +++ b/burn-compute/tests/dummy/mod.rs @@ -0,0 +1,7 @@ +mod compute; +mod kernel; +mod server; + +pub use compute::*; +pub use kernel::*; +pub use server::*; diff --git a/burn-compute/tests/dummy/server.rs b/burn-compute/tests/dummy/server.rs new file mode 100644 index 000000000..d7b0ea786 --- /dev/null +++ b/burn-compute/tests/dummy/server.rs @@ -0,0 +1,60 @@ +use burn_compute::{ + memory_management::{MemoryManagement, SimpleMemoryManagement}, + server::{ComputeServer, Handle}, + storage::BytesStorage, +}; +use derive_new::new; + +use super::DummyKernel; + +/// The dummy server is used to test the burn-compute infrastructure. +/// It uses simple memory management with a bytes storage on CPU, without asynchronous tasks. +#[derive(new)] +pub struct DummyServer> { + memory_management: MM, +} + +impl ComputeServer for DummyServer +where + MM: MemoryManagement, +{ + type Kernel = Box; + type Storage = BytesStorage; + type MemoryManagement = MM; + + fn read(&mut self, handle: &Handle) -> Vec { + let bytes = self.memory_management.get(handle); + + bytes.read().to_vec() + } + + fn create(&mut self, data: &[u8]) -> Handle { + let handle = self.memory_management.reserve(data.len()); + let resource = self.memory_management.get(&handle); + + let bytes = resource.write(); + + for (i, val) in data.iter().enumerate() { + bytes[i] = *val; + } + + handle + } + + fn empty(&mut self, size: usize) -> Handle { + self.memory_management.reserve(size) + } + + fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle]) { + let mut resources = handles + .iter() + .map(|handle| self.memory_management.get(handle)) + .collect::>(); + + kernel.compute(&mut resources); + } + + fn sync(&mut self) { + // Nothing to do with dummy backend. + } +} diff --git a/burn-compute/tests/integration_test.rs b/burn-compute/tests/integration_test.rs new file mode 100644 index 000000000..c951962d6 --- /dev/null +++ b/burn-compute/tests/integration_test.rs @@ -0,0 +1,38 @@ +mod dummy; + +use dummy::{client, DummyDevice, DummyElementwiseAddition}; + +#[test] +fn created_resource_is_the_same_when_read() { + let client = client(&DummyDevice); + let resource = Vec::from([0, 1, 2]); + let resource_description = client.create(&resource); + + let obtained_resource = client.read(&resource_description); + + assert_eq!(resource, obtained_resource) +} + +#[test] +fn empty_allocates_memory() { + let client = client(&DummyDevice); + let size = 4; + let resource_description = client.empty(size); + let empty_resource = client.read(&resource_description); + + assert_eq!(empty_resource.len(), 4); +} + +#[test] +fn execute_elementwise_addition() { + let client = client(&DummyDevice); + let lhs = client.create(&[0, 1, 2]); + let rhs = client.create(&[4, 4, 4]); + let out = client.empty(3); + + client.execute(Box::new(DummyElementwiseAddition), &[&lhs, &rhs, &out]); + + let obtained_resource = client.read(&out); + + assert_eq!(obtained_resource, Vec::from([4, 5, 6])) +} diff --git a/burn-wgpu/Cargo.toml b/burn-wgpu/Cargo.toml index f6f66c6a2..aac01e056 100644 --- a/burn-wgpu/Cargo.toml +++ b/burn-wgpu/Cargo.toml @@ -45,6 +45,10 @@ burn-tensor = { path = "../burn-tensor", version = "0.10.0", default-features = burn-ndarray = { path = "../burn-ndarray", version = "0.10.0" } serial_test = "2.0.0" +# Still only in dev mode +hashbrown = { workspace = true } +burn-compute = { path = "../burn-compute", version = "0.10.0", default-features = false, features=["channel-mutex", "std"] } + [[bench]] name = "unary" harness = false diff --git a/burn-wgpu/src/compute/base.rs b/burn-wgpu/src/compute/base.rs new file mode 100644 index 000000000..75f42ef27 --- /dev/null +++ b/burn-wgpu/src/compute/base.rs @@ -0,0 +1,129 @@ +use super::{Kernel, WgpuServer}; +use crate::{ + compute::WgpuStorage, + context::{select_device, WorkGroup}, + kernel::{DynamicKernel, SourceTemplate, StaticKernel}, + GraphicsApi, WgpuDevice, +}; +use burn_compute::{ + channel::MutexComputeChannel, + client::ComputeClient, + memory_management::{DeallocStrategy, SimpleMemoryManagement}, + Compute, +}; +use std::{marker::PhantomData, sync::Arc}; + +type WgpuChannel = MutexComputeChannel; + +/// Compute handle for the wgpu backend. +static COMPUTE: Compute = Compute::new(); + +pub fn compute_client( + device: &WgpuDevice, +) -> ComputeClient { + let device = Arc::new(device); + + COMPUTE.client(&device, move || { + let (device_wgpu, queue, info) = pollster::block_on(select_device::(&device)); + + log::info!( + "Created wgpu compute server on device {:?} => {:?}", + device, + info + ); + + // TODO: Support a way to modify max_tasks without std. + let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") { + Ok(value) => value + .parse::() + .expect("BURN_WGPU_MAX_TASKS should be a positive integer."), + Err(_) => 16, // 16 tasks by default + }; + + let device = Arc::new(device_wgpu); + let storage = WgpuStorage::new(device.clone()); + // Maximum reusability. + let memory_management = SimpleMemoryManagement::new(storage, DeallocStrategy::Never); + let server = WgpuServer::new(memory_management, device, queue, max_tasks); + let channel = WgpuChannel::new(server); + + ComputeClient::new(channel) + }) +} + +pub struct DynamicComputeKernel { + kernel: K, + workgroup: WorkGroup, +} + +impl Kernel for DynamicComputeKernel +where + K: DynamicKernel + 'static, +{ + fn source_template(self: Box) -> SourceTemplate { + self.kernel.source_template() + } + + fn id(&self) -> String { + self.kernel.id() + } + + fn workgroup(&self) -> WorkGroup { + self.workgroup.clone() + } +} + +#[derive(new)] +pub struct StaticComputeKernel { + workgroup: WorkGroup, + _kernel: PhantomData, +} + +impl Kernel for StaticComputeKernel +where + K: StaticKernel + 'static, +{ + fn source_template(self: Box) -> SourceTemplate { + K::source_template() + } + + fn id(&self) -> String { + format!("{:?}", core::any::TypeId::of::()) + } + + fn workgroup(&self) -> WorkGroup { + self.workgroup.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{binary_elemwise, kernel::KernelSettings, AutoGraphicsApi}; + + #[test] + fn can_run_kernel() { + binary_elemwise!(Add, "+"); + + let client = compute_client::(&WgpuDevice::default()); + + let lhs: Vec = vec![0., 1., 2., 3., 4., 5., 6., 7.]; + let rhs: Vec = vec![10., 11., 12., 6., 7., 3., 1., 0.]; + let info: Vec = vec![1, 1, 1, 1, 8, 8, 8]; + + let lhs = client.create(bytemuck::cast_slice(&lhs)); + let rhs = client.create(bytemuck::cast_slice(&rhs)); + let out = client.empty(core::mem::size_of::() * 8); + let info = client.create(bytemuck::cast_slice(&info)); + + type Kernel = KernelSettings; + let kernel = Box::new(StaticComputeKernel::::new(WorkGroup::new(1, 1, 1))); + + client.execute(kernel, &[&lhs, &rhs, &out, &info]); + + let data = client.read(&out); + let output: &[f32] = bytemuck::cast_slice(&data); + + assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]); + } +} diff --git a/burn-wgpu/src/compute/mod.rs b/burn-wgpu/src/compute/mod.rs new file mode 100644 index 000000000..3b14e2686 --- /dev/null +++ b/burn-wgpu/src/compute/mod.rs @@ -0,0 +1,7 @@ +mod base; +mod server; +mod storage; + +pub use base::*; +pub use server::*; +pub use storage::*; diff --git a/burn-wgpu/src/compute/server.rs b/burn-wgpu/src/compute/server.rs new file mode 100644 index 000000000..01893c255 --- /dev/null +++ b/burn-wgpu/src/compute/server.rs @@ -0,0 +1,253 @@ +use std::{borrow::Cow, sync::Arc}; + +use super::WgpuStorage; +use crate::{context::WorkGroup, kernel::SourceTemplate}; +use burn_compute::{ + memory_management::{MemoryManagement, SimpleMemoryManagement}, + server::{self, ComputeServer}, +}; +use hashbrown::HashMap; +use wgpu::{ + util::{BufferInitDescriptor, DeviceExt}, + BindGroup, CommandEncoder, ComputePipeline, ShaderModuleDescriptor, +}; + +/// Wgpu compute server. +pub struct WgpuServer> { + memory_management: MM, + device: Arc, + queue: wgpu::Queue, + encoder: CommandEncoder, + pipelines: HashMap>, + tasks: Vec, + max_tasks: usize, +} + +#[derive(new)] +struct ComputeTask { + pipeline: Arc, + bind_group: BindGroup, + work_group: WorkGroup, +} + +pub trait Kernel: 'static + Send { + /// Source template for the kernel. + fn source_template(self: Box) -> SourceTemplate; + /// Identifier for the kernel, used for caching kernel compilation. + fn id(&self) -> String; + fn workgroup(&self) -> WorkGroup; +} + +impl WgpuServer +where + MM: MemoryManagement, +{ + pub fn new( + memory_management: MM, + device: Arc, + queue: wgpu::Queue, + max_tasks: usize, + ) -> Self { + let encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Command Encoder"), + }); + + Self { + memory_management, + device, + queue, + encoder, + pipelines: HashMap::new(), + tasks: Vec::new(), + max_tasks, + } + } + + fn submit(&mut self) { + assert!( + self.tasks.is_empty(), + "Tasks should be completed before submitting the current encoder." + ); + println!("Submit"); + let mut new_encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + core::mem::swap(&mut new_encoder, &mut self.encoder); + + self.queue.submit(Some(new_encoder.finish())); + } + + fn register_tasks(&mut self) { + if self.tasks.is_empty() { + return; + } + + let mut compute = self + .encoder + .begin_compute_pass(&wgpu::ComputePassDescriptor { label: None }); + + for task in self.tasks.iter() { + compute.set_pipeline(&task.pipeline); + compute.set_bind_group(0, &task.bind_group, &[]); + compute.dispatch_workgroups(task.work_group.x, task.work_group.y, task.work_group.z); + } + + std::mem::drop(compute); + self.tasks.clear(); + } + + fn pipeline(&mut self, kernel: Box) -> Arc { + let kernel_id = kernel.id(); + if let Some(pipeline) = self.pipelines.get(&kernel_id) { + return pipeline.clone(); + } + + let pipeline = self.compile_source(&kernel.source_template().complete()); + self.pipelines.insert(kernel_id.clone(), pipeline.clone()); + + pipeline + } + + fn compile_source(&self, source: &str) -> Arc { + let module = self.device.create_shader_module(ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), + }); + + Arc::new( + self.device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: "main", + }), + ) + } +} + +impl ComputeServer for WgpuServer +where + MM: MemoryManagement, +{ + type Kernel = Box; + type Storage = WgpuStorage; + type MemoryManagement = MM; + + fn read(&mut self, handle: &server::Handle) -> Vec { + // Register previous tasks before reading the buffer so that it is up to date. + self.register_tasks(); + + let resource = self.memory_management.get(handle); + + let size = resource.size(); + let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + self.encoder.copy_buffer_to_buffer( + &resource.buffer, + resource.offset(), + &buffer_dest, + 0, + size, + ); + + self.submit(); + + let buffer_slice = buffer_dest.slice(..); + let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); + buffer_slice.map_async(wgpu::MapMode::Read, move |v| { + sender + .send(v) + .expect("Unable to send buffer slice result to async channel.") + }); + + self.device.poll(wgpu::Maintain::Wait); + + let result = pollster::block_on(receiver.receive()); + + if let Some(Ok(())) = result { + let data = buffer_slice.get_mapped_range(); + let result = bytemuck::cast_slice(&data).to_vec(); + + drop(data); + buffer_dest.unmap(); + result + } else { + panic!("Unable to read buffer {:?}", result) + } + } + + fn create(&mut self, data: &[u8]) -> server::Handle { + let handle = self.empty(data.len()); + + let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor { + label: Some("Buffer Src"), + contents: data, + usage: wgpu::BufferUsages::COPY_SRC, + })); + + let resource = self.memory_management.get(&handle); + + self.register_tasks(); + + self.encoder.copy_buffer_to_buffer( + &buffer_src, + 0, + &resource.buffer, + resource.offset(), + buffer_src.size(), + ); + + handle + } + + fn empty(&mut self, size: usize) -> server::Handle { + self.memory_management.reserve(size) + } + + fn execute(&mut self, kernel: Self::Kernel, handles: &[&server::Handle]) { + let work_group = kernel.workgroup(); + let pipeline = self.pipeline(kernel); + let group_layout = pipeline.get_bind_group_layout(0); + + let handles = handles + .iter() + .map(|handle| self.memory_management.get(handle)) + .collect::>(); + + let entries = handles + .iter() + .enumerate() + .map(|(i, buffer)| wgpu::BindGroupEntry { + binding: i as u32, + resource: buffer.as_binding(), + }) + .collect::>(); + + let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &group_layout, + entries: &entries, + }); + + self.tasks + .push(ComputeTask::new(pipeline, bind_group, work_group)); + + if self.tasks.len() >= self.max_tasks { + self.register_tasks(); + self.submit(); + } + } + + fn sync(&mut self) { + if !self.tasks.is_empty() { + self.register_tasks(); + self.submit(); + } + } +} diff --git a/burn-wgpu/src/compute/storage.rs b/burn-wgpu/src/compute/storage.rs new file mode 100644 index 000000000..5a8f09669 --- /dev/null +++ b/burn-wgpu/src/compute/storage.rs @@ -0,0 +1,102 @@ +use burn_compute::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; +use hashbrown::HashMap; +use std::{num::NonZeroU64, sync::Arc}; + +pub struct WgpuStorage { + memory: HashMap>, + device: Arc, +} + +#[derive(new, Debug)] +pub struct WgpuResource { + pub buffer: Arc, + pub kind: WgpuResourceKind, +} + +impl WgpuResource { + /// Return the binding view of the buffer. + pub fn as_binding(&self) -> wgpu::BindingResource { + let binding = match &self.kind { + WgpuResourceKind::Full => self.buffer.as_entire_buffer_binding(), + WgpuResourceKind::Slice(offs, size) => wgpu::BufferBinding { + buffer: &self.buffer, + offset: *offs, + size: Some(*size), + }, + }; + wgpu::BindingResource::Buffer(binding) + } + + /// Return the buffer size. + pub fn size(&self) -> u64 { + match self.kind { + WgpuResourceKind::Full => self.buffer.size(), + WgpuResourceKind::Slice(_, size) => size.get(), + } + } + + /// Return the buffer offset. + pub fn offset(&self) -> u64 { + match self.kind { + WgpuResourceKind::Full => 0, + WgpuResourceKind::Slice(offset, _) => offset, + } + } +} + +#[derive(Debug)] +pub enum WgpuResourceKind { + /// Represents an entire buffer. + Full, + /// A slice over a buffer. + Slice(wgpu::BufferAddress, wgpu::BufferSize), +} + +/// Keeps actual wgpu buffer references in a hashmap with ids as key. +impl WgpuStorage { + pub fn new(device: Arc) -> Self { + Self { + memory: HashMap::new(), + device, + } + } +} + +impl ComputeStorage for WgpuStorage { + type Resource = WgpuResource; + + fn get(&mut self, handle: &StorageHandle) -> Self::Resource { + let buffer = self.memory.get(&handle.id).unwrap(); + + match handle.utilization { + StorageUtilization::Full(_) => { + WgpuResource::new(buffer.clone(), WgpuResourceKind::Full) + } + StorageUtilization::Slice(offset, size) => WgpuResource::new( + buffer.clone(), + WgpuResourceKind::Slice(offset as u64, NonZeroU64::new(size as u64).unwrap()), + ), + } + } + + fn alloc(&mut self, size: usize) -> StorageHandle { + let id = StorageId::new(); + let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: size as u64, + usage: wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + })); + + self.memory.insert(id.clone(), buffer); + + StorageHandle::new(id, StorageUtilization::Full(size)) + } + + fn dealloc(&mut self, id: StorageId) { + self.memory.get(&id).unwrap().destroy(); + let _ = self.memory.remove(&id); + } +} diff --git a/burn-wgpu/src/context/base.rs b/burn-wgpu/src/context/base.rs index 47fe4e05e..ee207dbf3 100644 --- a/burn-wgpu/src/context/base.rs +++ b/burn-wgpu/src/context/base.rs @@ -270,7 +270,7 @@ impl PartialEq for Context { } } -async fn select_device( +pub(crate) async fn select_device( device: &WgpuDevice, ) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) { let adapter = select_adapter::(device); diff --git a/burn-wgpu/src/kernel/base.rs b/burn-wgpu/src/kernel/base.rs index 62535de46..bf68b623d 100644 --- a/burn-wgpu/src/kernel/base.rs +++ b/burn-wgpu/src/kernel/base.rs @@ -3,13 +3,13 @@ use crate::{context::WorkGroup, element::WgpuElement, tensor::WgpuTensor}; use std::marker::PhantomData; /// Static wgpu kernel to create a [source template](SourceTemplate). -pub trait StaticKernel: 'static { +pub trait StaticKernel: Send + 'static { /// Source template for the kernel. fn source_template() -> SourceTemplate; } /// Dynamic wgpu kernel to create a [source template](SourceTemplate). -pub trait DynamicKernel { +pub trait DynamicKernel: Send { /// Source template for the kernel. fn source_template(self) -> SourceTemplate; /// Identifier for the kernel, used for caching kernel compilation. diff --git a/burn-wgpu/src/lib.rs b/burn-wgpu/src/lib.rs index ce43182aa..650ce17c6 100644 --- a/burn-wgpu/src/lib.rs +++ b/burn-wgpu/src/lib.rs @@ -16,6 +16,10 @@ pub mod kernel; /// Tensor module. pub mod tensor; +#[cfg(test)] // Only enabled for dev for now. +/// Compute related module. +pub mod compute; + pub(crate) mod pool; pub(crate) mod tune; diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 63e9f0938..461e4737a 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -61,20 +61,14 @@ fn rustup(target: &str) { } // Define and run a cargo command -fn run_cargo(command: &str, first_params: &[&str], second_params: &[&str], error: &str) { +fn run_cargo(command: &str, params: Params, error: &str) { // Print cargo command - println!( - "\ncargo {} {} {}\n", - command, - first_params.join(" "), - second_params.join(" ") - ); + println!("\ncargo {} {}\n", command, params); // Run cargo let cargo = Command::new("cargo") .arg(command) - .args(first_params) - .args(second_params) + .args(params.params) .stdout(Stdio::inherit()) // Send stdout directly to terminal .stderr(Stdio::inherit()) // Send stderr directly to terminal .spawn() @@ -85,34 +79,31 @@ fn run_cargo(command: &str, first_params: &[&str], second_params: &[&str], error } // Run cargo build command -fn cargo_build(params: &[&str]) { +fn cargo_build(params: Params) { // Run cargo build run_cargo( "build", - params, - &["--color=always"], + params + "--color=always", "Failed to run cargo build", ); } // Run cargo install command -fn cargo_install(params: &[&str]) { +fn cargo_install(params: Params) { // Run cargo install run_cargo( "install", - params, - &["--color=always"], + params + "--color=always", "Failed to run cargo install", ); } // Run cargo test command -fn cargo_test(params: &[&str]) { +fn cargo_test(params: Params) { // Run cargo test run_cargo( "test", - params, - &["--color=always", "--", "--color=always"], + params + "--color=always" + "--" + "--color=always", "Failed to run cargo test", ); } @@ -122,8 +113,7 @@ fn cargo_fmt() { // Run cargo fmt run_cargo( "fmt", - &["--check", "--all"], - &["--", "--color=always"], + ["--check", "--all", "--", "--color=always"].into(), "Failed to run cargo fmt", ); } @@ -136,50 +126,48 @@ fn cargo_clippy() { // Run cargo clippy run_cargo( "clippy", - &["--color=always"], - &["--", "-D", "warnings"], + ["--color=always", "--", "-D", "warnings"].into(), "Failed to run cargo clippy", ); } // Run cargo doc command -fn cargo_doc(params: &[&str]) { +fn cargo_doc(params: Params) { // Run cargo doc - run_cargo( - "doc", - params, - &["--color=always"], - "Failed to run cargo doc", - ); + run_cargo("doc", params + "--color=always", "Failed to run cargo doc"); } // Build and test a crate in a no_std environment -fn build_and_test_no_std(crate_name: &str) { +fn build_and_test_no_std(crate_name: &str, extra_args: [&str; N]) { println!("\nRun checks for `{}` crate", crate_name); // Run cargo build --no-default-features - cargo_build(&["-p", crate_name, "--no-default-features"]); + cargo_build(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args); // Run cargo test --no-default-features - cargo_test(&["-p", crate_name, "--no-default-features"]); + cargo_test(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args); // Run cargo build --no-default-features --target wasm32-unknown-unknowns - cargo_build(&[ - "-p", - crate_name, - "--no-default-features", - "--target", - WASM32_TARGET, - ]); + cargo_build( + Params::from([ + "-p", + crate_name, + "--no-default-features", + "--target", + WASM32_TARGET, + ]) + extra_args, + ); // Run cargo build --no-default-features --target thumbv7m-none-eabi - cargo_build(&[ - "-p", - crate_name, - "--no-default-features", - "--target", - ARM_TARGET, - ]); + cargo_build( + Params::from([ + "-p", + crate_name, + "--no-default-features", + "--target", + ARM_TARGET, + ]) + extra_args, + ); } // Run no_std checks @@ -193,12 +181,16 @@ fn no_std_checks() { rustup(ARM_TARGET); // Run checks for the following crates - build_and_test_no_std("burn"); - build_and_test_no_std("burn-core"); - build_and_test_no_std("burn-common"); - build_and_test_no_std("burn-tensor"); - build_and_test_no_std("burn-ndarray"); - build_and_test_no_std("burn-no-std-tests"); + build_and_test_no_std("burn", []); + build_and_test_no_std("burn-core", []); + build_and_test_no_std( + "burn-compute", + ["--features", "channel-mutex storage-bytes"], + ); + build_and_test_no_std("burn-common", []); + build_and_test_no_std("burn-tensor", []); + build_and_test_no_std("burn-ndarray", []); + build_and_test_no_std("burn-no-std-tests", []); } // Test burn-core with tch and wgpu backend @@ -206,10 +198,10 @@ fn burn_core_std() { println!("\n\nRun checks for burn-core crate with tch and wgpu backend"); // Run cargo test --features test-tch - cargo_test(&["-p", "burn-core", "--features", "test-tch"]); + cargo_test(["-p", "burn-core", "--features", "test-tch"].into()); // Run cargo test --features test-wgpu - cargo_test(&["-p", "burn-core", "--features", "test-wgpu"]); + cargo_test(["-p", "burn-core", "--features", "test-wgpu"].into()); } // Test burn-dataset features @@ -217,13 +209,13 @@ fn burn_dataset_features_std() { println!("\n\nRun checks for burn-dataset features"); // Run cargo build --all-features - cargo_build(&["-p", "burn-dataset", "--all-features"]); + cargo_build(["-p", "burn-dataset", "--all-features"].into()); // Run cargo test --all-features - cargo_test(&["-p", "burn-dataset", "--all-features"]); + cargo_test(["-p", "burn-dataset", "--all-features"].into()); // Run cargo doc --all-features - cargo_doc(&["-p", "burn-dataset", "--all-features"]); + cargo_doc(["-p", "burn-dataset", "--all-features"].into()); } fn std_checks() { @@ -234,10 +226,10 @@ fn std_checks() { println!("Running std checks"); // Build each workspace - cargo_build(&["--workspace", "--exclude=xtask"]); + cargo_build(["--workspace", "--exclude=xtask"].into()); // Test each workspace - cargo_test(&["--workspace"]); + cargo_test(["--workspace"].into()); // Check format cargo_fmt(); @@ -246,7 +238,7 @@ fn std_checks() { cargo_clippy(); // Produce documentation for each workspace - cargo_doc(&["--workspace"]); + cargo_doc(["--workspace"].into()); // Test burn-dataset features burn_dataset_features_std(); @@ -257,7 +249,7 @@ fn std_checks() { fn check_typos() { // Install typos-cli - cargo_install(&["typos-cli", "--version", "1.16.5"]); + cargo_install(["typos-cli", "--version", "1.16.5"].into()); println!("Running typos check \n\n"); @@ -349,3 +341,39 @@ pub fn run(env: CheckType) -> anyhow::Result<()> { Ok(()) } + +struct Params { + params: Vec, +} + +impl From<[&str; N]> for Params { + fn from(value: [&str; N]) -> Self { + Self { + params: value.iter().map(|v| v.to_string()).collect(), + } + } +} + +impl From<&str> for Params { + fn from(value: &str) -> Self { + Self { + params: vec![value.to_string()], + } + } +} + +impl std::fmt::Display for Params { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.params.join(" ").as_str()) + } +} + +impl> std::ops::Add for Params { + type Output = Params; + + fn add(mut self, rhs: Rhs) -> Self::Output { + let rhs: Params = rhs.into(); + self.params.extend(rhs.params); + self + } +}