Burn compute (#809)

This commit is contained in:
Nathaniel Simard 2023-09-18 19:56:53 -04:00 committed by GitHub
parent d7e9e75099
commit ac4adb54ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 1941 additions and 86 deletions

View File

@ -7,6 +7,7 @@ members = [
"burn",
"burn-autodiff",
"burn-common",
"burn-compute",
"burn-core",
"burn-dataset",
"burn-derive",
@ -24,9 +25,7 @@ members = [
"examples/*",
]
exclude = [
"examples/notebook",
]
exclude = ["examples/notebook"]
[workspace.dependencies]
bytemuck = "1.13"
@ -37,7 +36,7 @@ dirs = "5.0.1"
fake = "2.6.1"
flate2 = "1.0.26"
float-cmp = "0.9.0"
gix-tempfile = {version = "8.0.0", features = ["signals"]}
gix-tempfile = { version = "8.0.0", features = ["signals"] }
hashbrown = "0.14.0"
indicatif = "0.17.5"
libm = "0.2.7"
@ -47,17 +46,17 @@ proc-macro2 = "1.0.60"
protobuf-codegen = "3.2"
quote = "1.0.28"
r2d2 = "0.8.10"
r2d2_sqlite = {version = "0.22.0"}
r2d2_sqlite = { version = "0.22.0" }
rayon = "1.7.0"
rmp-serde = "1.1.1"
rstest = "0.18.1"
rusqlite = {version = "0.29"}
rusqlite = { version = "0.29" }
sanitize-filename = "0.5.0"
serde_rusqlite = "0.33.1"
spin = {version = "0.9.8", features = ["mutex", "spin_mutex"]}
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
strum = "0.24"
strum_macros = "0.24"
syn = {version = "2.0", features = ["full", "extra-traits"]}
syn = { version = "2.0", features = ["full", "extra-traits"] }
tempfile = "3.6.0"
thiserror = "1.0.40"
tracing-subscriber = "0.3.17"
@ -67,19 +66,33 @@ tracing-appender = "0.2.2"
# WGPU stuff
futures-intrusive = "0.5"
pollster = "0.3"
text_placeholder = {version = "0.5.0", features = ["struct_context"]}
text_placeholder = { version = "0.5.0", features = ["struct_context"] }
wgpu = "0.17.0"
#
# The following packages disable the "std" feature for no_std compatibility
#
bincode = {version = "2.0.0-rc.3", features = ["alloc", "serde"], default-features = false}
derive-new = {version = "0.5.9", default-features = false}
half = {version = "2.3.1", features = ["alloc", "num-traits", "serde"], default-features = false}
ndarray = {version = "0.15.6", default-features = false}
num-traits = {version = "0.2.15", default-features = false, features = ["libm"]}# libm is for no_std
rand = {version = "0.8.5", default-features = false, features = ["std_rng"]}# std_rng is for no_std
rand_distr = {version = "0.4.3", default-features = false}
serde = {version = "1.0.164", default-features = false, features = ["derive", "alloc"]}# alloc is for no_std, derive is needed
serde_json = {version = "1.0.96", default-features = false}
uuid = {version = "1.3.4", default-features = false}
bincode = { version = "2.0.0-rc.3", features = [
"alloc",
"serde",
], default-features = false }
derive-new = { version = "0.5.9", default-features = false }
half = { version = "2.3.1", features = [
"alloc",
"num-traits",
"serde",
], default-features = false }
ndarray = { version = "0.15.6", default-features = false }
num-traits = { version = "0.2.15", default-features = false, features = [
"libm",
] } # libm is for no_std
rand = { version = "0.8.5", default-features = false, features = [
"std_rng",
] } # std_rng is for no_std
rand_distr = { version = "0.4.3", default-features = false }
serde = { version = "1.0.164", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed
serde_json = { version = "1.0.96", default-features = false }
uuid = { version = "1.3.4", default-features = false }

View File

@ -1,7 +1,5 @@
use alloc::string::{String, ToString};
use crate::rand::{get_seeded_rng, Rng, SEED};
use alloc::string::{String, ToString};
use uuid::{Builder, Bytes};
/// Simple ID generator.

27
burn-compute/Cargo.toml Normal file
View File

@ -0,0 +1,27 @@
[package]
authors = ["louisfd <louisfd94@gmail.com>", "Nathaniel Simard"]
categories = ["science"]
description = "Compute crate that helps creating high performance async backends."
edition = "2021"
keywords = ["deep-learning", "machine-learning", "data"]
license = "MIT OR Apache-2.0"
name = "burn-compute"
readme = "README.md"
repository = "https://github.com/burn-rs/burn/tree/main/burn-compute"
version = "0.10.0"
[features]
default = ["std", "channel-mutex", "channel-mpsc", "channel-cell", "storage-bytes"]
std = []
channel-mutex = []
channel-cell = []
channel-mpsc = [] # Assume std
storage-bytes = []
[dependencies]
burn-common = { path = "../burn-common", version = "0.10.0", default-features = false }
derive-new = { workspace = true }
spin = { workspace = true }
hashbrown = { workspace = true }
[dev-dependencies]

7
burn-compute/README.md Normal file
View File

@ -0,0 +1,7 @@
# Burn Compute
This crate helps creating high performance async backends.
- [x] Asynchronous kernel executions
- [x] Memory allocation management
- [ ] Autotuning

View File

@ -0,0 +1,21 @@
use crate::server::{ComputeServer, Handle};
use alloc::vec::Vec;
/// The ComputeChannel trait links the ComputeClient to the ComputeServer
/// while ensuring thread-safety
pub trait ComputeChannel<Server: ComputeServer>: Clone {
/// Given a handle, returns owned resource as bytes
fn read(&self, handle: &Handle<Server>) -> Vec<u8>;
/// Given a resource as bytes, stores it and returns the resource handle
fn create(&self, data: &[u8]) -> Handle<Server>;
/// Reserves `size` bytes in the storage, and returns a handle over them
fn empty(&self, size: usize) -> Handle<Server>;
/// Executes the `kernel` over the given `handles`.
fn execute(&self, kernel: Server::Kernel, handles: &[&Handle<Server>]);
/// Wait for the completion of every task in the server.
fn sync(&self);
}

View File

@ -0,0 +1,65 @@
use super::ComputeChannel;
use crate::server::{ComputeServer, Handle};
use alloc::sync::Arc;
use alloc::vec::Vec;
/// A channel using a [ref cell](core::cell::RefCell) to access the server with mutability.
///
/// # Important
///
/// Only use this channel if you don't use any threading in your application, otherwise it will
/// panic or cause undefined behaviors.
///
/// This is mosly useful for `no-std` environments where threads aren't supported, otherwise prefer
/// the [mutex](super::MutexComputeChannel) or the [mpsc](super::MpscComputeChannel) channels.
pub struct RefCellComputeChannel<Server> {
server: Arc<core::cell::RefCell<Server>>,
}
impl<S> Clone for RefCellComputeChannel<S> {
fn clone(&self) -> Self {
Self {
server: self.server.clone(),
}
}
}
impl<Server> RefCellComputeChannel<Server>
where
Server: ComputeServer,
{
/// Create a new cell compute channel.
pub fn new(server: Server) -> Self {
Self {
server: Arc::new(core::cell::RefCell::new(server)),
}
}
}
impl<Server> ComputeChannel<Server> for RefCellComputeChannel<Server>
where
Server: ComputeServer,
{
fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
let mut server = self.server.borrow_mut();
server.read(handle)
}
fn create(&self, resource: &[u8]) -> Handle<Server> {
self.server.borrow_mut().create(resource)
}
fn empty(&self, size: usize) -> Handle<Server> {
self.server.borrow_mut().empty(size)
}
fn execute(&self, kernel_description: Server::Kernel, handles: &[&Handle<Server>]) {
self.server
.borrow_mut()
.execute(kernel_description, handles)
}
fn sync(&self) {
self.server.borrow_mut().sync()
}
}

View File

@ -0,0 +1,17 @@
mod base;
pub use base::*;
#[cfg(feature = "channel-mutex")]
mod mutex;
#[cfg(feature = "channel-mutex")]
pub use mutex::*;
#[cfg(feature = "channel-mpsc")]
mod mpsc;
#[cfg(feature = "channel-mpsc")]
pub use mpsc::*;
#[cfg(feature = "channel-cell")]
mod cell;
#[cfg(feature = "channel-cell")]
pub use cell::*;

View File

@ -0,0 +1,154 @@
use std::{
sync::{mpsc, Arc},
thread,
};
use super::ComputeChannel;
use crate::server::{ComputeServer, Handle};
/// Create a channel using the [multi-producer, single-consumer channel](mpsc) to communicate with
/// the compute server spawn on its own thread.
pub struct MpscComputeChannel<Server>
where
Server: ComputeServer,
{
state: Arc<MpscComputeChannelState<Server>>,
}
struct MpscComputeChannelState<Server>
where
Server: ComputeServer,
{
_handle: thread::JoinHandle<()>,
sender: mpsc::SyncSender<Message<Server>>,
}
type Callback<Response> = mpsc::SyncSender<Response>;
enum Message<Server>
where
Server: ComputeServer,
{
Read(Handle<Server>, Callback<Vec<u8>>),
Create(Vec<u8>, Callback<Handle<Server>>),
Empty(usize, Callback<Handle<Server>>),
Execute(Server::Kernel, Vec<Handle<Server>>),
Sync(Callback<()>),
}
impl<Server> MpscComputeChannel<Server>
where
Server: ComputeServer + 'static,
{
/// Create a new mpsc compute channel.
pub fn new(mut server: Server, bound: usize) -> Self {
let (sender, receiver) = mpsc::sync_channel(bound);
let _handle = thread::spawn(move || {
while let Ok(message) = receiver.recv() {
match message {
Message::Read(handle, callback) => {
let data = server.read(&handle);
core::mem::drop(handle);
callback.send(data).unwrap();
}
Message::Create(data, callback) => {
let handle = server.create(&data);
callback.send(handle).unwrap();
}
Message::Empty(size, callback) => {
let handle = server.empty(size);
callback.send(handle).unwrap();
}
Message::Execute(kernel, handles) => {
server.execute(kernel, &handles.iter().collect::<Vec<_>>());
}
Message::Sync(callback) => {
server.sync();
callback.send(()).unwrap();
}
};
}
});
let state = Arc::new(MpscComputeChannelState { sender, _handle });
Self { state }
}
}
impl<Server: ComputeServer> Clone for MpscComputeChannel<Server> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
}
}
}
impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
where
Server: ComputeServer + 'static,
{
fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
let (callback, response) = mpsc::sync_channel(1);
self.state
.sender
.send(Message::Read(handle.clone(), callback))
.unwrap();
self.response(response)
}
fn create(&self, data: &[u8]) -> Handle<Server> {
let (callback, response) = mpsc::sync_channel(1);
self.state
.sender
.send(Message::Create(data.to_vec(), callback))
.unwrap();
self.response(response)
}
fn empty(&self, size: usize) -> Handle<Server> {
let (callback, response) = mpsc::sync_channel(1);
self.state
.sender
.send(Message::Empty(size, callback))
.unwrap();
self.response(response)
}
fn execute(&self, kernel: Server::Kernel, handles: &[&Handle<Server>]) {
self.state
.sender
.send(Message::Execute(
kernel,
handles
.iter()
.map(|h| (*h).clone())
.collect::<Vec<Handle<Server>>>(),
))
.unwrap()
}
fn sync(&self) {
let (callback, response) = mpsc::sync_channel(1);
self.state.sender.send(Message::Sync(callback)).unwrap();
self.response(response)
}
}
impl<Server: ComputeServer> MpscComputeChannel<Server> {
fn response<Response>(&self, response: mpsc::Receiver<Response>) -> Response {
match response.recv() {
Ok(val) => val,
Err(err) => panic!("Can't connect to the server correctly {err:?}"),
}
}
}

View File

@ -0,0 +1,55 @@
use super::ComputeChannel;
use crate::server::{ComputeServer, Handle};
use alloc::sync::Arc;
use alloc::vec::Vec;
use spin::Mutex;
/// The MutexComputeChannel ensures thread-safety by locking the server
/// on every operation
pub struct MutexComputeChannel<Server> {
server: Arc<Mutex<Server>>,
}
impl<S> Clone for MutexComputeChannel<S> {
fn clone(&self) -> Self {
Self {
server: self.server.clone(),
}
}
}
impl<Server> MutexComputeChannel<Server>
where
Server: ComputeServer,
{
/// Create a new mutex compute channel.
pub fn new(server: Server) -> Self {
Self {
server: Arc::new(Mutex::new(server)),
}
}
}
impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
where
Server: ComputeServer,
{
fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
self.server.lock().read(handle)
}
fn create(&self, data: &[u8]) -> Handle<Server> {
self.server.lock().create(data)
}
fn empty(&self, size: usize) -> Handle<Server> {
self.server.lock().empty(size)
}
fn execute(&self, kernel: Server::Kernel, handles: &[&Handle<Server>]) {
self.server.lock().execute(kernel, handles)
}
fn sync(&self) {
self.server.lock().sync()
}
}

View File

@ -0,0 +1,65 @@
use crate::{
channel::ComputeChannel,
server::{ComputeServer, Handle},
};
use alloc::vec::Vec;
use core::marker::PhantomData;
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
/// It should be obtained for a specific device via the Compute struct.
pub struct ComputeClient<Server, Channel> {
channel: Channel,
_server: PhantomData<Server>,
}
impl<S, C> Clone for ComputeClient<S, C>
where
S: ComputeServer,
C: ComputeChannel<S>,
{
fn clone(&self) -> Self {
Self {
channel: self.channel.clone(),
_server: PhantomData,
}
}
}
impl<Server, Channel> ComputeClient<Server, Channel>
where
Server: ComputeServer,
Channel: ComputeChannel<Server>,
{
/// Create a new client.
pub fn new(channel: Channel) -> Self {
Self {
channel,
_server: PhantomData,
}
}
/// Given a handle, returns owned resource as bytes.
pub fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
self.channel.read(handle)
}
/// Given a resource, stores it and returns the resource handle.
pub fn create(&self, data: &[u8]) -> Handle<Server> {
self.channel.create(data)
}
/// Reserves `size` bytes in the storage, and returns a handle over them.
pub fn empty(&self, size: usize) -> Handle<Server> {
self.channel.empty(size)
}
/// Executes the `kernel` over the given `handles`.
pub fn execute(&self, kernel: Server::Kernel, handles: &[&Handle<Server>]) {
self.channel.execute(kernel, handles)
}
/// Wait for the completion of every task in the server.
pub fn sync(&self) {
self.channel.sync()
}
}

View File

@ -0,0 +1,83 @@
use crate::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
use core::ops::DerefMut;
use hashbrown::HashMap;
/// The compute type has the responsibility to retrieve the correct compute client based on the
/// given device.
pub struct Compute<Device, Server, Channel> {
clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
}
impl<Device, Server, Channel> Compute<Device, Server, Channel>
where
Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
Server: ComputeServer,
Channel: ComputeChannel<Server>,
{
/// Create a new compute.
pub const fn new() -> Self {
Self {
clients: spin::Mutex::new(None),
}
}
/// Get the compute client for the given device.
///
/// Provide the init function to create a new client if it isn't already initialized.
pub fn client<Init>(&self, device: &Device, init: Init) -> ComputeClient<Server, Channel>
where
Init: Fn() -> ComputeClient<Server, Channel>,
{
let mut clients = self.clients.lock();
if clients.is_none() {
Self::register_inner(device, init(), &mut clients);
}
match clients.deref_mut() {
Some(clients) => match clients.get(device) {
Some(client) => client.clone(),
None => {
let client = init();
clients.insert(device.clone(), client.clone());
client
}
},
_ => unreachable!(),
}
}
/// Register the compute client for the given device.
///
/// # Note
///
/// This function is mostly useful when the creation of the compute client can't be done
/// synchronously and require special context.
///
/// # Panics
///
/// If a client is already registered for the given device.
pub fn register(&self, device: &Device, client: ComputeClient<Server, Channel>) {
let mut clients = self.clients.lock();
Self::register_inner(device, client, &mut clients);
}
fn register_inner(
device: &Device,
client: ComputeClient<Server, Channel>,
clients: &mut Option<HashMap<Device, ComputeClient<Server, Channel>>>,
) {
if clients.is_none() {
*clients = Some(HashMap::new());
}
if let Some(clients) = clients {
if clients.contains_key(device) {
panic!("Client already created for device {:?}", device);
}
clients.insert(device.clone(), client);
}
}
}

53
burn-compute/src/id.rs Normal file
View File

@ -0,0 +1,53 @@
#[macro_export(local_inner_macros)]
/// Create a new storage ID type.
macro_rules! storage_id_type {
($name:ident) => {
#[derive(Clone, Hash, PartialEq, Eq)]
/// Storage ID.
pub struct $name {
id: alloc::sync::Arc<alloc::string::String>,
}
impl $name {
/// Create a new ID.
pub fn new() -> Self {
Self {
id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()),
}
}
}
impl Default for $name {
fn default() -> Self {
Self::new()
}
}
};
}
#[macro_export(local_inner_macros)]
/// Create a new memory ID type.
macro_rules! memory_id_type {
($name:ident) => {
#[derive(Clone, Hash, PartialEq, Eq)]
/// Memory ID.
pub struct $name {
id: alloc::sync::Arc<alloc::string::String>,
}
impl $name {
/// Create a new ID.
pub(crate) fn new() -> Self {
Self {
id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()),
}
}
}
impl Default for $name {
fn default() -> Self {
Self::new()
}
}
};
}

26
burn-compute/src/lib.rs Normal file
View File

@ -0,0 +1,26 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
//! Burn compute crate that helps creating high performance async backends.
extern crate alloc;
#[macro_use]
extern crate derive_new;
mod id;
/// Compute channel module.
pub mod channel;
/// Compute client module.
pub mod client;
/// Memory management module.
pub mod memory_management;
/// Compute server module.
pub mod server;
/// Compute Storage module.
pub mod storage;
mod compute;
pub use compute::*;

View File

@ -0,0 +1,27 @@
use crate::storage::ComputeStorage;
/// The MemoryHandle trait is an abstract way to refer to some memory segment.
/// It should not contain actual references to data.
///
/// It is responsible for determining if the memory segment can be mutated,
/// for instance by keeping track of a reference count
pub trait MemoryHandle: Clone + Send {
/// Checks if the underlying memory can be safely mutated.
fn can_mut(&self) -> bool;
}
/// The MemoryManagement trait encapsulates strategies for (de)allocating memory.
/// It is bound to the ComputeStorage trait, which does the actual (de)allocations.
///
/// The MemoryManagement can only reserve memory space or get the resource located at a space.
/// Modification of the resource data should be done directly on the resource.
pub trait MemoryManagement<Storage: ComputeStorage>: Send {
/// The associated type Handle must implement MemoryHandle
type Handle: MemoryHandle;
/// Returns the resource from the storage at the specified handle
fn get(&mut self, handle: &Self::Handle) -> Storage::Resource;
/// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it
fn reserve(&mut self, size: usize) -> Self::Handle;
}

View File

@ -0,0 +1,5 @@
mod base;
mod simple;
pub use base::*;
pub use simple::*;

View File

@ -0,0 +1,339 @@
use super::{MemoryHandle, MemoryManagement};
use crate::{
memory_id_type,
storage::{ComputeStorage, StorageHandle, StorageUtilization},
};
use alloc::{sync::Arc, vec::Vec};
use hashbrown::HashMap;
// The ChunkId allows to keep track of how many references there are to a specific chunk.
memory_id_type!(ChunkId);
// The SliceId allows to keep track of how many references there are to a specific slice.
memory_id_type!(SliceId);
impl ChunkId {
/// A chunk is free if it is only referred by the chunk hashmap.
fn is_free(&self) -> bool {
Arc::strong_count(&self.id) <= 1
}
}
impl SliceId {
/// A slice is free if it is only referred by the slice hashmap and the chunk it is in.
fn is_free(&self) -> bool {
Arc::strong_count(&self.id) <= 2
}
}
/// The SimpleHandle is a memory handle, referring to either a chunk or a slice.
#[derive(Clone)]
pub enum SimpleHandle {
/// A whole chunk of memory.
Chunk(ChunkId),
/// A slice of a chunk of memory.
Slice(SliceId),
}
/// The strategy defines the frequency at which deallocation of unused memory chunks should occur.
pub enum DeallocStrategy {
/// Once every n calls to reserve.
///
/// First associated data is n, second is the state and should start at 0
PeriodTick(usize, usize),
#[cfg(feature = "std")]
/// Once every period of time
PeriodTime(std::time::Duration, std::time::Instant),
/// Never deallocate.
Never,
}
impl DeallocStrategy {
/// Create a new strategy with the given period.
pub fn new_period_tick(period: usize) -> Self {
DeallocStrategy::PeriodTick(period, 0)
}
fn should_dealloc(&mut self) -> bool {
match self {
DeallocStrategy::PeriodTick(period, last) => {
*last = (*last + 1) % *period;
*last == 0
}
#[cfg(feature = "std")]
DeallocStrategy::PeriodTime(period, last) => {
if &last.elapsed() > period {
*last = std::time::Instant::now();
true
} else {
false
}
}
DeallocStrategy::Never => false,
}
}
}
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
pub struct SimpleMemoryManagement<Storage> {
chunks: HashMap<ChunkId, (StorageHandle, Vec<SliceId>)>,
slices: HashMap<SliceId, (StorageHandle, ChunkId)>,
dealloc_strategy: DeallocStrategy,
storage: Storage,
}
impl MemoryHandle for SimpleHandle {
/// Returns true if referenced by only one tensor, and only once by the
/// memory management hashmaps
fn can_mut(&self) -> bool {
// One reference in the chunk hashmap, another owned by one tensor.
const REFERENCE_LIMIT_CHUNK: usize = 2;
// One reference in the chunk hashmap (for the chunk on which this slice is built),
// another in the slice hashmap for this slice, and another owned by one tensor.
const REFERENCE_LIMIT_SLICE: usize = 3;
match &self {
SimpleHandle::Chunk(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_CHUNK,
SimpleHandle::Slice(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_SLICE,
}
}
}
impl<Storage: ComputeStorage> MemoryManagement<Storage> for SimpleMemoryManagement<Storage> {
type Handle = SimpleHandle;
/// Returns the resource from the storage, for the specified handle.
fn get(&mut self, handle: &Self::Handle) -> Storage::Resource {
let resource = match &handle {
SimpleHandle::Chunk(id) => &self.chunks.get(id).unwrap().0,
SimpleHandle::Slice(id) => &self.slices.get(id).unwrap().0,
};
self.storage.get(resource)
}
/// Reserves memory of specified size using the reserve algorithm, and return
/// a handle to the reserved memory.
///
/// Also clean ups, removing unused slices, and chunks if permitted by deallocation strategy.
fn reserve(&mut self, size: usize) -> Self::Handle {
self.cleanup_slices();
let handle = self.reserve_algorithm(size);
if self.dealloc_strategy.should_dealloc() {
self.cleanup_chunks();
}
handle
}
}
impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
/// Creates a new instance using the given storage and deallocation strategy.
pub fn new(storage: Storage, dealloc_strategy: DeallocStrategy) -> Self {
Self {
chunks: HashMap::new(),
slices: HashMap::new(),
dealloc_strategy,
storage,
}
}
/// Creates an new instance using the given storage without deallocation.
pub fn never_dealloc(storage: Storage) -> Self {
Self::new(storage, DeallocStrategy::Never)
}
fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle {
// Looks for a large enough, existing but unused chunk of memory.
let chunk = self.find_free_chunk(size);
match chunk {
Some((chunk_id, chunk_size)) => {
if size == chunk_size {
// If there is one of exactly the same size, it reuses it.
SimpleHandle::Chunk(chunk_id.clone())
} else {
// Otherwise creates a slice of the right size upon it, always starting at zero.
self.create_slice(size, chunk_id)
}
}
// If no chunk available, creates one of exactly the right size.
None => self.create_chunk(size),
}
}
/// Finds the smallest of the free and large enough chunks to fit `size`
/// Returns the chunk's id and size.
fn find_free_chunk(&self, size: usize) -> Option<(ChunkId, usize)> {
let mut size_diff_current = usize::MAX;
let mut current = None;
self.chunks
.iter()
.for_each(|(chunk_id, (resource, slices))| {
let is_free = slices.is_empty() && chunk_id.is_free();
if is_free && resource.size() > size {
let size_diff = resource.size() - size;
if size_diff < size_diff_current {
current = Some((chunk_id, resource));
size_diff_current = size_diff;
}
}
});
current.map(|(id, handle)| (id.clone(), handle.size()))
}
/// Creates a slice of size `size` upon the given chunk.
///
/// For now slices must start at zero, therefore there can be only one per chunk
fn create_slice(&mut self, size: usize, chunk_id: ChunkId) -> SimpleHandle {
let (handle, slices) = self.chunks.get_mut(&chunk_id).unwrap();
let slice_id = SliceId::new();
let storage = StorageHandle {
id: handle.id.clone(),
utilization: StorageUtilization::Slice(0, size),
};
if slices.is_empty() {
self.slices.insert(slice_id.clone(), (storage, chunk_id));
} else {
panic!("Can't have more than 1 slice yet.");
}
slices.push(slice_id.clone());
SimpleHandle::Slice(slice_id)
}
/// Creates a chunk of given size by allocating on the storage.
fn create_chunk(&mut self, size: usize) -> SimpleHandle {
let resource = self.storage.alloc(size);
let chunk_id = ChunkId::new();
self.chunks.insert(chunk_id.clone(), (resource, Vec::new()));
SimpleHandle::Chunk(chunk_id)
}
/// Deallocates free chunks and remove them from chunks map.
fn cleanup_chunks(&mut self) {
let mut ids_to_remove = Vec::new();
self.chunks.iter().for_each(|(chunk_id, _resource)| {
if chunk_id.is_free() {
ids_to_remove.push(chunk_id.clone());
}
});
ids_to_remove
.iter()
.map(|chunk_id| self.chunks.remove(chunk_id).unwrap())
.for_each(|(resource, _slices)| {
self.storage.dealloc(resource.id);
});
}
/// Removes free slices from slice map and corresponding chunks.
fn cleanup_slices(&mut self) {
let mut ids_to_remove = Vec::new();
self.slices.iter().for_each(|(slice_id, _resource)| {
if slice_id.is_free() {
ids_to_remove.push(slice_id.clone());
}
});
ids_to_remove
.iter()
.map(|slice_id| {
let value = self.slices.remove(slice_id).unwrap();
(slice_id, value.1)
})
.for_each(|(slice_id, chunk_id)| {
let (_chunk, slices) = self.chunks.get_mut(&chunk_id).unwrap();
slices.retain(|id| id != slice_id);
});
}
}
#[cfg(test)]
mod tests {
use crate::{
memory_management::{MemoryHandle, MemoryManagement},
storage::BytesStorage,
};
use super::{DeallocStrategy, SimpleMemoryManagement};
#[test]
fn can_mut_with_single_tensor_reference() {
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
let chunk_size = 4;
let simple_handle = memory_management.create_chunk(chunk_size);
let x = simple_handle.clone();
core::mem::drop(simple_handle);
assert!(x.can_mut());
}
#[test]
fn two_tensor_references_remove_mutability() {
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
let chunk_size = 4;
let simple_handle = memory_management.create_chunk(chunk_size);
let x = simple_handle.clone();
assert!(!simple_handle.can_mut());
assert!(!x.can_mut())
}
#[test]
fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() {
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
let chunk_size = 4;
let _chunk_handle = memory_management.reserve(chunk_size);
let _new_handle = memory_management.reserve(chunk_size);
assert_eq!(memory_management.chunks.len(), 2);
}
#[test]
fn when_empty_chunk_is_cleaned_upexists_it_disappears() {
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
let chunk_size = 4;
let chunk_handle = memory_management.reserve(chunk_size);
drop(chunk_handle);
memory_management.cleanup_chunks();
assert_eq!(memory_management.chunks.len(), 0);
}
#[test]
fn never_dealloc_strategy_never_deallocs() {
let mut never_dealloc = DeallocStrategy::Never;
for _ in 0..20 {
assert!(!never_dealloc.should_dealloc())
}
}
#[test]
fn period_tick_dealloc_strategy_should_dealloc_after_period() {
let period = 3;
let mut period_tick_dealloc = DeallocStrategy::new_period_tick(period);
for _ in 0..3 {
for _ in 0..period - 1 {
assert!(!period_tick_dealloc.should_dealloc());
}
assert!(period_tick_dealloc.should_dealloc());
}
}
}

View File

@ -0,0 +1,40 @@
use alloc::vec::Vec;
use crate::{memory_management::MemoryManagement, storage::ComputeStorage};
type _Storage<Server> = <Server as ComputeServer>::Storage;
type _MemoryManagement<Server> = <Server as ComputeServer>::MemoryManagement;
/// This alias for a [memory handle](MemoryManagement::Handle).
pub type Handle<Server> = <_MemoryManagement<Server> as MemoryManagement<_Storage<Server>>>::Handle;
/// The compute server is responsible for handling resources and computations over resources.
///
/// Everything in the server is mutable, therefore it should be solely accessed through the
/// [compute channel](crate::channel::ComputeChannel) for thread safety.
pub trait ComputeServer: Send {
/// The kernel type defines the computation algorithms.
type Kernel: Send;
/// The [storage](ComputeStorage) type defines how data is stored and accessed.
type Storage: ComputeStorage;
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
type MemoryManagement: MemoryManagement<Self::Storage>;
/// Given a handle, returns the owned resource as bytes.
fn read(&mut self, handle: &Handle<Self>) -> Vec<u8>;
/// Given a resource as bytes, stores it and returns the memory handle.
fn create(&mut self, data: &[u8]) -> Handle<Self>;
/// Reserves `size` bytes in the storage, and returns a handle over them.
fn empty(&mut self, size: usize) -> Handle<Self>;
/// Executes the `kernel` over the given memory `handles`.
///
/// Kernels have mutable access to every resource they are given
/// and are responsible of determining which should be read or written.
fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle<Self>]);
/// Wait for the completion of every task in the server.
fn sync(&mut self);
}

View File

@ -0,0 +1,48 @@
use crate::storage_id_type;
// This ID is used to map a handle to its actual data.
storage_id_type!(StorageId);
/// Defines if data uses a full memory chunk or a slice of it.
#[derive(Clone)]
pub enum StorageUtilization {
/// Full memory chunk of specified size
Full(usize),
/// Slice of memory chunk with start index and size.
Slice(usize, usize),
}
/// Contains the [storage id](StorageId) of a resource and the way it is used.
#[derive(new)]
pub struct StorageHandle {
/// Storage id.
pub id: StorageId,
/// How the storage is used.
pub utilization: StorageUtilization,
}
impl StorageHandle {
/// Returns the size the handle is pointing to in memory.
pub fn size(&self) -> usize {
match self.utilization {
StorageUtilization::Full(size) => size,
StorageUtilization::Slice(_, size) => size,
}
}
}
/// Storage types are responsible for allocating and deallocating memory.
pub trait ComputeStorage: Send {
/// The resource associated type determines the way data is implemented and how
/// it can be accessed by kernels.
type Resource: Send;
/// Returns the underlying resource for a specified storage handle
fn get(&mut self, handle: &StorageHandle) -> Self::Resource;
/// Allocates `size` units of memory and returns a handle to it
fn alloc(&mut self, size: usize) -> StorageHandle;
/// Deallocates the memory pointed by the given storage id.
fn dealloc(&mut self, id: StorageId);
}

View File

@ -0,0 +1,121 @@
use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
use alloc::alloc::{alloc, dealloc, Layout};
use hashbrown::HashMap;
/// The bytes storage maps ids to pointers of bytes in a contiguous layout.
#[derive(Default)]
pub struct BytesStorage {
memory: HashMap<StorageId, AllocatedBytes>,
}
/// Can send to other threads, but can't sync.
unsafe impl Send for BytesStorage {}
unsafe impl Send for BytesResource {}
/// This struct is a pointer to a memory chunk or slice.
pub struct BytesResource {
ptr: *mut u8,
utilization: StorageUtilization,
}
/// This struct refers to a specific (contiguous) layout of bytes.
struct AllocatedBytes {
ptr: *mut u8,
layout: Layout,
}
impl BytesResource {
fn get_exact_location_and_length(&self) -> (*mut u8, usize) {
match self.utilization {
StorageUtilization::Full(len) => (self.ptr, len),
StorageUtilization::Slice(location, len) => unsafe { (self.ptr.add(location), len) },
}
}
/// Returns the resource as a mutable slice of bytes.
pub fn write<'a>(&self) -> &'a mut [u8] {
let (ptr, len) = self.get_exact_location_and_length();
unsafe { core::slice::from_raw_parts_mut(ptr, len) }
}
/// Returns the resource as an immutable slice of bytes.
pub fn read<'a>(&self) -> &'a [u8] {
let (ptr, len) = self.get_exact_location_and_length();
unsafe { core::slice::from_raw_parts(ptr, len) }
}
}
impl ComputeStorage for BytesStorage {
type Resource = BytesResource;
fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
let allocated_bytes = self.memory.get_mut(&handle.id).unwrap();
BytesResource {
ptr: allocated_bytes.ptr,
utilization: handle.utilization.clone(),
}
}
fn alloc(&mut self, size: usize) -> StorageHandle {
let id = StorageId::new();
let handle = StorageHandle {
id: id.clone(),
utilization: StorageUtilization::Full(size),
};
unsafe {
let layout = Layout::array::<u8>(size).unwrap();
let ptr = alloc(layout);
let memory = AllocatedBytes { ptr, layout };
self.memory.insert(id, memory);
}
handle
}
fn dealloc(&mut self, id: StorageId) {
if let Some(memory) = self.memory.remove(&id) {
unsafe {
dealloc(memory.ptr, memory.layout);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_can_alloc_and_dealloc() {
let mut storage = BytesStorage::default();
let handle_1 = storage.alloc(64);
assert_eq!(handle_1.size(), 64);
storage.dealloc(handle_1.id);
}
#[test]
fn test_slices() {
let mut storage = BytesStorage::default();
let handle_1 = storage.alloc(64);
let handle_2 = StorageHandle::new(handle_1.id.clone(), StorageUtilization::Slice(24, 8));
storage
.get(&handle_1)
.write()
.iter_mut()
.enumerate()
.for_each(|(i, b)| {
*b = i as u8;
});
let bytes = storage.get(&handle_2).read().to_vec();
storage.dealloc(handle_1.id);
assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
}
}

View File

@ -0,0 +1,8 @@
mod base;
pub use base::*;
#[cfg(feature = "storage-bytes")]
mod bytes_cpu;
#[cfg(feature = "storage-bytes")]
pub use bytes_cpu::*;

View File

@ -0,0 +1,26 @@
use super::DummyServer;
use burn_compute::channel::MutexComputeChannel;
use burn_compute::client::ComputeClient;
use burn_compute::memory_management::SimpleMemoryManagement;
use burn_compute::storage::BytesStorage;
use burn_compute::Compute;
/// The dummy device.
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub struct DummyDevice;
static COMPUTE: Compute<DummyDevice, DummyServer, MutexComputeChannel<DummyServer>> =
Compute::new();
pub fn client(
device: &DummyDevice,
) -> ComputeClient<DummyServer, MutexComputeChannel<DummyServer>> {
COMPUTE.client(device, || {
let storage = BytesStorage::default();
let memory_management = SimpleMemoryManagement::never_dealloc(storage);
let server = DummyServer::new(memory_management);
let channel = MutexComputeChannel::new(server);
ComputeClient::new(channel)
})
}

View File

@ -0,0 +1,25 @@
use burn_compute::storage::BytesResource;
/// The DummyKernel trait should be implemented for every supported operation
pub trait DummyKernel: Send {
fn compute<'a>(&self, resources: &mut [BytesResource]);
}
/// Contains the algorithm for element-wise addition
pub struct DummyElementwiseAddition;
impl DummyKernel for DummyElementwiseAddition {
fn compute<'a>(&self, inputs: &mut [BytesResource]) {
// Notice how the kernel is responsible for determining which inputs
// are read-only and which are writable.
let lhs = &inputs[0].read();
let rhs = &inputs[1].read();
let out = &mut inputs[2].write();
let size = lhs.len();
for i in 0..size {
out[i] = lhs[i] + rhs[i];
}
}
}

View File

@ -0,0 +1,7 @@
mod compute;
mod kernel;
mod server;
pub use compute::*;
pub use kernel::*;
pub use server::*;

View File

@ -0,0 +1,60 @@
use burn_compute::{
memory_management::{MemoryManagement, SimpleMemoryManagement},
server::{ComputeServer, Handle},
storage::BytesStorage,
};
use derive_new::new;
use super::DummyKernel;
/// The dummy server is used to test the burn-compute infrastructure.
/// It uses simple memory management with a bytes storage on CPU, without asynchronous tasks.
#[derive(new)]
pub struct DummyServer<MM = SimpleMemoryManagement<BytesStorage>> {
memory_management: MM,
}
impl<MM> ComputeServer for DummyServer<MM>
where
MM: MemoryManagement<BytesStorage>,
{
type Kernel = Box<dyn DummyKernel>;
type Storage = BytesStorage;
type MemoryManagement = MM;
fn read(&mut self, handle: &Handle<Self>) -> Vec<u8> {
let bytes = self.memory_management.get(handle);
bytes.read().to_vec()
}
fn create(&mut self, data: &[u8]) -> Handle<Self> {
let handle = self.memory_management.reserve(data.len());
let resource = self.memory_management.get(&handle);
let bytes = resource.write();
for (i, val) in data.iter().enumerate() {
bytes[i] = *val;
}
handle
}
fn empty(&mut self, size: usize) -> Handle<Self> {
self.memory_management.reserve(size)
}
fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle<Self>]) {
let mut resources = handles
.iter()
.map(|handle| self.memory_management.get(handle))
.collect::<Vec<_>>();
kernel.compute(&mut resources);
}
fn sync(&mut self) {
// Nothing to do with dummy backend.
}
}

View File

@ -0,0 +1,38 @@
mod dummy;
use dummy::{client, DummyDevice, DummyElementwiseAddition};
#[test]
fn created_resource_is_the_same_when_read() {
let client = client(&DummyDevice);
let resource = Vec::from([0, 1, 2]);
let resource_description = client.create(&resource);
let obtained_resource = client.read(&resource_description);
assert_eq!(resource, obtained_resource)
}
#[test]
fn empty_allocates_memory() {
let client = client(&DummyDevice);
let size = 4;
let resource_description = client.empty(size);
let empty_resource = client.read(&resource_description);
assert_eq!(empty_resource.len(), 4);
}
#[test]
fn execute_elementwise_addition() {
let client = client(&DummyDevice);
let lhs = client.create(&[0, 1, 2]);
let rhs = client.create(&[4, 4, 4]);
let out = client.empty(3);
client.execute(Box::new(DummyElementwiseAddition), &[&lhs, &rhs, &out]);
let obtained_resource = client.read(&out);
assert_eq!(obtained_resource, Vec::from([4, 5, 6]))
}

View File

@ -45,6 +45,10 @@ burn-tensor = { path = "../burn-tensor", version = "0.10.0", default-features =
burn-ndarray = { path = "../burn-ndarray", version = "0.10.0" }
serial_test = "2.0.0"
# Still only in dev mode
hashbrown = { workspace = true }
burn-compute = { path = "../burn-compute", version = "0.10.0", default-features = false, features=["channel-mutex", "std"] }
[[bench]]
name = "unary"
harness = false

View File

@ -0,0 +1,129 @@
use super::{Kernel, WgpuServer};
use crate::{
compute::WgpuStorage,
context::{select_device, WorkGroup},
kernel::{DynamicKernel, SourceTemplate, StaticKernel},
GraphicsApi, WgpuDevice,
};
use burn_compute::{
channel::MutexComputeChannel,
client::ComputeClient,
memory_management::{DeallocStrategy, SimpleMemoryManagement},
Compute,
};
use std::{marker::PhantomData, sync::Arc};
type WgpuChannel = MutexComputeChannel<WgpuServer>;
/// Compute handle for the wgpu backend.
static COMPUTE: Compute<WgpuDevice, WgpuServer, WgpuChannel> = Compute::new();
pub fn compute_client<G: GraphicsApi>(
device: &WgpuDevice,
) -> ComputeClient<WgpuServer, WgpuChannel> {
let device = Arc::new(device);
COMPUTE.client(&device, move || {
let (device_wgpu, queue, info) = pollster::block_on(select_device::<G>(&device));
log::info!(
"Created wgpu compute server on device {:?} => {:?}",
device,
info
);
// TODO: Support a way to modify max_tasks without std.
let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") {
Ok(value) => value
.parse::<usize>()
.expect("BURN_WGPU_MAX_TASKS should be a positive integer."),
Err(_) => 16, // 16 tasks by default
};
let device = Arc::new(device_wgpu);
let storage = WgpuStorage::new(device.clone());
// Maximum reusability.
let memory_management = SimpleMemoryManagement::new(storage, DeallocStrategy::Never);
let server = WgpuServer::new(memory_management, device, queue, max_tasks);
let channel = WgpuChannel::new(server);
ComputeClient::new(channel)
})
}
pub struct DynamicComputeKernel<K> {
kernel: K,
workgroup: WorkGroup,
}
impl<K> Kernel for DynamicComputeKernel<K>
where
K: DynamicKernel + 'static,
{
fn source_template(self: Box<Self>) -> SourceTemplate {
self.kernel.source_template()
}
fn id(&self) -> String {
self.kernel.id()
}
fn workgroup(&self) -> WorkGroup {
self.workgroup.clone()
}
}
#[derive(new)]
pub struct StaticComputeKernel<K> {
workgroup: WorkGroup,
_kernel: PhantomData<K>,
}
impl<K> Kernel for StaticComputeKernel<K>
where
K: StaticKernel + 'static,
{
fn source_template(self: Box<Self>) -> SourceTemplate {
K::source_template()
}
fn id(&self) -> String {
format!("{:?}", core::any::TypeId::of::<K>())
}
fn workgroup(&self) -> WorkGroup {
self.workgroup.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{binary_elemwise, kernel::KernelSettings, AutoGraphicsApi};
#[test]
fn can_run_kernel() {
binary_elemwise!(Add, "+");
let client = compute_client::<AutoGraphicsApi>(&WgpuDevice::default());
let lhs: Vec<f32> = vec![0., 1., 2., 3., 4., 5., 6., 7.];
let rhs: Vec<f32> = vec![10., 11., 12., 6., 7., 3., 1., 0.];
let info: Vec<u32> = vec![1, 1, 1, 1, 8, 8, 8];
let lhs = client.create(bytemuck::cast_slice(&lhs));
let rhs = client.create(bytemuck::cast_slice(&rhs));
let out = client.empty(core::mem::size_of::<f32>() * 8);
let info = client.create(bytemuck::cast_slice(&info));
type Kernel = KernelSettings<Add, f32, i32, 16, 16, 1>;
let kernel = Box::new(StaticComputeKernel::<Kernel>::new(WorkGroup::new(1, 1, 1)));
client.execute(kernel, &[&lhs, &rhs, &out, &info]);
let data = client.read(&out);
let output: &[f32] = bytemuck::cast_slice(&data);
assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]);
}
}

View File

@ -0,0 +1,7 @@
mod base;
mod server;
mod storage;
pub use base::*;
pub use server::*;
pub use storage::*;

View File

@ -0,0 +1,253 @@
use std::{borrow::Cow, sync::Arc};
use super::WgpuStorage;
use crate::{context::WorkGroup, kernel::SourceTemplate};
use burn_compute::{
memory_management::{MemoryManagement, SimpleMemoryManagement},
server::{self, ComputeServer},
};
use hashbrown::HashMap;
use wgpu::{
util::{BufferInitDescriptor, DeviceExt},
BindGroup, CommandEncoder, ComputePipeline, ShaderModuleDescriptor,
};
/// Wgpu compute server.
pub struct WgpuServer<MM = SimpleMemoryManagement<WgpuStorage>> {
memory_management: MM,
device: Arc<wgpu::Device>,
queue: wgpu::Queue,
encoder: CommandEncoder,
pipelines: HashMap<String, Arc<ComputePipeline>>,
tasks: Vec<ComputeTask>,
max_tasks: usize,
}
#[derive(new)]
struct ComputeTask {
pipeline: Arc<ComputePipeline>,
bind_group: BindGroup,
work_group: WorkGroup,
}
pub trait Kernel: 'static + Send {
/// Source template for the kernel.
fn source_template(self: Box<Self>) -> SourceTemplate;
/// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> String;
fn workgroup(&self) -> WorkGroup;
}
impl<MM> WgpuServer<MM>
where
MM: MemoryManagement<WgpuStorage>,
{
pub fn new(
memory_management: MM,
device: Arc<wgpu::Device>,
queue: wgpu::Queue,
max_tasks: usize,
) -> Self {
let encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Command Encoder"),
});
Self {
memory_management,
device,
queue,
encoder,
pipelines: HashMap::new(),
tasks: Vec::new(),
max_tasks,
}
}
fn submit(&mut self) {
assert!(
self.tasks.is_empty(),
"Tasks should be completed before submitting the current encoder."
);
println!("Submit");
let mut new_encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
core::mem::swap(&mut new_encoder, &mut self.encoder);
self.queue.submit(Some(new_encoder.finish()));
}
fn register_tasks(&mut self) {
if self.tasks.is_empty() {
return;
}
let mut compute = self
.encoder
.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
for task in self.tasks.iter() {
compute.set_pipeline(&task.pipeline);
compute.set_bind_group(0, &task.bind_group, &[]);
compute.dispatch_workgroups(task.work_group.x, task.work_group.y, task.work_group.z);
}
std::mem::drop(compute);
self.tasks.clear();
}
fn pipeline(&mut self, kernel: Box<dyn Kernel>) -> Arc<ComputePipeline> {
let kernel_id = kernel.id();
if let Some(pipeline) = self.pipelines.get(&kernel_id) {
return pipeline.clone();
}
let pipeline = self.compile_source(&kernel.source_template().complete());
self.pipelines.insert(kernel_id.clone(), pipeline.clone());
pipeline
}
fn compile_source(&self, source: &str) -> Arc<ComputePipeline> {
let module = self.device.create_shader_module(ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
});
Arc::new(
self.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &module,
entry_point: "main",
}),
)
}
}
impl<MM> ComputeServer for WgpuServer<MM>
where
MM: MemoryManagement<WgpuStorage>,
{
type Kernel = Box<dyn Kernel>;
type Storage = WgpuStorage;
type MemoryManagement = MM;
fn read(&mut self, handle: &server::Handle<Self>) -> Vec<u8> {
// Register previous tasks before reading the buffer so that it is up to date.
self.register_tasks();
let resource = self.memory_management.get(handle);
let size = resource.size();
let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.encoder.copy_buffer_to_buffer(
&resource.buffer,
resource.offset(),
&buffer_dest,
0,
size,
);
self.submit();
let buffer_slice = buffer_dest.slice(..);
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
sender
.send(v)
.expect("Unable to send buffer slice result to async channel.")
});
self.device.poll(wgpu::Maintain::Wait);
let result = pollster::block_on(receiver.receive());
if let Some(Ok(())) = result {
let data = buffer_slice.get_mapped_range();
let result = bytemuck::cast_slice(&data).to_vec();
drop(data);
buffer_dest.unmap();
result
} else {
panic!("Unable to read buffer {:?}", result)
}
}
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
let handle = self.empty(data.len());
let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor {
label: Some("Buffer Src"),
contents: data,
usage: wgpu::BufferUsages::COPY_SRC,
}));
let resource = self.memory_management.get(&handle);
self.register_tasks();
self.encoder.copy_buffer_to_buffer(
&buffer_src,
0,
&resource.buffer,
resource.offset(),
buffer_src.size(),
);
handle
}
fn empty(&mut self, size: usize) -> server::Handle<Self> {
self.memory_management.reserve(size)
}
fn execute(&mut self, kernel: Self::Kernel, handles: &[&server::Handle<Self>]) {
let work_group = kernel.workgroup();
let pipeline = self.pipeline(kernel);
let group_layout = pipeline.get_bind_group_layout(0);
let handles = handles
.iter()
.map(|handle| self.memory_management.get(handle))
.collect::<Vec<_>>();
let entries = handles
.iter()
.enumerate()
.map(|(i, buffer)| wgpu::BindGroupEntry {
binding: i as u32,
resource: buffer.as_binding(),
})
.collect::<Vec<_>>();
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &group_layout,
entries: &entries,
});
self.tasks
.push(ComputeTask::new(pipeline, bind_group, work_group));
if self.tasks.len() >= self.max_tasks {
self.register_tasks();
self.submit();
}
}
fn sync(&mut self) {
if !self.tasks.is_empty() {
self.register_tasks();
self.submit();
}
}
}

View File

@ -0,0 +1,102 @@
use burn_compute::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
use hashbrown::HashMap;
use std::{num::NonZeroU64, sync::Arc};
pub struct WgpuStorage {
memory: HashMap<StorageId, Arc<wgpu::Buffer>>,
device: Arc<wgpu::Device>,
}
#[derive(new, Debug)]
pub struct WgpuResource {
pub buffer: Arc<wgpu::Buffer>,
pub kind: WgpuResourceKind,
}
impl WgpuResource {
/// Return the binding view of the buffer.
pub fn as_binding(&self) -> wgpu::BindingResource {
let binding = match &self.kind {
WgpuResourceKind::Full => self.buffer.as_entire_buffer_binding(),
WgpuResourceKind::Slice(offs, size) => wgpu::BufferBinding {
buffer: &self.buffer,
offset: *offs,
size: Some(*size),
},
};
wgpu::BindingResource::Buffer(binding)
}
/// Return the buffer size.
pub fn size(&self) -> u64 {
match self.kind {
WgpuResourceKind::Full => self.buffer.size(),
WgpuResourceKind::Slice(_, size) => size.get(),
}
}
/// Return the buffer offset.
pub fn offset(&self) -> u64 {
match self.kind {
WgpuResourceKind::Full => 0,
WgpuResourceKind::Slice(offset, _) => offset,
}
}
}
#[derive(Debug)]
pub enum WgpuResourceKind {
/// Represents an entire buffer.
Full,
/// A slice over a buffer.
Slice(wgpu::BufferAddress, wgpu::BufferSize),
}
/// Keeps actual wgpu buffer references in a hashmap with ids as key.
impl WgpuStorage {
pub fn new(device: Arc<wgpu::Device>) -> Self {
Self {
memory: HashMap::new(),
device,
}
}
}
impl ComputeStorage for WgpuStorage {
type Resource = WgpuResource;
fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
let buffer = self.memory.get(&handle.id).unwrap();
match handle.utilization {
StorageUtilization::Full(_) => {
WgpuResource::new(buffer.clone(), WgpuResourceKind::Full)
}
StorageUtilization::Slice(offset, size) => WgpuResource::new(
buffer.clone(),
WgpuResourceKind::Slice(offset as u64, NonZeroU64::new(size as u64).unwrap()),
),
}
}
fn alloc(&mut self, size: usize) -> StorageHandle {
let id = StorageId::new();
let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: size as u64,
usage: wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
}));
self.memory.insert(id.clone(), buffer);
StorageHandle::new(id, StorageUtilization::Full(size))
}
fn dealloc(&mut self, id: StorageId) {
self.memory.get(&id).unwrap().destroy();
let _ = self.memory.remove(&id);
}
}

View File

@ -270,7 +270,7 @@ impl PartialEq for Context {
}
}
async fn select_device<G: GraphicsApi>(
pub(crate) async fn select_device<G: GraphicsApi>(
device: &WgpuDevice,
) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) {
let adapter = select_adapter::<G>(device);

View File

@ -3,13 +3,13 @@ use crate::{context::WorkGroup, element::WgpuElement, tensor::WgpuTensor};
use std::marker::PhantomData;
/// Static wgpu kernel to create a [source template](SourceTemplate).
pub trait StaticKernel: 'static {
pub trait StaticKernel: Send + 'static {
/// Source template for the kernel.
fn source_template() -> SourceTemplate;
}
/// Dynamic wgpu kernel to create a [source template](SourceTemplate).
pub trait DynamicKernel {
pub trait DynamicKernel: Send {
/// Source template for the kernel.
fn source_template(self) -> SourceTemplate;
/// Identifier for the kernel, used for caching kernel compilation.

View File

@ -16,6 +16,10 @@ pub mod kernel;
/// Tensor module.
pub mod tensor;
#[cfg(test)] // Only enabled for dev for now.
/// Compute related module.
pub mod compute;
pub(crate) mod pool;
pub(crate) mod tune;

View File

@ -61,20 +61,14 @@ fn rustup(target: &str) {
}
// Define and run a cargo command
fn run_cargo(command: &str, first_params: &[&str], second_params: &[&str], error: &str) {
fn run_cargo(command: &str, params: Params, error: &str) {
// Print cargo command
println!(
"\ncargo {} {} {}\n",
command,
first_params.join(" "),
second_params.join(" ")
);
println!("\ncargo {} {}\n", command, params);
// Run cargo
let cargo = Command::new("cargo")
.arg(command)
.args(first_params)
.args(second_params)
.args(params.params)
.stdout(Stdio::inherit()) // Send stdout directly to terminal
.stderr(Stdio::inherit()) // Send stderr directly to terminal
.spawn()
@ -85,34 +79,31 @@ fn run_cargo(command: &str, first_params: &[&str], second_params: &[&str], error
}
// Run cargo build command
fn cargo_build(params: &[&str]) {
fn cargo_build(params: Params) {
// Run cargo build
run_cargo(
"build",
params,
&["--color=always"],
params + "--color=always",
"Failed to run cargo build",
);
}
// Run cargo install command
fn cargo_install(params: &[&str]) {
fn cargo_install(params: Params) {
// Run cargo install
run_cargo(
"install",
params,
&["--color=always"],
params + "--color=always",
"Failed to run cargo install",
);
}
// Run cargo test command
fn cargo_test(params: &[&str]) {
fn cargo_test(params: Params) {
// Run cargo test
run_cargo(
"test",
params,
&["--color=always", "--", "--color=always"],
params + "--color=always" + "--" + "--color=always",
"Failed to run cargo test",
);
}
@ -122,8 +113,7 @@ fn cargo_fmt() {
// Run cargo fmt
run_cargo(
"fmt",
&["--check", "--all"],
&["--", "--color=always"],
["--check", "--all", "--", "--color=always"].into(),
"Failed to run cargo fmt",
);
}
@ -136,50 +126,48 @@ fn cargo_clippy() {
// Run cargo clippy
run_cargo(
"clippy",
&["--color=always"],
&["--", "-D", "warnings"],
["--color=always", "--", "-D", "warnings"].into(),
"Failed to run cargo clippy",
);
}
// Run cargo doc command
fn cargo_doc(params: &[&str]) {
fn cargo_doc(params: Params) {
// Run cargo doc
run_cargo(
"doc",
params,
&["--color=always"],
"Failed to run cargo doc",
);
run_cargo("doc", params + "--color=always", "Failed to run cargo doc");
}
// Build and test a crate in a no_std environment
fn build_and_test_no_std(crate_name: &str) {
fn build_and_test_no_std<const N: usize>(crate_name: &str, extra_args: [&str; N]) {
println!("\nRun checks for `{}` crate", crate_name);
// Run cargo build --no-default-features
cargo_build(&["-p", crate_name, "--no-default-features"]);
cargo_build(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args);
// Run cargo test --no-default-features
cargo_test(&["-p", crate_name, "--no-default-features"]);
cargo_test(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args);
// Run cargo build --no-default-features --target wasm32-unknown-unknowns
cargo_build(&[
"-p",
crate_name,
"--no-default-features",
"--target",
WASM32_TARGET,
]);
cargo_build(
Params::from([
"-p",
crate_name,
"--no-default-features",
"--target",
WASM32_TARGET,
]) + extra_args,
);
// Run cargo build --no-default-features --target thumbv7m-none-eabi
cargo_build(&[
"-p",
crate_name,
"--no-default-features",
"--target",
ARM_TARGET,
]);
cargo_build(
Params::from([
"-p",
crate_name,
"--no-default-features",
"--target",
ARM_TARGET,
]) + extra_args,
);
}
// Run no_std checks
@ -193,12 +181,16 @@ fn no_std_checks() {
rustup(ARM_TARGET);
// Run checks for the following crates
build_and_test_no_std("burn");
build_and_test_no_std("burn-core");
build_and_test_no_std("burn-common");
build_and_test_no_std("burn-tensor");
build_and_test_no_std("burn-ndarray");
build_and_test_no_std("burn-no-std-tests");
build_and_test_no_std("burn", []);
build_and_test_no_std("burn-core", []);
build_and_test_no_std(
"burn-compute",
["--features", "channel-mutex storage-bytes"],
);
build_and_test_no_std("burn-common", []);
build_and_test_no_std("burn-tensor", []);
build_and_test_no_std("burn-ndarray", []);
build_and_test_no_std("burn-no-std-tests", []);
}
// Test burn-core with tch and wgpu backend
@ -206,10 +198,10 @@ fn burn_core_std() {
println!("\n\nRun checks for burn-core crate with tch and wgpu backend");
// Run cargo test --features test-tch
cargo_test(&["-p", "burn-core", "--features", "test-tch"]);
cargo_test(["-p", "burn-core", "--features", "test-tch"].into());
// Run cargo test --features test-wgpu
cargo_test(&["-p", "burn-core", "--features", "test-wgpu"]);
cargo_test(["-p", "burn-core", "--features", "test-wgpu"].into());
}
// Test burn-dataset features
@ -217,13 +209,13 @@ fn burn_dataset_features_std() {
println!("\n\nRun checks for burn-dataset features");
// Run cargo build --all-features
cargo_build(&["-p", "burn-dataset", "--all-features"]);
cargo_build(["-p", "burn-dataset", "--all-features"].into());
// Run cargo test --all-features
cargo_test(&["-p", "burn-dataset", "--all-features"]);
cargo_test(["-p", "burn-dataset", "--all-features"].into());
// Run cargo doc --all-features
cargo_doc(&["-p", "burn-dataset", "--all-features"]);
cargo_doc(["-p", "burn-dataset", "--all-features"].into());
}
fn std_checks() {
@ -234,10 +226,10 @@ fn std_checks() {
println!("Running std checks");
// Build each workspace
cargo_build(&["--workspace", "--exclude=xtask"]);
cargo_build(["--workspace", "--exclude=xtask"].into());
// Test each workspace
cargo_test(&["--workspace"]);
cargo_test(["--workspace"].into());
// Check format
cargo_fmt();
@ -246,7 +238,7 @@ fn std_checks() {
cargo_clippy();
// Produce documentation for each workspace
cargo_doc(&["--workspace"]);
cargo_doc(["--workspace"].into());
// Test burn-dataset features
burn_dataset_features_std();
@ -257,7 +249,7 @@ fn std_checks() {
fn check_typos() {
// Install typos-cli
cargo_install(&["typos-cli", "--version", "1.16.5"]);
cargo_install(["typos-cli", "--version", "1.16.5"].into());
println!("Running typos check \n\n");
@ -349,3 +341,39 @@ pub fn run(env: CheckType) -> anyhow::Result<()> {
Ok(())
}
struct Params {
params: Vec<String>,
}
impl<const N: usize> From<[&str; N]> for Params {
fn from(value: [&str; N]) -> Self {
Self {
params: value.iter().map(|v| v.to_string()).collect(),
}
}
}
impl From<&str> for Params {
fn from(value: &str) -> Self {
Self {
params: vec![value.to_string()],
}
}
}
impl std::fmt::Display for Params {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.params.join(" ").as_str())
}
}
impl<Rhs: Into<Params>> std::ops::Add<Rhs> for Params {
type Output = Params;
fn add(mut self, rhs: Rhs) -> Self::Output {
let rhs: Params = rhs.into();
self.params.extend(rhs.params);
self
}
}