mirror of https://github.com/tracel-ai/burn.git
Burn compute (#809)
This commit is contained in:
parent
d7e9e75099
commit
ac4adb54ea
51
Cargo.toml
51
Cargo.toml
|
@ -7,6 +7,7 @@ members = [
|
|||
"burn",
|
||||
"burn-autodiff",
|
||||
"burn-common",
|
||||
"burn-compute",
|
||||
"burn-core",
|
||||
"burn-dataset",
|
||||
"burn-derive",
|
||||
|
@ -24,9 +25,7 @@ members = [
|
|||
"examples/*",
|
||||
]
|
||||
|
||||
exclude = [
|
||||
"examples/notebook",
|
||||
]
|
||||
exclude = ["examples/notebook"]
|
||||
|
||||
[workspace.dependencies]
|
||||
bytemuck = "1.13"
|
||||
|
@ -37,7 +36,7 @@ dirs = "5.0.1"
|
|||
fake = "2.6.1"
|
||||
flate2 = "1.0.26"
|
||||
float-cmp = "0.9.0"
|
||||
gix-tempfile = {version = "8.0.0", features = ["signals"]}
|
||||
gix-tempfile = { version = "8.0.0", features = ["signals"] }
|
||||
hashbrown = "0.14.0"
|
||||
indicatif = "0.17.5"
|
||||
libm = "0.2.7"
|
||||
|
@ -47,17 +46,17 @@ proc-macro2 = "1.0.60"
|
|||
protobuf-codegen = "3.2"
|
||||
quote = "1.0.28"
|
||||
r2d2 = "0.8.10"
|
||||
r2d2_sqlite = {version = "0.22.0"}
|
||||
r2d2_sqlite = { version = "0.22.0" }
|
||||
rayon = "1.7.0"
|
||||
rmp-serde = "1.1.1"
|
||||
rstest = "0.18.1"
|
||||
rusqlite = {version = "0.29"}
|
||||
rusqlite = { version = "0.29" }
|
||||
sanitize-filename = "0.5.0"
|
||||
serde_rusqlite = "0.33.1"
|
||||
spin = {version = "0.9.8", features = ["mutex", "spin_mutex"]}
|
||||
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
|
||||
strum = "0.24"
|
||||
strum_macros = "0.24"
|
||||
syn = {version = "2.0", features = ["full", "extra-traits"]}
|
||||
syn = { version = "2.0", features = ["full", "extra-traits"] }
|
||||
tempfile = "3.6.0"
|
||||
thiserror = "1.0.40"
|
||||
tracing-subscriber = "0.3.17"
|
||||
|
@ -67,19 +66,33 @@ tracing-appender = "0.2.2"
|
|||
# WGPU stuff
|
||||
futures-intrusive = "0.5"
|
||||
pollster = "0.3"
|
||||
text_placeholder = {version = "0.5.0", features = ["struct_context"]}
|
||||
text_placeholder = { version = "0.5.0", features = ["struct_context"] }
|
||||
wgpu = "0.17.0"
|
||||
|
||||
#
|
||||
# The following packages disable the "std" feature for no_std compatibility
|
||||
#
|
||||
bincode = {version = "2.0.0-rc.3", features = ["alloc", "serde"], default-features = false}
|
||||
derive-new = {version = "0.5.9", default-features = false}
|
||||
half = {version = "2.3.1", features = ["alloc", "num-traits", "serde"], default-features = false}
|
||||
ndarray = {version = "0.15.6", default-features = false}
|
||||
num-traits = {version = "0.2.15", default-features = false, features = ["libm"]}# libm is for no_std
|
||||
rand = {version = "0.8.5", default-features = false, features = ["std_rng"]}# std_rng is for no_std
|
||||
rand_distr = {version = "0.4.3", default-features = false}
|
||||
serde = {version = "1.0.164", default-features = false, features = ["derive", "alloc"]}# alloc is for no_std, derive is needed
|
||||
serde_json = {version = "1.0.96", default-features = false}
|
||||
uuid = {version = "1.3.4", default-features = false}
|
||||
bincode = { version = "2.0.0-rc.3", features = [
|
||||
"alloc",
|
||||
"serde",
|
||||
], default-features = false }
|
||||
derive-new = { version = "0.5.9", default-features = false }
|
||||
half = { version = "2.3.1", features = [
|
||||
"alloc",
|
||||
"num-traits",
|
||||
"serde",
|
||||
], default-features = false }
|
||||
ndarray = { version = "0.15.6", default-features = false }
|
||||
num-traits = { version = "0.2.15", default-features = false, features = [
|
||||
"libm",
|
||||
] } # libm is for no_std
|
||||
rand = { version = "0.8.5", default-features = false, features = [
|
||||
"std_rng",
|
||||
] } # std_rng is for no_std
|
||||
rand_distr = { version = "0.4.3", default-features = false }
|
||||
serde = { version = "1.0.164", default-features = false, features = [
|
||||
"derive",
|
||||
"alloc",
|
||||
] } # alloc is for no_std, derive is needed
|
||||
serde_json = { version = "1.0.96", default-features = false }
|
||||
uuid = { version = "1.3.4", default-features = false }
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
use alloc::string::{String, ToString};
|
||||
|
||||
use crate::rand::{get_seeded_rng, Rng, SEED};
|
||||
|
||||
use alloc::string::{String, ToString};
|
||||
use uuid::{Builder, Bytes};
|
||||
|
||||
/// Simple ID generator.
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
[package]
|
||||
authors = ["louisfd <louisfd94@gmail.com>", "Nathaniel Simard"]
|
||||
categories = ["science"]
|
||||
description = "Compute crate that helps creating high performance async backends."
|
||||
edition = "2021"
|
||||
keywords = ["deep-learning", "machine-learning", "data"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
name = "burn-compute"
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/burn-rs/burn/tree/main/burn-compute"
|
||||
version = "0.10.0"
|
||||
|
||||
[features]
|
||||
default = ["std", "channel-mutex", "channel-mpsc", "channel-cell", "storage-bytes"]
|
||||
std = []
|
||||
channel-mutex = []
|
||||
channel-cell = []
|
||||
channel-mpsc = [] # Assume std
|
||||
storage-bytes = []
|
||||
|
||||
[dependencies]
|
||||
burn-common = { path = "../burn-common", version = "0.10.0", default-features = false }
|
||||
derive-new = { workspace = true }
|
||||
spin = { workspace = true }
|
||||
hashbrown = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
|
@ -0,0 +1,7 @@
|
|||
# Burn Compute
|
||||
|
||||
This crate helps creating high performance async backends.
|
||||
|
||||
- [x] Asynchronous kernel executions
|
||||
- [x] Memory allocation management
|
||||
- [ ] Autotuning
|
|
@ -0,0 +1,21 @@
|
|||
use crate::server::{ComputeServer, Handle};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// The ComputeChannel trait links the ComputeClient to the ComputeServer
|
||||
/// while ensuring thread-safety
|
||||
pub trait ComputeChannel<Server: ComputeServer>: Clone {
|
||||
/// Given a handle, returns owned resource as bytes
|
||||
fn read(&self, handle: &Handle<Server>) -> Vec<u8>;
|
||||
|
||||
/// Given a resource as bytes, stores it and returns the resource handle
|
||||
fn create(&self, data: &[u8]) -> Handle<Server>;
|
||||
|
||||
/// Reserves `size` bytes in the storage, and returns a handle over them
|
||||
fn empty(&self, size: usize) -> Handle<Server>;
|
||||
|
||||
/// Executes the `kernel` over the given `handles`.
|
||||
fn execute(&self, kernel: Server::Kernel, handles: &[&Handle<Server>]);
|
||||
|
||||
/// Wait for the completion of every task in the server.
|
||||
fn sync(&self);
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
use super::ComputeChannel;
|
||||
use crate::server::{ComputeServer, Handle};
|
||||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// A channel using a [ref cell](core::cell::RefCell) to access the server with mutability.
|
||||
///
|
||||
/// # Important
|
||||
///
|
||||
/// Only use this channel if you don't use any threading in your application, otherwise it will
|
||||
/// panic or cause undefined behaviors.
|
||||
///
|
||||
/// This is mosly useful for `no-std` environments where threads aren't supported, otherwise prefer
|
||||
/// the [mutex](super::MutexComputeChannel) or the [mpsc](super::MpscComputeChannel) channels.
|
||||
pub struct RefCellComputeChannel<Server> {
|
||||
server: Arc<core::cell::RefCell<Server>>,
|
||||
}
|
||||
|
||||
impl<S> Clone for RefCellComputeChannel<S> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
server: self.server.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<Server> RefCellComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
/// Create a new cell compute channel.
|
||||
pub fn new(server: Server) -> Self {
|
||||
Self {
|
||||
server: Arc::new(core::cell::RefCell::new(server)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server> ComputeChannel<Server> for RefCellComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
|
||||
let mut server = self.server.borrow_mut();
|
||||
|
||||
server.read(handle)
|
||||
}
|
||||
|
||||
fn create(&self, resource: &[u8]) -> Handle<Server> {
|
||||
self.server.borrow_mut().create(resource)
|
||||
}
|
||||
|
||||
fn empty(&self, size: usize) -> Handle<Server> {
|
||||
self.server.borrow_mut().empty(size)
|
||||
}
|
||||
|
||||
fn execute(&self, kernel_description: Server::Kernel, handles: &[&Handle<Server>]) {
|
||||
self.server
|
||||
.borrow_mut()
|
||||
.execute(kernel_description, handles)
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
self.server.borrow_mut().sync()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
mod base;
|
||||
pub use base::*;
|
||||
|
||||
#[cfg(feature = "channel-mutex")]
|
||||
mod mutex;
|
||||
#[cfg(feature = "channel-mutex")]
|
||||
pub use mutex::*;
|
||||
|
||||
#[cfg(feature = "channel-mpsc")]
|
||||
mod mpsc;
|
||||
#[cfg(feature = "channel-mpsc")]
|
||||
pub use mpsc::*;
|
||||
|
||||
#[cfg(feature = "channel-cell")]
|
||||
mod cell;
|
||||
#[cfg(feature = "channel-cell")]
|
||||
pub use cell::*;
|
|
@ -0,0 +1,154 @@
|
|||
use std::{
|
||||
sync::{mpsc, Arc},
|
||||
thread,
|
||||
};
|
||||
|
||||
use super::ComputeChannel;
|
||||
use crate::server::{ComputeServer, Handle};
|
||||
|
||||
/// Create a channel using the [multi-producer, single-consumer channel](mpsc) to communicate with
|
||||
/// the compute server spawn on its own thread.
|
||||
pub struct MpscComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
state: Arc<MpscComputeChannelState<Server>>,
|
||||
}
|
||||
|
||||
struct MpscComputeChannelState<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
_handle: thread::JoinHandle<()>,
|
||||
sender: mpsc::SyncSender<Message<Server>>,
|
||||
}
|
||||
|
||||
type Callback<Response> = mpsc::SyncSender<Response>;
|
||||
|
||||
enum Message<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
Read(Handle<Server>, Callback<Vec<u8>>),
|
||||
Create(Vec<u8>, Callback<Handle<Server>>),
|
||||
Empty(usize, Callback<Handle<Server>>),
|
||||
Execute(Server::Kernel, Vec<Handle<Server>>),
|
||||
Sync(Callback<()>),
|
||||
}
|
||||
|
||||
impl<Server> MpscComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer + 'static,
|
||||
{
|
||||
/// Create a new mpsc compute channel.
|
||||
pub fn new(mut server: Server, bound: usize) -> Self {
|
||||
let (sender, receiver) = mpsc::sync_channel(bound);
|
||||
|
||||
let _handle = thread::spawn(move || {
|
||||
while let Ok(message) = receiver.recv() {
|
||||
match message {
|
||||
Message::Read(handle, callback) => {
|
||||
let data = server.read(&handle);
|
||||
core::mem::drop(handle);
|
||||
callback.send(data).unwrap();
|
||||
}
|
||||
Message::Create(data, callback) => {
|
||||
let handle = server.create(&data);
|
||||
callback.send(handle).unwrap();
|
||||
}
|
||||
Message::Empty(size, callback) => {
|
||||
let handle = server.empty(size);
|
||||
callback.send(handle).unwrap();
|
||||
}
|
||||
Message::Execute(kernel, handles) => {
|
||||
server.execute(kernel, &handles.iter().collect::<Vec<_>>());
|
||||
}
|
||||
Message::Sync(callback) => {
|
||||
server.sync();
|
||||
callback.send(()).unwrap();
|
||||
}
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
let state = Arc::new(MpscComputeChannelState { sender, _handle });
|
||||
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server: ComputeServer> Clone for MpscComputeChannel<Server> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
state: self.state.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer + 'static,
|
||||
{
|
||||
fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
|
||||
let (callback, response) = mpsc::sync_channel(1);
|
||||
|
||||
self.state
|
||||
.sender
|
||||
.send(Message::Read(handle.clone(), callback))
|
||||
.unwrap();
|
||||
|
||||
self.response(response)
|
||||
}
|
||||
|
||||
fn create(&self, data: &[u8]) -> Handle<Server> {
|
||||
let (callback, response) = mpsc::sync_channel(1);
|
||||
|
||||
self.state
|
||||
.sender
|
||||
.send(Message::Create(data.to_vec(), callback))
|
||||
.unwrap();
|
||||
|
||||
self.response(response)
|
||||
}
|
||||
|
||||
fn empty(&self, size: usize) -> Handle<Server> {
|
||||
let (callback, response) = mpsc::sync_channel(1);
|
||||
|
||||
self.state
|
||||
.sender
|
||||
.send(Message::Empty(size, callback))
|
||||
.unwrap();
|
||||
|
||||
self.response(response)
|
||||
}
|
||||
|
||||
fn execute(&self, kernel: Server::Kernel, handles: &[&Handle<Server>]) {
|
||||
self.state
|
||||
.sender
|
||||
.send(Message::Execute(
|
||||
kernel,
|
||||
handles
|
||||
.iter()
|
||||
.map(|h| (*h).clone())
|
||||
.collect::<Vec<Handle<Server>>>(),
|
||||
))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
let (callback, response) = mpsc::sync_channel(1);
|
||||
|
||||
self.state.sender.send(Message::Sync(callback)).unwrap();
|
||||
|
||||
self.response(response)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server: ComputeServer> MpscComputeChannel<Server> {
|
||||
fn response<Response>(&self, response: mpsc::Receiver<Response>) -> Response {
|
||||
match response.recv() {
|
||||
Ok(val) => val,
|
||||
Err(err) => panic!("Can't connect to the server correctly {err:?}"),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
use super::ComputeChannel;
|
||||
use crate::server::{ComputeServer, Handle};
|
||||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
use spin::Mutex;
|
||||
|
||||
/// The MutexComputeChannel ensures thread-safety by locking the server
|
||||
/// on every operation
|
||||
pub struct MutexComputeChannel<Server> {
|
||||
server: Arc<Mutex<Server>>,
|
||||
}
|
||||
|
||||
impl<S> Clone for MutexComputeChannel<S> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
server: self.server.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<Server> MutexComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
/// Create a new mutex compute channel.
|
||||
pub fn new(server: Server) -> Self {
|
||||
Self {
|
||||
server: Arc::new(Mutex::new(server)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
|
||||
self.server.lock().read(handle)
|
||||
}
|
||||
|
||||
fn create(&self, data: &[u8]) -> Handle<Server> {
|
||||
self.server.lock().create(data)
|
||||
}
|
||||
|
||||
fn empty(&self, size: usize) -> Handle<Server> {
|
||||
self.server.lock().empty(size)
|
||||
}
|
||||
|
||||
fn execute(&self, kernel: Server::Kernel, handles: &[&Handle<Server>]) {
|
||||
self.server.lock().execute(kernel, handles)
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
self.server.lock().sync()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
use crate::{
|
||||
channel::ComputeChannel,
|
||||
server::{ComputeServer, Handle},
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
|
||||
/// It should be obtained for a specific device via the Compute struct.
|
||||
pub struct ComputeClient<Server, Channel> {
|
||||
channel: Channel,
|
||||
_server: PhantomData<Server>,
|
||||
}
|
||||
|
||||
impl<S, C> Clone for ComputeClient<S, C>
|
||||
where
|
||||
S: ComputeServer,
|
||||
C: ComputeChannel<S>,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
channel: self.channel.clone(),
|
||||
_server: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server, Channel> ComputeClient<Server, Channel>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
Channel: ComputeChannel<Server>,
|
||||
{
|
||||
/// Create a new client.
|
||||
pub fn new(channel: Channel) -> Self {
|
||||
Self {
|
||||
channel,
|
||||
_server: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Given a handle, returns owned resource as bytes.
|
||||
pub fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
|
||||
self.channel.read(handle)
|
||||
}
|
||||
|
||||
/// Given a resource, stores it and returns the resource handle.
|
||||
pub fn create(&self, data: &[u8]) -> Handle<Server> {
|
||||
self.channel.create(data)
|
||||
}
|
||||
|
||||
/// Reserves `size` bytes in the storage, and returns a handle over them.
|
||||
pub fn empty(&self, size: usize) -> Handle<Server> {
|
||||
self.channel.empty(size)
|
||||
}
|
||||
|
||||
/// Executes the `kernel` over the given `handles`.
|
||||
pub fn execute(&self, kernel: Server::Kernel, handles: &[&Handle<Server>]) {
|
||||
self.channel.execute(kernel, handles)
|
||||
}
|
||||
|
||||
/// Wait for the completion of every task in the server.
|
||||
pub fn sync(&self) {
|
||||
self.channel.sync()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
use crate::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
|
||||
use core::ops::DerefMut;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// The compute type has the responsibility to retrieve the correct compute client based on the
|
||||
/// given device.
|
||||
pub struct Compute<Device, Server, Channel> {
|
||||
clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
|
||||
}
|
||||
|
||||
impl<Device, Server, Channel> Compute<Device, Server, Channel>
|
||||
where
|
||||
Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
|
||||
Server: ComputeServer,
|
||||
Channel: ComputeChannel<Server>,
|
||||
{
|
||||
/// Create a new compute.
|
||||
pub const fn new() -> Self {
|
||||
Self {
|
||||
clients: spin::Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the compute client for the given device.
|
||||
///
|
||||
/// Provide the init function to create a new client if it isn't already initialized.
|
||||
pub fn client<Init>(&self, device: &Device, init: Init) -> ComputeClient<Server, Channel>
|
||||
where
|
||||
Init: Fn() -> ComputeClient<Server, Channel>,
|
||||
{
|
||||
let mut clients = self.clients.lock();
|
||||
|
||||
if clients.is_none() {
|
||||
Self::register_inner(device, init(), &mut clients);
|
||||
}
|
||||
|
||||
match clients.deref_mut() {
|
||||
Some(clients) => match clients.get(device) {
|
||||
Some(client) => client.clone(),
|
||||
None => {
|
||||
let client = init();
|
||||
clients.insert(device.clone(), client.clone());
|
||||
client
|
||||
}
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register the compute client for the given device.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This function is mostly useful when the creation of the compute client can't be done
|
||||
/// synchronously and require special context.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If a client is already registered for the given device.
|
||||
pub fn register(&self, device: &Device, client: ComputeClient<Server, Channel>) {
|
||||
let mut clients = self.clients.lock();
|
||||
|
||||
Self::register_inner(device, client, &mut clients);
|
||||
}
|
||||
|
||||
fn register_inner(
|
||||
device: &Device,
|
||||
client: ComputeClient<Server, Channel>,
|
||||
clients: &mut Option<HashMap<Device, ComputeClient<Server, Channel>>>,
|
||||
) {
|
||||
if clients.is_none() {
|
||||
*clients = Some(HashMap::new());
|
||||
}
|
||||
|
||||
if let Some(clients) = clients {
|
||||
if clients.contains_key(device) {
|
||||
panic!("Client already created for device {:?}", device);
|
||||
}
|
||||
|
||||
clients.insert(device.clone(), client);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
#[macro_export(local_inner_macros)]
|
||||
/// Create a new storage ID type.
|
||||
macro_rules! storage_id_type {
|
||||
($name:ident) => {
|
||||
#[derive(Clone, Hash, PartialEq, Eq)]
|
||||
/// Storage ID.
|
||||
pub struct $name {
|
||||
id: alloc::sync::Arc<alloc::string::String>,
|
||||
}
|
||||
|
||||
impl $name {
|
||||
/// Create a new ID.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for $name {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export(local_inner_macros)]
|
||||
/// Create a new memory ID type.
|
||||
macro_rules! memory_id_type {
|
||||
($name:ident) => {
|
||||
#[derive(Clone, Hash, PartialEq, Eq)]
|
||||
/// Memory ID.
|
||||
pub struct $name {
|
||||
id: alloc::sync::Arc<alloc::string::String>,
|
||||
}
|
||||
|
||||
impl $name {
|
||||
/// Create a new ID.
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for $name {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
|
||||
//! Burn compute crate that helps creating high performance async backends.
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
mod id;
|
||||
|
||||
/// Compute channel module.
|
||||
pub mod channel;
|
||||
/// Compute client module.
|
||||
pub mod client;
|
||||
|
||||
/// Memory management module.
|
||||
pub mod memory_management;
|
||||
/// Compute server module.
|
||||
pub mod server;
|
||||
/// Compute Storage module.
|
||||
pub mod storage;
|
||||
|
||||
mod compute;
|
||||
pub use compute::*;
|
|
@ -0,0 +1,27 @@
|
|||
use crate::storage::ComputeStorage;
|
||||
|
||||
/// The MemoryHandle trait is an abstract way to refer to some memory segment.
|
||||
/// It should not contain actual references to data.
|
||||
///
|
||||
/// It is responsible for determining if the memory segment can be mutated,
|
||||
/// for instance by keeping track of a reference count
|
||||
pub trait MemoryHandle: Clone + Send {
|
||||
/// Checks if the underlying memory can be safely mutated.
|
||||
fn can_mut(&self) -> bool;
|
||||
}
|
||||
|
||||
/// The MemoryManagement trait encapsulates strategies for (de)allocating memory.
|
||||
/// It is bound to the ComputeStorage trait, which does the actual (de)allocations.
|
||||
///
|
||||
/// The MemoryManagement can only reserve memory space or get the resource located at a space.
|
||||
/// Modification of the resource data should be done directly on the resource.
|
||||
pub trait MemoryManagement<Storage: ComputeStorage>: Send {
|
||||
/// The associated type Handle must implement MemoryHandle
|
||||
type Handle: MemoryHandle;
|
||||
|
||||
/// Returns the resource from the storage at the specified handle
|
||||
fn get(&mut self, handle: &Self::Handle) -> Storage::Resource;
|
||||
|
||||
/// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it
|
||||
fn reserve(&mut self, size: usize) -> Self::Handle;
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
mod base;
|
||||
mod simple;
|
||||
|
||||
pub use base::*;
|
||||
pub use simple::*;
|
|
@ -0,0 +1,339 @@
|
|||
use super::{MemoryHandle, MemoryManagement};
|
||||
use crate::{
|
||||
memory_id_type,
|
||||
storage::{ComputeStorage, StorageHandle, StorageUtilization},
|
||||
};
|
||||
use alloc::{sync::Arc, vec::Vec};
|
||||
use hashbrown::HashMap;
|
||||
|
||||
// The ChunkId allows to keep track of how many references there are to a specific chunk.
|
||||
memory_id_type!(ChunkId);
|
||||
// The SliceId allows to keep track of how many references there are to a specific slice.
|
||||
memory_id_type!(SliceId);
|
||||
|
||||
impl ChunkId {
|
||||
/// A chunk is free if it is only referred by the chunk hashmap.
|
||||
fn is_free(&self) -> bool {
|
||||
Arc::strong_count(&self.id) <= 1
|
||||
}
|
||||
}
|
||||
|
||||
impl SliceId {
|
||||
/// A slice is free if it is only referred by the slice hashmap and the chunk it is in.
|
||||
fn is_free(&self) -> bool {
|
||||
Arc::strong_count(&self.id) <= 2
|
||||
}
|
||||
}
|
||||
|
||||
/// The SimpleHandle is a memory handle, referring to either a chunk or a slice.
|
||||
#[derive(Clone)]
|
||||
pub enum SimpleHandle {
|
||||
/// A whole chunk of memory.
|
||||
Chunk(ChunkId),
|
||||
/// A slice of a chunk of memory.
|
||||
Slice(SliceId),
|
||||
}
|
||||
|
||||
/// The strategy defines the frequency at which deallocation of unused memory chunks should occur.
|
||||
pub enum DeallocStrategy {
|
||||
/// Once every n calls to reserve.
|
||||
///
|
||||
/// First associated data is n, second is the state and should start at 0
|
||||
PeriodTick(usize, usize),
|
||||
#[cfg(feature = "std")]
|
||||
/// Once every period of time
|
||||
PeriodTime(std::time::Duration, std::time::Instant),
|
||||
/// Never deallocate.
|
||||
Never,
|
||||
}
|
||||
|
||||
impl DeallocStrategy {
|
||||
/// Create a new strategy with the given period.
|
||||
pub fn new_period_tick(period: usize) -> Self {
|
||||
DeallocStrategy::PeriodTick(period, 0)
|
||||
}
|
||||
|
||||
fn should_dealloc(&mut self) -> bool {
|
||||
match self {
|
||||
DeallocStrategy::PeriodTick(period, last) => {
|
||||
*last = (*last + 1) % *period;
|
||||
*last == 0
|
||||
}
|
||||
#[cfg(feature = "std")]
|
||||
DeallocStrategy::PeriodTime(period, last) => {
|
||||
if &last.elapsed() > period {
|
||||
*last = std::time::Instant::now();
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
DeallocStrategy::Never => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
|
||||
pub struct SimpleMemoryManagement<Storage> {
|
||||
chunks: HashMap<ChunkId, (StorageHandle, Vec<SliceId>)>,
|
||||
slices: HashMap<SliceId, (StorageHandle, ChunkId)>,
|
||||
dealloc_strategy: DeallocStrategy,
|
||||
storage: Storage,
|
||||
}
|
||||
|
||||
impl MemoryHandle for SimpleHandle {
|
||||
/// Returns true if referenced by only one tensor, and only once by the
|
||||
/// memory management hashmaps
|
||||
fn can_mut(&self) -> bool {
|
||||
// One reference in the chunk hashmap, another owned by one tensor.
|
||||
const REFERENCE_LIMIT_CHUNK: usize = 2;
|
||||
// One reference in the chunk hashmap (for the chunk on which this slice is built),
|
||||
// another in the slice hashmap for this slice, and another owned by one tensor.
|
||||
const REFERENCE_LIMIT_SLICE: usize = 3;
|
||||
|
||||
match &self {
|
||||
SimpleHandle::Chunk(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_CHUNK,
|
||||
SimpleHandle::Slice(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_SLICE,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Storage: ComputeStorage> MemoryManagement<Storage> for SimpleMemoryManagement<Storage> {
|
||||
type Handle = SimpleHandle;
|
||||
|
||||
/// Returns the resource from the storage, for the specified handle.
|
||||
fn get(&mut self, handle: &Self::Handle) -> Storage::Resource {
|
||||
let resource = match &handle {
|
||||
SimpleHandle::Chunk(id) => &self.chunks.get(id).unwrap().0,
|
||||
SimpleHandle::Slice(id) => &self.slices.get(id).unwrap().0,
|
||||
};
|
||||
|
||||
self.storage.get(resource)
|
||||
}
|
||||
|
||||
/// Reserves memory of specified size using the reserve algorithm, and return
|
||||
/// a handle to the reserved memory.
|
||||
///
|
||||
/// Also clean ups, removing unused slices, and chunks if permitted by deallocation strategy.
|
||||
fn reserve(&mut self, size: usize) -> Self::Handle {
|
||||
self.cleanup_slices();
|
||||
|
||||
let handle = self.reserve_algorithm(size);
|
||||
|
||||
if self.dealloc_strategy.should_dealloc() {
|
||||
self.cleanup_chunks();
|
||||
}
|
||||
|
||||
handle
|
||||
}
|
||||
}
|
||||
|
||||
impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
|
||||
/// Creates a new instance using the given storage and deallocation strategy.
|
||||
pub fn new(storage: Storage, dealloc_strategy: DeallocStrategy) -> Self {
|
||||
Self {
|
||||
chunks: HashMap::new(),
|
||||
slices: HashMap::new(),
|
||||
dealloc_strategy,
|
||||
storage,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an new instance using the given storage without deallocation.
|
||||
pub fn never_dealloc(storage: Storage) -> Self {
|
||||
Self::new(storage, DeallocStrategy::Never)
|
||||
}
|
||||
|
||||
fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle {
|
||||
// Looks for a large enough, existing but unused chunk of memory.
|
||||
let chunk = self.find_free_chunk(size);
|
||||
|
||||
match chunk {
|
||||
Some((chunk_id, chunk_size)) => {
|
||||
if size == chunk_size {
|
||||
// If there is one of exactly the same size, it reuses it.
|
||||
SimpleHandle::Chunk(chunk_id.clone())
|
||||
} else {
|
||||
// Otherwise creates a slice of the right size upon it, always starting at zero.
|
||||
self.create_slice(size, chunk_id)
|
||||
}
|
||||
}
|
||||
// If no chunk available, creates one of exactly the right size.
|
||||
None => self.create_chunk(size),
|
||||
}
|
||||
}
|
||||
|
||||
/// Finds the smallest of the free and large enough chunks to fit `size`
|
||||
/// Returns the chunk's id and size.
|
||||
fn find_free_chunk(&self, size: usize) -> Option<(ChunkId, usize)> {
|
||||
let mut size_diff_current = usize::MAX;
|
||||
let mut current = None;
|
||||
|
||||
self.chunks
|
||||
.iter()
|
||||
.for_each(|(chunk_id, (resource, slices))| {
|
||||
let is_free = slices.is_empty() && chunk_id.is_free();
|
||||
|
||||
if is_free && resource.size() > size {
|
||||
let size_diff = resource.size() - size;
|
||||
if size_diff < size_diff_current {
|
||||
current = Some((chunk_id, resource));
|
||||
size_diff_current = size_diff;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
current.map(|(id, handle)| (id.clone(), handle.size()))
|
||||
}
|
||||
|
||||
/// Creates a slice of size `size` upon the given chunk.
|
||||
///
|
||||
/// For now slices must start at zero, therefore there can be only one per chunk
|
||||
fn create_slice(&mut self, size: usize, chunk_id: ChunkId) -> SimpleHandle {
|
||||
let (handle, slices) = self.chunks.get_mut(&chunk_id).unwrap();
|
||||
let slice_id = SliceId::new();
|
||||
|
||||
let storage = StorageHandle {
|
||||
id: handle.id.clone(),
|
||||
utilization: StorageUtilization::Slice(0, size),
|
||||
};
|
||||
|
||||
if slices.is_empty() {
|
||||
self.slices.insert(slice_id.clone(), (storage, chunk_id));
|
||||
} else {
|
||||
panic!("Can't have more than 1 slice yet.");
|
||||
}
|
||||
|
||||
slices.push(slice_id.clone());
|
||||
|
||||
SimpleHandle::Slice(slice_id)
|
||||
}
|
||||
|
||||
/// Creates a chunk of given size by allocating on the storage.
|
||||
fn create_chunk(&mut self, size: usize) -> SimpleHandle {
|
||||
let resource = self.storage.alloc(size);
|
||||
let chunk_id = ChunkId::new();
|
||||
|
||||
self.chunks.insert(chunk_id.clone(), (resource, Vec::new()));
|
||||
|
||||
SimpleHandle::Chunk(chunk_id)
|
||||
}
|
||||
|
||||
/// Deallocates free chunks and remove them from chunks map.
|
||||
fn cleanup_chunks(&mut self) {
|
||||
let mut ids_to_remove = Vec::new();
|
||||
|
||||
self.chunks.iter().for_each(|(chunk_id, _resource)| {
|
||||
if chunk_id.is_free() {
|
||||
ids_to_remove.push(chunk_id.clone());
|
||||
}
|
||||
});
|
||||
|
||||
ids_to_remove
|
||||
.iter()
|
||||
.map(|chunk_id| self.chunks.remove(chunk_id).unwrap())
|
||||
.for_each(|(resource, _slices)| {
|
||||
self.storage.dealloc(resource.id);
|
||||
});
|
||||
}
|
||||
|
||||
/// Removes free slices from slice map and corresponding chunks.
|
||||
fn cleanup_slices(&mut self) {
|
||||
let mut ids_to_remove = Vec::new();
|
||||
|
||||
self.slices.iter().for_each(|(slice_id, _resource)| {
|
||||
if slice_id.is_free() {
|
||||
ids_to_remove.push(slice_id.clone());
|
||||
}
|
||||
});
|
||||
|
||||
ids_to_remove
|
||||
.iter()
|
||||
.map(|slice_id| {
|
||||
let value = self.slices.remove(slice_id).unwrap();
|
||||
(slice_id, value.1)
|
||||
})
|
||||
.for_each(|(slice_id, chunk_id)| {
|
||||
let (_chunk, slices) = self.chunks.get_mut(&chunk_id).unwrap();
|
||||
slices.retain(|id| id != slice_id);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
memory_management::{MemoryHandle, MemoryManagement},
|
||||
storage::BytesStorage,
|
||||
};
|
||||
|
||||
use super::{DeallocStrategy, SimpleMemoryManagement};
|
||||
|
||||
#[test]
|
||||
fn can_mut_with_single_tensor_reference() {
|
||||
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
|
||||
|
||||
let chunk_size = 4;
|
||||
let simple_handle = memory_management.create_chunk(chunk_size);
|
||||
|
||||
let x = simple_handle.clone();
|
||||
core::mem::drop(simple_handle);
|
||||
|
||||
assert!(x.can_mut());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_tensor_references_remove_mutability() {
|
||||
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
|
||||
|
||||
let chunk_size = 4;
|
||||
let simple_handle = memory_management.create_chunk(chunk_size);
|
||||
|
||||
let x = simple_handle.clone();
|
||||
|
||||
assert!(!simple_handle.can_mut());
|
||||
assert!(!x.can_mut())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() {
|
||||
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
|
||||
let chunk_size = 4;
|
||||
let _chunk_handle = memory_management.reserve(chunk_size);
|
||||
let _new_handle = memory_management.reserve(chunk_size);
|
||||
|
||||
assert_eq!(memory_management.chunks.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn when_empty_chunk_is_cleaned_upexists_it_disappears() {
|
||||
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
|
||||
let chunk_size = 4;
|
||||
let chunk_handle = memory_management.reserve(chunk_size);
|
||||
drop(chunk_handle);
|
||||
memory_management.cleanup_chunks();
|
||||
|
||||
assert_eq!(memory_management.chunks.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn never_dealloc_strategy_never_deallocs() {
|
||||
let mut never_dealloc = DeallocStrategy::Never;
|
||||
for _ in 0..20 {
|
||||
assert!(!never_dealloc.should_dealloc())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn period_tick_dealloc_strategy_should_dealloc_after_period() {
|
||||
let period = 3;
|
||||
let mut period_tick_dealloc = DeallocStrategy::new_period_tick(period);
|
||||
|
||||
for _ in 0..3 {
|
||||
for _ in 0..period - 1 {
|
||||
assert!(!period_tick_dealloc.should_dealloc());
|
||||
}
|
||||
assert!(period_tick_dealloc.should_dealloc());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use crate::{memory_management::MemoryManagement, storage::ComputeStorage};
|
||||
|
||||
type _Storage<Server> = <Server as ComputeServer>::Storage;
|
||||
type _MemoryManagement<Server> = <Server as ComputeServer>::MemoryManagement;
|
||||
|
||||
/// This alias for a [memory handle](MemoryManagement::Handle).
|
||||
pub type Handle<Server> = <_MemoryManagement<Server> as MemoryManagement<_Storage<Server>>>::Handle;
|
||||
|
||||
/// The compute server is responsible for handling resources and computations over resources.
|
||||
///
|
||||
/// Everything in the server is mutable, therefore it should be solely accessed through the
|
||||
/// [compute channel](crate::channel::ComputeChannel) for thread safety.
|
||||
pub trait ComputeServer: Send {
|
||||
/// The kernel type defines the computation algorithms.
|
||||
type Kernel: Send;
|
||||
/// The [storage](ComputeStorage) type defines how data is stored and accessed.
|
||||
type Storage: ComputeStorage;
|
||||
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
|
||||
type MemoryManagement: MemoryManagement<Self::Storage>;
|
||||
|
||||
/// Given a handle, returns the owned resource as bytes.
|
||||
fn read(&mut self, handle: &Handle<Self>) -> Vec<u8>;
|
||||
|
||||
/// Given a resource as bytes, stores it and returns the memory handle.
|
||||
fn create(&mut self, data: &[u8]) -> Handle<Self>;
|
||||
|
||||
/// Reserves `size` bytes in the storage, and returns a handle over them.
|
||||
fn empty(&mut self, size: usize) -> Handle<Self>;
|
||||
|
||||
/// Executes the `kernel` over the given memory `handles`.
|
||||
///
|
||||
/// Kernels have mutable access to every resource they are given
|
||||
/// and are responsible of determining which should be read or written.
|
||||
fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle<Self>]);
|
||||
|
||||
/// Wait for the completion of every task in the server.
|
||||
fn sync(&mut self);
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
use crate::storage_id_type;
|
||||
|
||||
// This ID is used to map a handle to its actual data.
|
||||
storage_id_type!(StorageId);
|
||||
|
||||
/// Defines if data uses a full memory chunk or a slice of it.
|
||||
#[derive(Clone)]
|
||||
pub enum StorageUtilization {
|
||||
/// Full memory chunk of specified size
|
||||
Full(usize),
|
||||
/// Slice of memory chunk with start index and size.
|
||||
Slice(usize, usize),
|
||||
}
|
||||
|
||||
/// Contains the [storage id](StorageId) of a resource and the way it is used.
|
||||
#[derive(new)]
|
||||
pub struct StorageHandle {
|
||||
/// Storage id.
|
||||
pub id: StorageId,
|
||||
/// How the storage is used.
|
||||
pub utilization: StorageUtilization,
|
||||
}
|
||||
|
||||
impl StorageHandle {
|
||||
/// Returns the size the handle is pointing to in memory.
|
||||
pub fn size(&self) -> usize {
|
||||
match self.utilization {
|
||||
StorageUtilization::Full(size) => size,
|
||||
StorageUtilization::Slice(_, size) => size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Storage types are responsible for allocating and deallocating memory.
|
||||
pub trait ComputeStorage: Send {
|
||||
/// The resource associated type determines the way data is implemented and how
|
||||
/// it can be accessed by kernels.
|
||||
type Resource: Send;
|
||||
|
||||
/// Returns the underlying resource for a specified storage handle
|
||||
fn get(&mut self, handle: &StorageHandle) -> Self::Resource;
|
||||
|
||||
/// Allocates `size` units of memory and returns a handle to it
|
||||
fn alloc(&mut self, size: usize) -> StorageHandle;
|
||||
|
||||
/// Deallocates the memory pointed by the given storage id.
|
||||
fn dealloc(&mut self, id: StorageId);
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
|
||||
use alloc::alloc::{alloc, dealloc, Layout};
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// The bytes storage maps ids to pointers of bytes in a contiguous layout.
|
||||
#[derive(Default)]
|
||||
pub struct BytesStorage {
|
||||
memory: HashMap<StorageId, AllocatedBytes>,
|
||||
}
|
||||
|
||||
/// Can send to other threads, but can't sync.
|
||||
unsafe impl Send for BytesStorage {}
|
||||
unsafe impl Send for BytesResource {}
|
||||
|
||||
/// This struct is a pointer to a memory chunk or slice.
|
||||
pub struct BytesResource {
|
||||
ptr: *mut u8,
|
||||
utilization: StorageUtilization,
|
||||
}
|
||||
|
||||
/// This struct refers to a specific (contiguous) layout of bytes.
|
||||
struct AllocatedBytes {
|
||||
ptr: *mut u8,
|
||||
layout: Layout,
|
||||
}
|
||||
|
||||
impl BytesResource {
|
||||
fn get_exact_location_and_length(&self) -> (*mut u8, usize) {
|
||||
match self.utilization {
|
||||
StorageUtilization::Full(len) => (self.ptr, len),
|
||||
StorageUtilization::Slice(location, len) => unsafe { (self.ptr.add(location), len) },
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the resource as a mutable slice of bytes.
|
||||
pub fn write<'a>(&self) -> &'a mut [u8] {
|
||||
let (ptr, len) = self.get_exact_location_and_length();
|
||||
|
||||
unsafe { core::slice::from_raw_parts_mut(ptr, len) }
|
||||
}
|
||||
|
||||
/// Returns the resource as an immutable slice of bytes.
|
||||
pub fn read<'a>(&self) -> &'a [u8] {
|
||||
let (ptr, len) = self.get_exact_location_and_length();
|
||||
|
||||
unsafe { core::slice::from_raw_parts(ptr, len) }
|
||||
}
|
||||
}
|
||||
|
||||
impl ComputeStorage for BytesStorage {
|
||||
type Resource = BytesResource;
|
||||
|
||||
fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
|
||||
let allocated_bytes = self.memory.get_mut(&handle.id).unwrap();
|
||||
|
||||
BytesResource {
|
||||
ptr: allocated_bytes.ptr,
|
||||
utilization: handle.utilization.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn alloc(&mut self, size: usize) -> StorageHandle {
|
||||
let id = StorageId::new();
|
||||
let handle = StorageHandle {
|
||||
id: id.clone(),
|
||||
utilization: StorageUtilization::Full(size),
|
||||
};
|
||||
|
||||
unsafe {
|
||||
let layout = Layout::array::<u8>(size).unwrap();
|
||||
let ptr = alloc(layout);
|
||||
let memory = AllocatedBytes { ptr, layout };
|
||||
|
||||
self.memory.insert(id, memory);
|
||||
}
|
||||
|
||||
handle
|
||||
}
|
||||
|
||||
fn dealloc(&mut self, id: StorageId) {
|
||||
if let Some(memory) = self.memory.remove(&id) {
|
||||
unsafe {
|
||||
dealloc(memory.ptr, memory.layout);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_can_alloc_and_dealloc() {
|
||||
let mut storage = BytesStorage::default();
|
||||
let handle_1 = storage.alloc(64);
|
||||
|
||||
assert_eq!(handle_1.size(), 64);
|
||||
storage.dealloc(handle_1.id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slices() {
|
||||
let mut storage = BytesStorage::default();
|
||||
let handle_1 = storage.alloc(64);
|
||||
let handle_2 = StorageHandle::new(handle_1.id.clone(), StorageUtilization::Slice(24, 8));
|
||||
|
||||
storage
|
||||
.get(&handle_1)
|
||||
.write()
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, b)| {
|
||||
*b = i as u8;
|
||||
});
|
||||
|
||||
let bytes = storage.get(&handle_2).read().to_vec();
|
||||
storage.dealloc(handle_1.id);
|
||||
assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
mod base;
|
||||
|
||||
pub use base::*;
|
||||
|
||||
#[cfg(feature = "storage-bytes")]
|
||||
mod bytes_cpu;
|
||||
#[cfg(feature = "storage-bytes")]
|
||||
pub use bytes_cpu::*;
|
|
@ -0,0 +1,26 @@
|
|||
use super::DummyServer;
|
||||
use burn_compute::channel::MutexComputeChannel;
|
||||
use burn_compute::client::ComputeClient;
|
||||
use burn_compute::memory_management::SimpleMemoryManagement;
|
||||
use burn_compute::storage::BytesStorage;
|
||||
use burn_compute::Compute;
|
||||
|
||||
/// The dummy device.
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||
pub struct DummyDevice;
|
||||
|
||||
static COMPUTE: Compute<DummyDevice, DummyServer, MutexComputeChannel<DummyServer>> =
|
||||
Compute::new();
|
||||
|
||||
pub fn client(
|
||||
device: &DummyDevice,
|
||||
) -> ComputeClient<DummyServer, MutexComputeChannel<DummyServer>> {
|
||||
COMPUTE.client(device, || {
|
||||
let storage = BytesStorage::default();
|
||||
let memory_management = SimpleMemoryManagement::never_dealloc(storage);
|
||||
let server = DummyServer::new(memory_management);
|
||||
let channel = MutexComputeChannel::new(server);
|
||||
|
||||
ComputeClient::new(channel)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
use burn_compute::storage::BytesResource;
|
||||
|
||||
/// The DummyKernel trait should be implemented for every supported operation
|
||||
pub trait DummyKernel: Send {
|
||||
fn compute<'a>(&self, resources: &mut [BytesResource]);
|
||||
}
|
||||
|
||||
/// Contains the algorithm for element-wise addition
|
||||
pub struct DummyElementwiseAddition;
|
||||
|
||||
impl DummyKernel for DummyElementwiseAddition {
|
||||
fn compute<'a>(&self, inputs: &mut [BytesResource]) {
|
||||
// Notice how the kernel is responsible for determining which inputs
|
||||
// are read-only and which are writable.
|
||||
let lhs = &inputs[0].read();
|
||||
let rhs = &inputs[1].read();
|
||||
let out = &mut inputs[2].write();
|
||||
|
||||
let size = lhs.len();
|
||||
|
||||
for i in 0..size {
|
||||
out[i] = lhs[i] + rhs[i];
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
mod compute;
|
||||
mod kernel;
|
||||
mod server;
|
||||
|
||||
pub use compute::*;
|
||||
pub use kernel::*;
|
||||
pub use server::*;
|
|
@ -0,0 +1,60 @@
|
|||
use burn_compute::{
|
||||
memory_management::{MemoryManagement, SimpleMemoryManagement},
|
||||
server::{ComputeServer, Handle},
|
||||
storage::BytesStorage,
|
||||
};
|
||||
use derive_new::new;
|
||||
|
||||
use super::DummyKernel;
|
||||
|
||||
/// The dummy server is used to test the burn-compute infrastructure.
|
||||
/// It uses simple memory management with a bytes storage on CPU, without asynchronous tasks.
|
||||
#[derive(new)]
|
||||
pub struct DummyServer<MM = SimpleMemoryManagement<BytesStorage>> {
|
||||
memory_management: MM,
|
||||
}
|
||||
|
||||
impl<MM> ComputeServer for DummyServer<MM>
|
||||
where
|
||||
MM: MemoryManagement<BytesStorage>,
|
||||
{
|
||||
type Kernel = Box<dyn DummyKernel>;
|
||||
type Storage = BytesStorage;
|
||||
type MemoryManagement = MM;
|
||||
|
||||
fn read(&mut self, handle: &Handle<Self>) -> Vec<u8> {
|
||||
let bytes = self.memory_management.get(handle);
|
||||
|
||||
bytes.read().to_vec()
|
||||
}
|
||||
|
||||
fn create(&mut self, data: &[u8]) -> Handle<Self> {
|
||||
let handle = self.memory_management.reserve(data.len());
|
||||
let resource = self.memory_management.get(&handle);
|
||||
|
||||
let bytes = resource.write();
|
||||
|
||||
for (i, val) in data.iter().enumerate() {
|
||||
bytes[i] = *val;
|
||||
}
|
||||
|
||||
handle
|
||||
}
|
||||
|
||||
fn empty(&mut self, size: usize) -> Handle<Self> {
|
||||
self.memory_management.reserve(size)
|
||||
}
|
||||
|
||||
fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle<Self>]) {
|
||||
let mut resources = handles
|
||||
.iter()
|
||||
.map(|handle| self.memory_management.get(handle))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
kernel.compute(&mut resources);
|
||||
}
|
||||
|
||||
fn sync(&mut self) {
|
||||
// Nothing to do with dummy backend.
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
mod dummy;
|
||||
|
||||
use dummy::{client, DummyDevice, DummyElementwiseAddition};
|
||||
|
||||
#[test]
|
||||
fn created_resource_is_the_same_when_read() {
|
||||
let client = client(&DummyDevice);
|
||||
let resource = Vec::from([0, 1, 2]);
|
||||
let resource_description = client.create(&resource);
|
||||
|
||||
let obtained_resource = client.read(&resource_description);
|
||||
|
||||
assert_eq!(resource, obtained_resource)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_allocates_memory() {
|
||||
let client = client(&DummyDevice);
|
||||
let size = 4;
|
||||
let resource_description = client.empty(size);
|
||||
let empty_resource = client.read(&resource_description);
|
||||
|
||||
assert_eq!(empty_resource.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn execute_elementwise_addition() {
|
||||
let client = client(&DummyDevice);
|
||||
let lhs = client.create(&[0, 1, 2]);
|
||||
let rhs = client.create(&[4, 4, 4]);
|
||||
let out = client.empty(3);
|
||||
|
||||
client.execute(Box::new(DummyElementwiseAddition), &[&lhs, &rhs, &out]);
|
||||
|
||||
let obtained_resource = client.read(&out);
|
||||
|
||||
assert_eq!(obtained_resource, Vec::from([4, 5, 6]))
|
||||
}
|
|
@ -45,6 +45,10 @@ burn-tensor = { path = "../burn-tensor", version = "0.10.0", default-features =
|
|||
burn-ndarray = { path = "../burn-ndarray", version = "0.10.0" }
|
||||
serial_test = "2.0.0"
|
||||
|
||||
# Still only in dev mode
|
||||
hashbrown = { workspace = true }
|
||||
burn-compute = { path = "../burn-compute", version = "0.10.0", default-features = false, features=["channel-mutex", "std"] }
|
||||
|
||||
[[bench]]
|
||||
name = "unary"
|
||||
harness = false
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
use super::{Kernel, WgpuServer};
|
||||
use crate::{
|
||||
compute::WgpuStorage,
|
||||
context::{select_device, WorkGroup},
|
||||
kernel::{DynamicKernel, SourceTemplate, StaticKernel},
|
||||
GraphicsApi, WgpuDevice,
|
||||
};
|
||||
use burn_compute::{
|
||||
channel::MutexComputeChannel,
|
||||
client::ComputeClient,
|
||||
memory_management::{DeallocStrategy, SimpleMemoryManagement},
|
||||
Compute,
|
||||
};
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
|
||||
type WgpuChannel = MutexComputeChannel<WgpuServer>;
|
||||
|
||||
/// Compute handle for the wgpu backend.
|
||||
static COMPUTE: Compute<WgpuDevice, WgpuServer, WgpuChannel> = Compute::new();
|
||||
|
||||
pub fn compute_client<G: GraphicsApi>(
|
||||
device: &WgpuDevice,
|
||||
) -> ComputeClient<WgpuServer, WgpuChannel> {
|
||||
let device = Arc::new(device);
|
||||
|
||||
COMPUTE.client(&device, move || {
|
||||
let (device_wgpu, queue, info) = pollster::block_on(select_device::<G>(&device));
|
||||
|
||||
log::info!(
|
||||
"Created wgpu compute server on device {:?} => {:?}",
|
||||
device,
|
||||
info
|
||||
);
|
||||
|
||||
// TODO: Support a way to modify max_tasks without std.
|
||||
let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") {
|
||||
Ok(value) => value
|
||||
.parse::<usize>()
|
||||
.expect("BURN_WGPU_MAX_TASKS should be a positive integer."),
|
||||
Err(_) => 16, // 16 tasks by default
|
||||
};
|
||||
|
||||
let device = Arc::new(device_wgpu);
|
||||
let storage = WgpuStorage::new(device.clone());
|
||||
// Maximum reusability.
|
||||
let memory_management = SimpleMemoryManagement::new(storage, DeallocStrategy::Never);
|
||||
let server = WgpuServer::new(memory_management, device, queue, max_tasks);
|
||||
let channel = WgpuChannel::new(server);
|
||||
|
||||
ComputeClient::new(channel)
|
||||
})
|
||||
}
|
||||
|
||||
pub struct DynamicComputeKernel<K> {
|
||||
kernel: K,
|
||||
workgroup: WorkGroup,
|
||||
}
|
||||
|
||||
impl<K> Kernel for DynamicComputeKernel<K>
|
||||
where
|
||||
K: DynamicKernel + 'static,
|
||||
{
|
||||
fn source_template(self: Box<Self>) -> SourceTemplate {
|
||||
self.kernel.source_template()
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
self.kernel.id()
|
||||
}
|
||||
|
||||
fn workgroup(&self) -> WorkGroup {
|
||||
self.workgroup.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct StaticComputeKernel<K> {
|
||||
workgroup: WorkGroup,
|
||||
_kernel: PhantomData<K>,
|
||||
}
|
||||
|
||||
impl<K> Kernel for StaticComputeKernel<K>
|
||||
where
|
||||
K: StaticKernel + 'static,
|
||||
{
|
||||
fn source_template(self: Box<Self>) -> SourceTemplate {
|
||||
K::source_template()
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}", core::any::TypeId::of::<K>())
|
||||
}
|
||||
|
||||
fn workgroup(&self) -> WorkGroup {
|
||||
self.workgroup.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{binary_elemwise, kernel::KernelSettings, AutoGraphicsApi};
|
||||
|
||||
#[test]
|
||||
fn can_run_kernel() {
|
||||
binary_elemwise!(Add, "+");
|
||||
|
||||
let client = compute_client::<AutoGraphicsApi>(&WgpuDevice::default());
|
||||
|
||||
let lhs: Vec<f32> = vec![0., 1., 2., 3., 4., 5., 6., 7.];
|
||||
let rhs: Vec<f32> = vec![10., 11., 12., 6., 7., 3., 1., 0.];
|
||||
let info: Vec<u32> = vec![1, 1, 1, 1, 8, 8, 8];
|
||||
|
||||
let lhs = client.create(bytemuck::cast_slice(&lhs));
|
||||
let rhs = client.create(bytemuck::cast_slice(&rhs));
|
||||
let out = client.empty(core::mem::size_of::<f32>() * 8);
|
||||
let info = client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
type Kernel = KernelSettings<Add, f32, i32, 16, 16, 1>;
|
||||
let kernel = Box::new(StaticComputeKernel::<Kernel>::new(WorkGroup::new(1, 1, 1)));
|
||||
|
||||
client.execute(kernel, &[&lhs, &rhs, &out, &info]);
|
||||
|
||||
let data = client.read(&out);
|
||||
let output: &[f32] = bytemuck::cast_slice(&data);
|
||||
|
||||
assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
mod base;
|
||||
mod server;
|
||||
mod storage;
|
||||
|
||||
pub use base::*;
|
||||
pub use server::*;
|
||||
pub use storage::*;
|
|
@ -0,0 +1,253 @@
|
|||
use std::{borrow::Cow, sync::Arc};
|
||||
|
||||
use super::WgpuStorage;
|
||||
use crate::{context::WorkGroup, kernel::SourceTemplate};
|
||||
use burn_compute::{
|
||||
memory_management::{MemoryManagement, SimpleMemoryManagement},
|
||||
server::{self, ComputeServer},
|
||||
};
|
||||
use hashbrown::HashMap;
|
||||
use wgpu::{
|
||||
util::{BufferInitDescriptor, DeviceExt},
|
||||
BindGroup, CommandEncoder, ComputePipeline, ShaderModuleDescriptor,
|
||||
};
|
||||
|
||||
/// Wgpu compute server.
|
||||
pub struct WgpuServer<MM = SimpleMemoryManagement<WgpuStorage>> {
|
||||
memory_management: MM,
|
||||
device: Arc<wgpu::Device>,
|
||||
queue: wgpu::Queue,
|
||||
encoder: CommandEncoder,
|
||||
pipelines: HashMap<String, Arc<ComputePipeline>>,
|
||||
tasks: Vec<ComputeTask>,
|
||||
max_tasks: usize,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct ComputeTask {
|
||||
pipeline: Arc<ComputePipeline>,
|
||||
bind_group: BindGroup,
|
||||
work_group: WorkGroup,
|
||||
}
|
||||
|
||||
pub trait Kernel: 'static + Send {
|
||||
/// Source template for the kernel.
|
||||
fn source_template(self: Box<Self>) -> SourceTemplate;
|
||||
/// Identifier for the kernel, used for caching kernel compilation.
|
||||
fn id(&self) -> String;
|
||||
fn workgroup(&self) -> WorkGroup;
|
||||
}
|
||||
|
||||
impl<MM> WgpuServer<MM>
|
||||
where
|
||||
MM: MemoryManagement<WgpuStorage>,
|
||||
{
|
||||
pub fn new(
|
||||
memory_management: MM,
|
||||
device: Arc<wgpu::Device>,
|
||||
queue: wgpu::Queue,
|
||||
max_tasks: usize,
|
||||
) -> Self {
|
||||
let encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("Command Encoder"),
|
||||
});
|
||||
|
||||
Self {
|
||||
memory_management,
|
||||
device,
|
||||
queue,
|
||||
encoder,
|
||||
pipelines: HashMap::new(),
|
||||
tasks: Vec::new(),
|
||||
max_tasks,
|
||||
}
|
||||
}
|
||||
|
||||
fn submit(&mut self) {
|
||||
assert!(
|
||||
self.tasks.is_empty(),
|
||||
"Tasks should be completed before submitting the current encoder."
|
||||
);
|
||||
println!("Submit");
|
||||
let mut new_encoder = self
|
||||
.device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||
core::mem::swap(&mut new_encoder, &mut self.encoder);
|
||||
|
||||
self.queue.submit(Some(new_encoder.finish()));
|
||||
}
|
||||
|
||||
fn register_tasks(&mut self) {
|
||||
if self.tasks.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut compute = self
|
||||
.encoder
|
||||
.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
|
||||
|
||||
for task in self.tasks.iter() {
|
||||
compute.set_pipeline(&task.pipeline);
|
||||
compute.set_bind_group(0, &task.bind_group, &[]);
|
||||
compute.dispatch_workgroups(task.work_group.x, task.work_group.y, task.work_group.z);
|
||||
}
|
||||
|
||||
std::mem::drop(compute);
|
||||
self.tasks.clear();
|
||||
}
|
||||
|
||||
fn pipeline(&mut self, kernel: Box<dyn Kernel>) -> Arc<ComputePipeline> {
|
||||
let kernel_id = kernel.id();
|
||||
if let Some(pipeline) = self.pipelines.get(&kernel_id) {
|
||||
return pipeline.clone();
|
||||
}
|
||||
|
||||
let pipeline = self.compile_source(&kernel.source_template().complete());
|
||||
self.pipelines.insert(kernel_id.clone(), pipeline.clone());
|
||||
|
||||
pipeline
|
||||
}
|
||||
|
||||
fn compile_source(&self, source: &str) -> Arc<ComputePipeline> {
|
||||
let module = self.device.create_shader_module(ShaderModuleDescriptor {
|
||||
label: None,
|
||||
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
|
||||
});
|
||||
|
||||
Arc::new(
|
||||
self.device
|
||||
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: None,
|
||||
layout: None,
|
||||
module: &module,
|
||||
entry_point: "main",
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<MM> ComputeServer for WgpuServer<MM>
|
||||
where
|
||||
MM: MemoryManagement<WgpuStorage>,
|
||||
{
|
||||
type Kernel = Box<dyn Kernel>;
|
||||
type Storage = WgpuStorage;
|
||||
type MemoryManagement = MM;
|
||||
|
||||
fn read(&mut self, handle: &server::Handle<Self>) -> Vec<u8> {
|
||||
// Register previous tasks before reading the buffer so that it is up to date.
|
||||
self.register_tasks();
|
||||
|
||||
let resource = self.memory_management.get(handle);
|
||||
|
||||
let size = resource.size();
|
||||
let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: None,
|
||||
size,
|
||||
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
self.encoder.copy_buffer_to_buffer(
|
||||
&resource.buffer,
|
||||
resource.offset(),
|
||||
&buffer_dest,
|
||||
0,
|
||||
size,
|
||||
);
|
||||
|
||||
self.submit();
|
||||
|
||||
let buffer_slice = buffer_dest.slice(..);
|
||||
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
|
||||
sender
|
||||
.send(v)
|
||||
.expect("Unable to send buffer slice result to async channel.")
|
||||
});
|
||||
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
|
||||
let result = pollster::block_on(receiver.receive());
|
||||
|
||||
if let Some(Ok(())) = result {
|
||||
let data = buffer_slice.get_mapped_range();
|
||||
let result = bytemuck::cast_slice(&data).to_vec();
|
||||
|
||||
drop(data);
|
||||
buffer_dest.unmap();
|
||||
result
|
||||
} else {
|
||||
panic!("Unable to read buffer {:?}", result)
|
||||
}
|
||||
}
|
||||
|
||||
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
|
||||
let handle = self.empty(data.len());
|
||||
|
||||
let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor {
|
||||
label: Some("Buffer Src"),
|
||||
contents: data,
|
||||
usage: wgpu::BufferUsages::COPY_SRC,
|
||||
}));
|
||||
|
||||
let resource = self.memory_management.get(&handle);
|
||||
|
||||
self.register_tasks();
|
||||
|
||||
self.encoder.copy_buffer_to_buffer(
|
||||
&buffer_src,
|
||||
0,
|
||||
&resource.buffer,
|
||||
resource.offset(),
|
||||
buffer_src.size(),
|
||||
);
|
||||
|
||||
handle
|
||||
}
|
||||
|
||||
fn empty(&mut self, size: usize) -> server::Handle<Self> {
|
||||
self.memory_management.reserve(size)
|
||||
}
|
||||
|
||||
fn execute(&mut self, kernel: Self::Kernel, handles: &[&server::Handle<Self>]) {
|
||||
let work_group = kernel.workgroup();
|
||||
let pipeline = self.pipeline(kernel);
|
||||
let group_layout = pipeline.get_bind_group_layout(0);
|
||||
|
||||
let handles = handles
|
||||
.iter()
|
||||
.map(|handle| self.memory_management.get(handle))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let entries = handles
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, buffer)| wgpu::BindGroupEntry {
|
||||
binding: i as u32,
|
||||
resource: buffer.as_binding(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
label: None,
|
||||
layout: &group_layout,
|
||||
entries: &entries,
|
||||
});
|
||||
|
||||
self.tasks
|
||||
.push(ComputeTask::new(pipeline, bind_group, work_group));
|
||||
|
||||
if self.tasks.len() >= self.max_tasks {
|
||||
self.register_tasks();
|
||||
self.submit();
|
||||
}
|
||||
}
|
||||
|
||||
fn sync(&mut self) {
|
||||
if !self.tasks.is_empty() {
|
||||
self.register_tasks();
|
||||
self.submit();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,102 @@
|
|||
use burn_compute::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
|
||||
use hashbrown::HashMap;
|
||||
use std::{num::NonZeroU64, sync::Arc};
|
||||
|
||||
pub struct WgpuStorage {
|
||||
memory: HashMap<StorageId, Arc<wgpu::Buffer>>,
|
||||
device: Arc<wgpu::Device>,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub struct WgpuResource {
|
||||
pub buffer: Arc<wgpu::Buffer>,
|
||||
pub kind: WgpuResourceKind,
|
||||
}
|
||||
|
||||
impl WgpuResource {
|
||||
/// Return the binding view of the buffer.
|
||||
pub fn as_binding(&self) -> wgpu::BindingResource {
|
||||
let binding = match &self.kind {
|
||||
WgpuResourceKind::Full => self.buffer.as_entire_buffer_binding(),
|
||||
WgpuResourceKind::Slice(offs, size) => wgpu::BufferBinding {
|
||||
buffer: &self.buffer,
|
||||
offset: *offs,
|
||||
size: Some(*size),
|
||||
},
|
||||
};
|
||||
wgpu::BindingResource::Buffer(binding)
|
||||
}
|
||||
|
||||
/// Return the buffer size.
|
||||
pub fn size(&self) -> u64 {
|
||||
match self.kind {
|
||||
WgpuResourceKind::Full => self.buffer.size(),
|
||||
WgpuResourceKind::Slice(_, size) => size.get(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the buffer offset.
|
||||
pub fn offset(&self) -> u64 {
|
||||
match self.kind {
|
||||
WgpuResourceKind::Full => 0,
|
||||
WgpuResourceKind::Slice(offset, _) => offset,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum WgpuResourceKind {
|
||||
/// Represents an entire buffer.
|
||||
Full,
|
||||
/// A slice over a buffer.
|
||||
Slice(wgpu::BufferAddress, wgpu::BufferSize),
|
||||
}
|
||||
|
||||
/// Keeps actual wgpu buffer references in a hashmap with ids as key.
|
||||
impl WgpuStorage {
|
||||
pub fn new(device: Arc<wgpu::Device>) -> Self {
|
||||
Self {
|
||||
memory: HashMap::new(),
|
||||
device,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ComputeStorage for WgpuStorage {
|
||||
type Resource = WgpuResource;
|
||||
|
||||
fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
|
||||
let buffer = self.memory.get(&handle.id).unwrap();
|
||||
|
||||
match handle.utilization {
|
||||
StorageUtilization::Full(_) => {
|
||||
WgpuResource::new(buffer.clone(), WgpuResourceKind::Full)
|
||||
}
|
||||
StorageUtilization::Slice(offset, size) => WgpuResource::new(
|
||||
buffer.clone(),
|
||||
WgpuResourceKind::Slice(offset as u64, NonZeroU64::new(size as u64).unwrap()),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn alloc(&mut self, size: usize) -> StorageHandle {
|
||||
let id = StorageId::new();
|
||||
let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: None,
|
||||
size: size as u64,
|
||||
usage: wgpu::BufferUsages::COPY_DST
|
||||
| wgpu::BufferUsages::STORAGE
|
||||
| wgpu::BufferUsages::COPY_SRC,
|
||||
mapped_at_creation: false,
|
||||
}));
|
||||
|
||||
self.memory.insert(id.clone(), buffer);
|
||||
|
||||
StorageHandle::new(id, StorageUtilization::Full(size))
|
||||
}
|
||||
|
||||
fn dealloc(&mut self, id: StorageId) {
|
||||
self.memory.get(&id).unwrap().destroy();
|
||||
let _ = self.memory.remove(&id);
|
||||
}
|
||||
}
|
|
@ -270,7 +270,7 @@ impl PartialEq for Context {
|
|||
}
|
||||
}
|
||||
|
||||
async fn select_device<G: GraphicsApi>(
|
||||
pub(crate) async fn select_device<G: GraphicsApi>(
|
||||
device: &WgpuDevice,
|
||||
) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) {
|
||||
let adapter = select_adapter::<G>(device);
|
||||
|
|
|
@ -3,13 +3,13 @@ use crate::{context::WorkGroup, element::WgpuElement, tensor::WgpuTensor};
|
|||
use std::marker::PhantomData;
|
||||
|
||||
/// Static wgpu kernel to create a [source template](SourceTemplate).
|
||||
pub trait StaticKernel: 'static {
|
||||
pub trait StaticKernel: Send + 'static {
|
||||
/// Source template for the kernel.
|
||||
fn source_template() -> SourceTemplate;
|
||||
}
|
||||
|
||||
/// Dynamic wgpu kernel to create a [source template](SourceTemplate).
|
||||
pub trait DynamicKernel {
|
||||
pub trait DynamicKernel: Send {
|
||||
/// Source template for the kernel.
|
||||
fn source_template(self) -> SourceTemplate;
|
||||
/// Identifier for the kernel, used for caching kernel compilation.
|
||||
|
|
|
@ -16,6 +16,10 @@ pub mod kernel;
|
|||
/// Tensor module.
|
||||
pub mod tensor;
|
||||
|
||||
#[cfg(test)] // Only enabled for dev for now.
|
||||
/// Compute related module.
|
||||
pub mod compute;
|
||||
|
||||
pub(crate) mod pool;
|
||||
pub(crate) mod tune;
|
||||
|
||||
|
|
|
@ -61,20 +61,14 @@ fn rustup(target: &str) {
|
|||
}
|
||||
|
||||
// Define and run a cargo command
|
||||
fn run_cargo(command: &str, first_params: &[&str], second_params: &[&str], error: &str) {
|
||||
fn run_cargo(command: &str, params: Params, error: &str) {
|
||||
// Print cargo command
|
||||
println!(
|
||||
"\ncargo {} {} {}\n",
|
||||
command,
|
||||
first_params.join(" "),
|
||||
second_params.join(" ")
|
||||
);
|
||||
println!("\ncargo {} {}\n", command, params);
|
||||
|
||||
// Run cargo
|
||||
let cargo = Command::new("cargo")
|
||||
.arg(command)
|
||||
.args(first_params)
|
||||
.args(second_params)
|
||||
.args(params.params)
|
||||
.stdout(Stdio::inherit()) // Send stdout directly to terminal
|
||||
.stderr(Stdio::inherit()) // Send stderr directly to terminal
|
||||
.spawn()
|
||||
|
@ -85,34 +79,31 @@ fn run_cargo(command: &str, first_params: &[&str], second_params: &[&str], error
|
|||
}
|
||||
|
||||
// Run cargo build command
|
||||
fn cargo_build(params: &[&str]) {
|
||||
fn cargo_build(params: Params) {
|
||||
// Run cargo build
|
||||
run_cargo(
|
||||
"build",
|
||||
params,
|
||||
&["--color=always"],
|
||||
params + "--color=always",
|
||||
"Failed to run cargo build",
|
||||
);
|
||||
}
|
||||
|
||||
// Run cargo install command
|
||||
fn cargo_install(params: &[&str]) {
|
||||
fn cargo_install(params: Params) {
|
||||
// Run cargo install
|
||||
run_cargo(
|
||||
"install",
|
||||
params,
|
||||
&["--color=always"],
|
||||
params + "--color=always",
|
||||
"Failed to run cargo install",
|
||||
);
|
||||
}
|
||||
|
||||
// Run cargo test command
|
||||
fn cargo_test(params: &[&str]) {
|
||||
fn cargo_test(params: Params) {
|
||||
// Run cargo test
|
||||
run_cargo(
|
||||
"test",
|
||||
params,
|
||||
&["--color=always", "--", "--color=always"],
|
||||
params + "--color=always" + "--" + "--color=always",
|
||||
"Failed to run cargo test",
|
||||
);
|
||||
}
|
||||
|
@ -122,8 +113,7 @@ fn cargo_fmt() {
|
|||
// Run cargo fmt
|
||||
run_cargo(
|
||||
"fmt",
|
||||
&["--check", "--all"],
|
||||
&["--", "--color=always"],
|
||||
["--check", "--all", "--", "--color=always"].into(),
|
||||
"Failed to run cargo fmt",
|
||||
);
|
||||
}
|
||||
|
@ -136,50 +126,48 @@ fn cargo_clippy() {
|
|||
// Run cargo clippy
|
||||
run_cargo(
|
||||
"clippy",
|
||||
&["--color=always"],
|
||||
&["--", "-D", "warnings"],
|
||||
["--color=always", "--", "-D", "warnings"].into(),
|
||||
"Failed to run cargo clippy",
|
||||
);
|
||||
}
|
||||
|
||||
// Run cargo doc command
|
||||
fn cargo_doc(params: &[&str]) {
|
||||
fn cargo_doc(params: Params) {
|
||||
// Run cargo doc
|
||||
run_cargo(
|
||||
"doc",
|
||||
params,
|
||||
&["--color=always"],
|
||||
"Failed to run cargo doc",
|
||||
);
|
||||
run_cargo("doc", params + "--color=always", "Failed to run cargo doc");
|
||||
}
|
||||
|
||||
// Build and test a crate in a no_std environment
|
||||
fn build_and_test_no_std(crate_name: &str) {
|
||||
fn build_and_test_no_std<const N: usize>(crate_name: &str, extra_args: [&str; N]) {
|
||||
println!("\nRun checks for `{}` crate", crate_name);
|
||||
|
||||
// Run cargo build --no-default-features
|
||||
cargo_build(&["-p", crate_name, "--no-default-features"]);
|
||||
cargo_build(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args);
|
||||
|
||||
// Run cargo test --no-default-features
|
||||
cargo_test(&["-p", crate_name, "--no-default-features"]);
|
||||
cargo_test(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args);
|
||||
|
||||
// Run cargo build --no-default-features --target wasm32-unknown-unknowns
|
||||
cargo_build(&[
|
||||
"-p",
|
||||
crate_name,
|
||||
"--no-default-features",
|
||||
"--target",
|
||||
WASM32_TARGET,
|
||||
]);
|
||||
cargo_build(
|
||||
Params::from([
|
||||
"-p",
|
||||
crate_name,
|
||||
"--no-default-features",
|
||||
"--target",
|
||||
WASM32_TARGET,
|
||||
]) + extra_args,
|
||||
);
|
||||
|
||||
// Run cargo build --no-default-features --target thumbv7m-none-eabi
|
||||
cargo_build(&[
|
||||
"-p",
|
||||
crate_name,
|
||||
"--no-default-features",
|
||||
"--target",
|
||||
ARM_TARGET,
|
||||
]);
|
||||
cargo_build(
|
||||
Params::from([
|
||||
"-p",
|
||||
crate_name,
|
||||
"--no-default-features",
|
||||
"--target",
|
||||
ARM_TARGET,
|
||||
]) + extra_args,
|
||||
);
|
||||
}
|
||||
|
||||
// Run no_std checks
|
||||
|
@ -193,12 +181,16 @@ fn no_std_checks() {
|
|||
rustup(ARM_TARGET);
|
||||
|
||||
// Run checks for the following crates
|
||||
build_and_test_no_std("burn");
|
||||
build_and_test_no_std("burn-core");
|
||||
build_and_test_no_std("burn-common");
|
||||
build_and_test_no_std("burn-tensor");
|
||||
build_and_test_no_std("burn-ndarray");
|
||||
build_and_test_no_std("burn-no-std-tests");
|
||||
build_and_test_no_std("burn", []);
|
||||
build_and_test_no_std("burn-core", []);
|
||||
build_and_test_no_std(
|
||||
"burn-compute",
|
||||
["--features", "channel-mutex storage-bytes"],
|
||||
);
|
||||
build_and_test_no_std("burn-common", []);
|
||||
build_and_test_no_std("burn-tensor", []);
|
||||
build_and_test_no_std("burn-ndarray", []);
|
||||
build_and_test_no_std("burn-no-std-tests", []);
|
||||
}
|
||||
|
||||
// Test burn-core with tch and wgpu backend
|
||||
|
@ -206,10 +198,10 @@ fn burn_core_std() {
|
|||
println!("\n\nRun checks for burn-core crate with tch and wgpu backend");
|
||||
|
||||
// Run cargo test --features test-tch
|
||||
cargo_test(&["-p", "burn-core", "--features", "test-tch"]);
|
||||
cargo_test(["-p", "burn-core", "--features", "test-tch"].into());
|
||||
|
||||
// Run cargo test --features test-wgpu
|
||||
cargo_test(&["-p", "burn-core", "--features", "test-wgpu"]);
|
||||
cargo_test(["-p", "burn-core", "--features", "test-wgpu"].into());
|
||||
}
|
||||
|
||||
// Test burn-dataset features
|
||||
|
@ -217,13 +209,13 @@ fn burn_dataset_features_std() {
|
|||
println!("\n\nRun checks for burn-dataset features");
|
||||
|
||||
// Run cargo build --all-features
|
||||
cargo_build(&["-p", "burn-dataset", "--all-features"]);
|
||||
cargo_build(["-p", "burn-dataset", "--all-features"].into());
|
||||
|
||||
// Run cargo test --all-features
|
||||
cargo_test(&["-p", "burn-dataset", "--all-features"]);
|
||||
cargo_test(["-p", "burn-dataset", "--all-features"].into());
|
||||
|
||||
// Run cargo doc --all-features
|
||||
cargo_doc(&["-p", "burn-dataset", "--all-features"]);
|
||||
cargo_doc(["-p", "burn-dataset", "--all-features"].into());
|
||||
}
|
||||
|
||||
fn std_checks() {
|
||||
|
@ -234,10 +226,10 @@ fn std_checks() {
|
|||
println!("Running std checks");
|
||||
|
||||
// Build each workspace
|
||||
cargo_build(&["--workspace", "--exclude=xtask"]);
|
||||
cargo_build(["--workspace", "--exclude=xtask"].into());
|
||||
|
||||
// Test each workspace
|
||||
cargo_test(&["--workspace"]);
|
||||
cargo_test(["--workspace"].into());
|
||||
|
||||
// Check format
|
||||
cargo_fmt();
|
||||
|
@ -246,7 +238,7 @@ fn std_checks() {
|
|||
cargo_clippy();
|
||||
|
||||
// Produce documentation for each workspace
|
||||
cargo_doc(&["--workspace"]);
|
||||
cargo_doc(["--workspace"].into());
|
||||
|
||||
// Test burn-dataset features
|
||||
burn_dataset_features_std();
|
||||
|
@ -257,7 +249,7 @@ fn std_checks() {
|
|||
|
||||
fn check_typos() {
|
||||
// Install typos-cli
|
||||
cargo_install(&["typos-cli", "--version", "1.16.5"]);
|
||||
cargo_install(["typos-cli", "--version", "1.16.5"].into());
|
||||
|
||||
println!("Running typos check \n\n");
|
||||
|
||||
|
@ -349,3 +341,39 @@ pub fn run(env: CheckType) -> anyhow::Result<()> {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct Params {
|
||||
params: Vec<String>,
|
||||
}
|
||||
|
||||
impl<const N: usize> From<[&str; N]> for Params {
|
||||
fn from(value: [&str; N]) -> Self {
|
||||
Self {
|
||||
params: value.iter().map(|v| v.to_string()).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Params {
|
||||
fn from(value: &str) -> Self {
|
||||
Self {
|
||||
params: vec![value.to_string()],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Params {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.params.join(" ").as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rhs: Into<Params>> std::ops::Add<Rhs> for Params {
|
||||
type Output = Params;
|
||||
|
||||
fn add(mut self, rhs: Rhs) -> Self::Output {
|
||||
let rhs: Params = rhs.into();
|
||||
self.params.extend(rhs.params);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue