Move HandleContainer and Tensor Ops descriptions from burn-fusion to burn-tensor (#1654)

* Move HandlerContainer and Tensor Ops description to burn-tensor

Move HandleContainer and Tensor operations descriptions to burn-tensor crate.
Removed the FusionDevice and replaced it with a DeviceOps trait bound to Backend::Device.

For now added modules to burn-tensor are excluded from no-std as they rely on Arc.

* [burn-tensor] Flatten module hierarchy for tensor representation

+ Add new repr feature to cargo file.

* Remove prefix on dosctring

* [burn-fusion] Require default features of burn-tensor
This commit is contained in:
Sylvain Benner 2024-04-23 11:27:54 -04:00 committed by GitHub
parent e6b1b7a317
commit c579686a8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
51 changed files with 776 additions and 739 deletions

8
Cargo.lock generated
View File

@ -3696,6 +3696,14 @@ dependencies = [
"thiserror",
]
[[package]]
name = "refactor"
version = "0.14.0"
dependencies = [
"burn",
"serde",
]
[[package]]
name = "regex"
version = "1.10.4"

View File

@ -1,6 +1,9 @@
use std::marker::PhantomData;
use burn_tensor::{backend::Backend, Device};
use burn_tensor::{
backend::{Backend, DeviceId, DeviceOps},
Device,
};
use candle_core::DeviceLocation;
use crate::{
@ -60,6 +63,16 @@ impl From<candle_core::Device> for CandleDevice {
}
}
impl DeviceOps for CandleDevice {
fn id(&self) -> burn_tensor::backend::DeviceId {
match self {
CandleDevice::Cpu => DeviceId::new(0, 0),
CandleDevice::Cuda(index) => DeviceId::new(1, *index as u32),
CandleDevice::Metal(index) => DeviceId::new(2, *index as u32),
}
}
}
impl Default for CandleDevice {
fn default() -> Self {
Self::Cpu

View File

@ -16,7 +16,7 @@ std = ["serde/std"]
doc = ["default"]
[dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.14.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.14.0" }
burn-common = { path = "../burn-common", version = "0.14.0" }
hashbrown = { workspace = true }
derive-new = {workspace = true }

View File

@ -1,15 +1,17 @@
use crate::{
client::FusionClient,
stream::{Context, OperationDescription},
FusionClientLocator, FusionTensor, PrecisionBridge,
client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge,
};
use burn_tensor::{
backend::Backend,
repr::{OperationDescription, ReprBackend},
Device,
};
use burn_tensor::{backend::Backend, Device, Shape};
use serde::{de::DeserializeOwned, Serialize};
use std::marker::PhantomData;
pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new();
pub(crate) fn get_client<B: FusionBackend>(device: &B::FusionDevice) -> B::FusionClient {
pub(crate) fn get_client<B: FusionBackend>(device: &B::Device) -> B::FusionClient {
CLIENTS.client(device)
}
@ -43,7 +45,7 @@ impl<B: FusionBackend> Backend for Fusion<B> {
}
fn sync(device: &Self::Device) {
let client = CLIENTS.client::<B::FusionClient>(&device.clone().into());
let client = CLIENTS.client::<B::FusionClient>(&device.clone());
client.drain();
B::sync(device)
}
@ -114,62 +116,17 @@ pub trait Optimization<B: FusionBackend>: Send {
fn from_state(device: &B::Device, state: B::OptimizationState) -> Self;
}
/// The device id.
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
pub struct DeviceId {
/// The type id identifies the type of the device.
pub type_id: u16,
/// The index id identifies the device number.
pub index_id: u32,
}
/// The handle device trait allows to get an id for a backend device.
pub trait FusionDevice: Clone + Send + Sync + PartialEq {
/// Return the [device id](DeviceId).
fn id(&self) -> DeviceId;
}
/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
/// [operation builder](crate::OptimizationBuilder).
pub trait FusionBackend: Backend {
pub trait FusionBackend: Backend + ReprBackend {
/// The state that can be serialized for an optimization.
type OptimizationState: Serialize + DeserializeOwned;
/// Optimization type for the backend.
type Optimization: Optimization<Self>;
/// The device type that can return an ID.
///
/// It can be the same as (Backend::Device), but must implement (FusionDevice).
type FusionDevice: FusionDevice + From<Self::Device> + Into<Self::Device> + core::fmt::Debug;
/// The type that can be used to point to a tensor of any kind.
type Handle: Sync + Send + Clone;
/// What kind of client should be used.
type FusionClient: FusionClient<FusionBackend = Self>;
/// The list of optimizations that will be used to optimize the computational graph.
fn optimizations(device: Device<Self>)
-> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
/// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::FloatTensorPrimitive).
fn float_tensor<const D: usize>(
handle: Self::Handle,
shape: Shape<D>,
) -> Self::FloatTensorPrimitive<D>;
/// Convert a [handle](FusionBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive).
fn int_tensor<const D: usize>(
handle: Self::Handle,
shape: Shape<D>,
) -> Self::IntTensorPrimitive<D>;
/// Convert a [handle](FusionBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive).
fn bool_tensor<const D: usize>(
handle: Self::Handle,
shape: Shape<D>,
) -> Self::BoolTensorPrimitive<D>;
/// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](FusionBackend::Handle).
fn float_tensor_handle<const D: usize>(tensor: Self::FloatTensorPrimitive<D>) -> Self::Handle;
/// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](FusionBackend::Handle).
fn int_tensor_handle<const D: usize>(tensor: Self::IntTensorPrimitive<D>) -> Self::Handle;
/// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](FusionBackend::Handle).
fn bool_tensor_handle<const D: usize>(tensor: Self::BoolTensorPrimitive<D>) -> Self::Handle;
}

View File

@ -1,10 +1,12 @@
use crate::{
stream::{Operation, OperationDescription, StreamId},
FusionBackend, FusionTensor, Handle, TensorDescription, TensorId,
stream::{execution::Operation, StreamId},
FusionBackend, FusionTensor, Handle,
};
use burn_tensor::{
backend::Backend,
ops::{FloatElem, IntElem},
Data, Reader,
repr::{OperationDescription, TensorDescription, TensorId},
Data, Device, Reader,
};
/// Define how to interact with the fusion server.
@ -12,8 +14,8 @@ pub trait FusionClient: Send + Sync + Clone {
/// The [fusion backend](FusionBackend) associated type.
type FusionBackend: FusionBackend;
/// Create a new client for the given [fusion device](FusionBackend::FusionDevice).
fn new(device: <Self::FusionBackend as FusionBackend>::FusionDevice) -> Self;
/// Create a new client for the given [device](Backend::Device).
fn new(device: Device<Self::FusionBackend>) -> Self;
/// Register a new [tensor operation description](OperationDescription).
fn register<O: Operation<Self::FusionBackend> + 'static>(
&self,
@ -24,7 +26,7 @@ pub trait FusionClient: Send + Sync + Clone {
/// Register all lazy computation.
fn drain(&self);
/// Get the current device used by all operations handled by this client.
fn device(&self) -> &<Self::FusionBackend as FusionBackend>::FusionDevice;
fn device(&self) -> &<Self::FusionBackend as Backend>::Device;
/// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it.
fn tensor_uninitialized(&self, shape: Vec<usize>) -> FusionTensor<Self>;
/// Create a tensor with the given handle and shape.

View File

@ -1,9 +1,13 @@
use super::FusionClient;
use crate::{
stream::{Operation, OperationDescription, StreamId},
stream::{execution::Operation, StreamId},
FusionBackend, FusionServer, FusionTensor, Handle,
};
use burn_tensor::ops::FloatElem;
use burn_tensor::{
backend::Backend,
ops::FloatElem,
repr::{OperationDescription, TensorDescription, TensorId},
};
use spin::Mutex;
use std::sync::Arc;
@ -13,7 +17,7 @@ where
B: FusionBackend,
{
server: Arc<Mutex<FusionServer<B>>>,
device: B::FusionDevice,
device: B::Device,
}
impl<B> Clone for MutexFusionClient<B>
@ -34,7 +38,7 @@ where
{
type FusionBackend = B;
fn new(device: B::FusionDevice) -> Self {
fn new(device: B::Device) -> Self {
Self {
device: device.clone(),
server: Arc::new(Mutex::new(FusionServer::new(device))),
@ -63,7 +67,7 @@ where
FusionTensor::new(id, shape, self.clone(), StreamId::current())
}
fn device(&self) -> &<Self::FusionBackend as FusionBackend>::FusionDevice {
fn device(&self) -> &<Self::FusionBackend as Backend>::Device {
&self.device
}
fn register_tensor(
@ -82,7 +86,7 @@ where
fn read_tensor_float<const D: usize>(
&self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
stream: StreamId,
) -> burn_tensor::Reader<burn_tensor::Data<FloatElem<Self::FusionBackend>, D>> {
self.server.lock().read_float(tensor, stream)
@ -90,7 +94,7 @@ where
fn read_tensor_int<const D: usize>(
&self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
id: StreamId,
) -> burn_tensor::Reader<burn_tensor::Data<burn_tensor::ops::IntElem<Self::FusionBackend>, D>>
{
@ -99,7 +103,7 @@ where
fn read_tensor_bool<const D: usize>(
&self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
stream: StreamId,
) -> burn_tensor::Reader<burn_tensor::Data<bool, D>> {
self.server.lock().read_bool(tensor, stream)
@ -107,17 +111,16 @@ where
fn change_client_float<const D: usize>(
&self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
client: Self,
stream: StreamId,
) -> FusionTensor<Self> {
let device = client.device.clone().into();
let mut server_other = client.server.lock();
let mut server_current = self.server.lock();
server_current.drain_stream(stream);
let id = server_current.change_server_float::<D>(&tensor, &device, &mut server_other);
let id =
server_current.change_server_float::<D>(&tensor, &client.device, &mut server_other);
core::mem::drop(server_other);
core::mem::drop(server_current);
@ -127,17 +130,15 @@ where
fn change_client_int<const D: usize>(
&self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
client: Self,
stream: StreamId,
) -> FusionTensor<Self> {
let device = client.device.clone().into();
let mut server_other = client.server.lock();
let mut server_current = self.server.lock();
server_current.drain_stream(stream);
let id = server_current.change_server_int::<D>(&tensor, &device, &mut server_other);
let id = server_current.change_server_int::<D>(&tensor, &client.device, &mut server_other);
core::mem::drop(server_other);
core::mem::drop(server_current);
@ -147,17 +148,15 @@ where
fn change_client_bool<const D: usize>(
&self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
client: Self,
stream: StreamId,
) -> FusionTensor<Self> {
let device = client.device.clone().into();
let mut server_other = client.server.lock();
let mut server_current = self.server.lock();
server_current.drain_stream(stream);
let id = server_current.change_server_bool::<D>(&tensor, &device, &mut server_other);
let id = server_current.change_server_bool::<D>(&tensor, &client.device, &mut server_other);
core::mem::drop(server_other);
core::mem::drop(server_current);
@ -165,7 +164,7 @@ where
FusionTensor::new(id, tensor.shape, client, StreamId::current())
}
fn register_orphan(&self, id: &crate::TensorId) {
fn register_orphan(&self, id: &TensorId) {
self.server.lock().drop_tensor_handle(*id);
}
}

View File

@ -1,8 +1,14 @@
use crate::{client::FusionClient, DeviceId, FusionBackend, FusionDevice};
use burn_tensor::{
backend::{Backend, DeviceId, DeviceOps},
repr::ReprBackend,
};
use crate::client::FusionClient;
use std::{any::Any, collections::HashMap, ops::DerefMut};
/// Type alias for [fusion backend handle](FusionBackend::Handle).
pub type Handle<B> = <B as FusionBackend>::Handle;
/// Type alias for [representation backend handle](burn_tensor::repr::ReprBackend::Handle).
pub type Handle<B> = <B as ReprBackend>::Handle;
type Key = (core::any::TypeId, DeviceId);
pub(crate) struct FusionClientLocator {
@ -22,7 +28,7 @@ impl FusionClientLocator {
/// Provide the init function to create a new client if it isn't already initialized.
pub fn client<C: FusionClient + 'static>(
&self,
device: &<C::FusionBackend as FusionBackend>::FusionDevice,
device: &<C::FusionBackend as Backend>::Device,
) -> C {
let device_id = device.id();
let client_id = (core::any::TypeId::of::<C>(), device_id);

View File

@ -16,7 +16,6 @@ pub mod stream;
mod backend;
mod bridge;
mod fusion;
mod handle;
mod ops;
mod server;
mod tensor;
@ -26,5 +25,4 @@ pub(crate) use server::*;
pub use backend::*;
pub use bridge::*;
pub use fusion::*;
pub use handle::*;
pub use tensor::*;

View File

@ -11,7 +11,7 @@ macro_rules! binary_float_ops {
}
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
let rhs = handles.get_float_tensor(&self.desc.rhs);
let output = $ops(lhs, rhs);
@ -35,7 +35,7 @@ macro_rules! binary_float_cmp_ops {
}
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
let rhs = handles.get_float_tensor(&self.desc.rhs);
let output = $ops(lhs, rhs);
@ -59,7 +59,7 @@ macro_rules! binary_int_cmp_ops {
}
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
let rhs = handles.get_int_tensor(&self.desc.rhs);
let output = $ops(lhs, rhs);
@ -93,7 +93,7 @@ macro_rules! binary_int_ops {
}
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
let rhs = handles.get_int_tensor(&self.desc.rhs);
let output = $ops(lhs, rhs);

View File

@ -2,23 +2,24 @@ use crate::{
client::FusionClient,
get_client,
ops::binary::binary_ops_shape,
stream::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, Operation,
OperationDescription, PermuteOperationDescription, RepeatOperationDescription,
ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, StreamId,
SwapDimsDescription, UnaryOperationDescription,
},
stream::{execution::Operation, StreamId},
Fusion, FusionBackend,
};
use burn_tensor::{
ops::{BoolTensor, BoolTensorOps},
repr::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
HandleContainer, OperationDescription, PermuteOperationDescription,
RepeatOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
},
Device, Shape,
};
impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
let client = get_client::<B>(&device.clone().into());
let client = get_client::<B>(&device.clone());
let tensor = B::bool_empty(shape.clone(), device);
client.register_tensor(
@ -42,7 +43,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
data: burn_tensor::Data<bool, D>,
device: &Device<Self>,
) -> BoolTensor<Self, D> {
let client = get_client::<B>(&device.clone().into());
let client = get_client::<B>(&device.clone());
let tensor = B::bool_from_data(data, device);
let shape = B::bool_shape(&tensor);
@ -62,7 +63,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for IntoIntOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let output = B::bool_into_int(input);
handles.register_int_tensor(&self.desc.out.id, output);
@ -95,7 +96,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for IntoFloatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let output = B::bool_into_float(input);
handles.register_float_tensor(&self.desc.out.id, output);
@ -119,15 +120,15 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
fn bool_device<const D: usize>(tensor: &BoolTensor<Self, D>) -> Device<Self> {
tensor.client.device().clone().into()
tensor.client.device().clone()
}
fn bool_to_device<const D: usize>(
tensor: BoolTensor<Self, D>,
device: &Device<Self>,
) -> BoolTensor<Self, D> {
let device_original: &B::FusionDevice = tensor.client.device();
let device_target: B::FusionDevice = device.clone().into();
let device_original: &B::Device = tensor.client.device();
let device_target: B::Device = device.clone();
if device_original == &device_target {
return tensor;
@ -154,7 +155,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for ReshapeDimsOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_bool_tensor::<D1>(&self.desc.input);
let output = B::bool_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
handles.register_bool_tensor(&self.desc.out.id, output);
@ -188,7 +189,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_bool_tensor::<D1>(&self.desc.tensor);
let output =
@ -232,7 +233,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceAssignOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_bool_tensor::<D1>(&self.desc.tensor);
let value = handles.get_bool_tensor::<D1>(&self.desc.value);
@ -277,7 +278,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for CatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensors = self
.desc
.tensors
@ -329,7 +330,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for EqualOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let lhs = handles.get_bool_tensor::<D>(&self.desc.lhs);
let rhs = handles.get_bool_tensor(&self.desc.rhs);
let output = B::bool_equal(lhs, rhs);
@ -364,7 +365,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for NotOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let output = B::bool_not(input);
handles.register_bool_tensor(&self.desc.out.id, output);
@ -381,7 +382,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
out.client.register(
vec![stream],
OperationDescription::Bool(crate::stream::BoolOperationDescription::Not(desc.clone())),
OperationDescription::Bool(BoolOperationDescription::Not(desc.clone())),
NotOps::<D>::new(desc),
);
@ -399,7 +400,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for SwapDimsOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2);
handles.register_bool_tensor(&self.desc.out.id, output);
@ -438,7 +439,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for PermuteDimsOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let axes: [usize; D] = self.desc.axes.try_into().unwrap();
let output = B::bool_permute(input, axes);
@ -478,7 +479,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, const D2: usize, B: FusionBackend> Operation<B> for ExpandOps<D, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let shape: [usize; D2] = self.desc.shape.try_into().unwrap();
let output = B::bool_expand(input, shape.into());
@ -516,7 +517,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for FlipOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let output = B::bool_flip(input, self.desc.axes.as_slice());
handles.register_bool_tensor(&self.desc.out.id, output);
@ -552,7 +553,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for RepeatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_bool_tensor::<D>(&self.desc.tensor);
let output = B::bool_repeat::<D>(tensor, self.desc.dim, self.desc.times);

View File

@ -4,21 +4,23 @@ use crate::{
get_client,
ops::binary::binary_ops_shape,
scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
stream::{
stream::{execution::Operation, StreamId},
unary_float_ops, Fusion, FusionBackend,
};
use burn_tensor::{
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
repr::{
BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription,
FloatOperationDescription, GatherOperationDescription, MaskFillOperationDescription,
MaskWhereOperationDescription, NumericOperationDescription, Operation,
FloatOperationDescription, GatherOperationDescription, HandleContainer,
MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription,
OperationDescription, PermuteOperationDescription, RandomOperationDescription,
ReduceDimWithIndicesDescription, RepeatOperationDescription, ReshapeDescription,
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
StreamId, SwapDimsDescription, UnaryOperationDescription,
SwapDimsDescription, TensorDescription, UnaryOperationDescription,
},
unary_float_ops, Fusion, FusionBackend, TensorDescription,
};
use burn_tensor::{
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
Data, Device, Distribution, ElementConversion, Reader, Shape,
};
use std::ops::Range;
@ -28,7 +30,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
data: Data<FloatElem<Self>, D>,
device: &Device<Self>,
) -> FloatTensor<Self, D> {
let client = get_client::<B>(&device.clone().into());
let client = get_client::<B>(&device.clone());
let tensor = B::float_from_data(data, device);
let shape = B::float_shape(&tensor);
@ -50,7 +52,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for RandomOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let shape = Shape::from(self.desc.out.shape.clone());
let output: B::FloatTensorPrimitive<D> =
B::float_random(shape, self.desc.distribution, &handles.device);
@ -60,7 +62,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
let stream = StreamId::current();
let shape: Vec<usize> = shape.dims.into();
let client = get_client::<B>(&device.clone().into());
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(shape);
let desc = RandomOperationDescription {
@ -83,7 +85,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for ZerosOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let shape = Shape::from(self.out.shape.clone());
let output = B::float_zeros::<D>(shape, &handles.device);
handles.register_float_tensor(&self.out.id, output);
@ -92,7 +94,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
let stream = StreamId::current();
let shape: Vec<usize> = shape.dims.into();
let client = get_client::<B>(&device.clone().into());
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(shape);
let desc = out.to_description_out();
@ -112,7 +114,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for OnesOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let shape = Shape::from(self.out.shape.clone());
let output = B::float_ones::<D>(shape, &handles.device);
handles.register_float_tensor(&self.out.id, output);
@ -121,7 +123,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
let stream = StreamId::current();
let shape: Vec<usize> = shape.dims.into();
let client = get_client::<B>(&device.clone().into());
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(shape);
let desc = out.to_description_out();
@ -146,7 +148,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for FullOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let shape = Shape::from(self.out.shape.clone());
let output: B::FloatTensorPrimitive<D> =
B::float_full(shape, self.elem.elem(), &handles.device);
@ -156,7 +158,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
let stream = StreamId::current();
let shape: Vec<usize> = shape.dims.into();
let client = get_client::<B>(&device.clone().into());
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(shape);
let desc = (out.to_description_out(), fill_value.elem::<f32>());
@ -180,15 +182,15 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
fn float_device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
tensor.client.device().clone().into()
tensor.client.device().clone()
}
fn float_to_device<const D: usize>(
tensor: FloatTensor<Self, D>,
device: &Device<Self>,
) -> FloatTensor<Self, D> {
let device_original: &B::FusionDevice = tensor.client.device();
let device_target: B::FusionDevice = device.clone().into();
let device_original: &B::Device = tensor.client.device();
let device_target: B::Device = device.clone();
if device_original == &device_target {
return tensor;
@ -212,7 +214,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for IntoIntOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.input);
let output = B::float_into_int(input);
@ -237,7 +239,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
fn float_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
let client = get_client::<B>(&device.clone().into());
let client = get_client::<B>(&device.clone());
let stream = StreamId::current();
let tensor = B::float_empty(shape.clone(), device);
@ -307,7 +309,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for ClampOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.tensor);
let output = B::float_clamp(input, self.desc.min.elem(), self.desc.max.elem());
@ -551,7 +553,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for SwapDimsOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.input);
let output = B::float_swap_dims(input, self.desc.dim1, self.desc.dim2);
handles.register_float_tensor(&self.desc.out.id, output);
@ -591,7 +593,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for ReshapeDimsOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_float_tensor::<D1>(&self.desc.input);
let output = B::float_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
handles.register_float_tensor(&self.desc.out.id, output);
@ -626,7 +628,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for GatherOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let indices = handles.get_int_tensor(&self.desc.indices);
@ -667,7 +669,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for ScatterOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let indices = handles.get_int_tensor(&self.desc.indices);
let value = handles.get_float_tensor(&self.desc.value);
@ -712,7 +714,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for SelectOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let indices = handles.get_int_tensor(&self.desc.indices);
@ -754,7 +756,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for SelectAssignOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let indices = handles.get_int_tensor(&self.desc.indices);
let value = handles.get_float_tensor(&self.desc.value);
@ -799,7 +801,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_float_tensor::<D1>(&self.desc.tensor);
let output =
@ -842,7 +844,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceAssignOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_float_tensor::<D1>(&self.desc.tensor);
let value = handles.get_float_tensor::<D1>(&self.desc.value);
@ -887,7 +889,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for MaskWhereOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let value = handles.get_float_tensor(&self.desc.value);
let mask = handles.get_bool_tensor(&self.desc.mask);
@ -932,7 +934,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for MaskFillOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let mask = handles.get_bool_tensor(&self.desc.mask);
@ -1530,7 +1532,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for CatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensors = self
.desc
.tensors
@ -1607,7 +1609,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for RepeatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let output = B::float_repeat::<D>(tensor, self.desc.dim, self.desc.times);
@ -1715,7 +1717,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for MaxDimWithIndicesOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let (output, indices) = B::float_max_dim_with_indices(tensor, self.desc.dim);
@ -1802,7 +1804,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for MinDimWithIndicesOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let (output, indices) = B::float_min_dim_with_indices(tensor, self.desc.dim);
@ -1871,7 +1873,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for PermuteDimsOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.input);
let axes: [usize; D] = self.desc.axes.try_into().unwrap();
let output = B::float_permute(input, axes);
@ -1911,7 +1913,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, const D2: usize, B: FusionBackend> Operation<B> for ExpandOps<D, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.input);
let shape: [usize; D2] = self.desc.shape.try_into().unwrap();
let output = B::float_expand(input, shape.into());
@ -1949,7 +1951,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for FlipOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.input);
let output = B::float_flip(input, &self.desc.axes);
handles.register_float_tensor(&self.desc.out.id, output);

View File

@ -4,28 +4,29 @@ use crate::{
get_client,
ops::binary::binary_ops_shape,
scalar_int_cmp_ops, scalar_int_ops,
stream::{
self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription,
GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription,
NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
RandomOperationDescription, ReduceDimWithIndicesDescription, RepeatOperationDescription,
ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription,
SelectAssignOperationDescription, SelectOperationDescription,
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
UnaryOperationDescription,
},
unary_int_ops, Fusion, FusionBackend, TensorDescription,
stream::{execution::Operation, StreamId},
unary_int_ops, Fusion, FusionBackend,
};
use burn_tensor::{
ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps},
repr::{
self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription,
GatherOperationDescription, HandleContainer, MaskFillOperationDescription,
MaskWhereOperationDescription, NumericOperationDescription, OperationDescription,
PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription,
RepeatOperationDescription, ReshapeDescription, ScalarOperationDescription,
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription,
TensorDescription, UnaryOperationDescription,
},
Data, Device, Distribution, ElementConversion, Reader, Shape,
};
use core::ops::Range;
impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
let client = get_client::<B>(&device.clone().into());
let client = get_client::<B>(&device.clone());
let tensor = B::int_empty(shape.clone(), device);
let stream = StreamId::current();
@ -44,7 +45,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
data: Data<IntElem<Self>, D>,
device: &Device<Self>,
) -> IntTensor<Self, D> {
let client = get_client::<B>(&device.clone().into());
let client = get_client::<B>(&device.clone());
let tensor = B::int_from_data(data, device);
let shape = B::int_shape(&tensor);
let stream = StreamId::current();
@ -53,15 +54,15 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
fn int_device<const D: usize>(tensor: &IntTensor<Self, D>) -> Device<Self> {
tensor.client.device().clone().into()
tensor.client.device().clone()
}
fn int_to_device<const D: usize>(
tensor: IntTensor<Self, D>,
device: &Device<Self>,
) -> IntTensor<Self, D> {
let device_original: &B::FusionDevice = tensor.client.device();
let device_target: B::FusionDevice = device.clone().into();
let device_original: &B::Device = tensor.client.device();
let device_target: B::Device = device.clone();
if device_original == &device_target {
return tensor;
@ -86,7 +87,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for ReshapeDimsOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_int_tensor::<D1>(&self.desc.input);
let output = B::int_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
handles.register_int_tensor(&self.desc.out.id, output);
@ -120,7 +121,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D1>(&self.desc.tensor);
let output =
@ -164,7 +165,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceAssignOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D1>(&self.desc.tensor);
let value = handles.get_int_tensor::<D1>(&self.desc.value);
@ -208,7 +209,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for MaskWhereOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let value = handles.get_int_tensor(&self.desc.value);
let mask = handles.get_bool_tensor(&self.desc.mask);
@ -251,7 +252,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for MaskFillOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let mask = handles.get_bool_tensor(&self.desc.mask);
@ -291,7 +292,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for GatherOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let indices = handles.get_int_tensor(&self.desc.indices);
@ -331,7 +332,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for ScatterOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let indices = handles.get_int_tensor(&self.desc.indices);
let value = handles.get_int_tensor(&self.desc.value);
@ -374,7 +375,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for SelectOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let indices = handles.get_int_tensor(&self.desc.indices);
@ -416,7 +417,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for SelectAssignOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let indices = handles.get_int_tensor(&self.desc.indices);
let value = handles.get_int_tensor(&self.desc.value);
@ -457,7 +458,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for CatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensors = self
.desc
.tensors
@ -770,9 +771,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
stream::OperationDescription::NumericInt(NumericOperationDescription::Add(
desc.clone(),
)),
repr::OperationDescription::NumericInt(NumericOperationDescription::Add(desc.clone())),
AddOps::<D>::new(desc),
);
@ -795,7 +794,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
stream::OperationDescription::NumericInt(NumericOperationDescription::AddScalar(
repr::OperationDescription::NumericInt(NumericOperationDescription::AddScalar(
desc.clone(),
)),
AddOps::<D>::new(desc),
@ -823,9 +822,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
stream::OperationDescription::NumericInt(NumericOperationDescription::Sub(
desc.clone(),
)),
repr::OperationDescription::NumericInt(NumericOperationDescription::Sub(desc.clone())),
SubOps::<D>::new(desc),
);
@ -848,7 +845,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
stream::OperationDescription::NumericInt(NumericOperationDescription::SubScalar(
repr::OperationDescription::NumericInt(NumericOperationDescription::SubScalar(
desc.clone(),
)),
SubOps::<D>::new(desc),
@ -876,9 +873,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
stream::OperationDescription::NumericInt(NumericOperationDescription::Mul(
desc.clone(),
)),
repr::OperationDescription::NumericInt(NumericOperationDescription::Mul(desc.clone())),
MulOps::<D>::new(desc),
);
@ -901,7 +896,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
stream::OperationDescription::NumericInt(NumericOperationDescription::MulScalar(
repr::OperationDescription::NumericInt(NumericOperationDescription::MulScalar(
desc.clone(),
)),
MulOps::<D>::new(desc),
@ -929,9 +924,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
stream::OperationDescription::NumericInt(NumericOperationDescription::Div(
desc.clone(),
)),
repr::OperationDescription::NumericInt(NumericOperationDescription::Div(desc.clone())),
DivOps::<D>::new(desc),
);
@ -954,7 +947,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
stream::OperationDescription::NumericInt(NumericOperationDescription::DivScalar(
repr::OperationDescription::NumericInt(NumericOperationDescription::DivScalar(
desc.clone(),
)),
DivOps::<D>::new(desc),
@ -979,7 +972,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
stream::OperationDescription::NumericInt(NumericOperationDescription::RemScalar(
repr::OperationDescription::NumericInt(NumericOperationDescription::RemScalar(
desc.clone(),
)),
ModOps::<D>::new(desc),
@ -995,7 +988,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for ZerosOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let shape = Shape::from(self.desc.shape.clone());
let output = B::int_zeros::<D>(shape, &handles.device);
handles.register_int_tensor(&self.desc.id, output);
@ -1004,7 +997,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let stream = StreamId::current();
let shape: Vec<usize> = shape.dims.into();
let client = get_client::<B>(&device.clone().into());
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(shape);
let desc = out.to_description_out();
client.register(
@ -1023,7 +1016,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for OnesOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let shape = Shape::from(self.desc.shape.clone());
let output = B::int_ones::<D>(shape, &handles.device);
handles.register_int_tensor(&self.desc.id, output);
@ -1032,7 +1025,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let stream = StreamId::current();
let shape: Vec<usize> = shape.dims.into();
let client = get_client::<B>(&device.clone().into());
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(shape);
let desc = out.to_description_out();
@ -1223,7 +1216,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for ClampOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_int_tensor::<D>(&self.desc.tensor);
let output = B::int_clamp(input, self.desc.min.elem(), self.desc.max.elem());
@ -1274,7 +1267,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for IntoFloatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_int_tensor::<D>(&self.desc.input);
let output = B::int_into_float(input);
handles.register_float_tensor(&self.desc.out.id, output);
@ -1289,7 +1282,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Int(stream::IntOperationDescription::IntoFloat(desc.clone())),
OperationDescription::Int(repr::IntOperationDescription::IntoFloat(desc.clone())),
IntoFloatOps::<D>::new(desc),
);
@ -1307,7 +1300,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for SwapDimsOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_int_tensor::<D>(&self.desc.input);
let output = B::int_swap_dims(input, self.desc.dim1, self.desc.dim2);
handles.register_int_tensor(&self.desc.out.id, output);
@ -1387,7 +1380,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for MaxDimWithIndicesOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let (output, indices) = B::int_max_dim_with_indices(tensor, self.desc.dim);
@ -1470,7 +1463,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for MinDimWithIndicesOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let (output, indices) = B::int_min_dim_with_indices(tensor, self.desc.dim);
@ -1513,7 +1506,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for IntRandomOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let shape = Shape::from(self.desc.out.shape.clone());
let output: B::IntTensorPrimitive<D> =
B::int_random(shape, self.desc.distribution, &handles.device);
@ -1523,7 +1516,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let stream = StreamId::current();
let shape: Vec<usize> = shape.dims.into();
let client = get_client::<B>(&device.clone().into());
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(shape);
let desc = RandomOperationDescription {
@ -1549,7 +1542,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for PermuteDimsOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_int_tensor::<D>(&self.desc.input);
let axes: [usize; D] = self.desc.axes.try_into().unwrap();
let output = B::int_permute(input, axes);
@ -1589,7 +1582,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, const D2: usize, B: FusionBackend> Operation<B> for ExpandOps<D, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let shape: [usize; D2] = self.desc.shape.try_into().unwrap();
let output = B::bool_expand(input, shape.into());
@ -1623,7 +1616,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for FlipDimsOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_int_tensor::<D>(&self.desc.input);
let axes = &self.desc.axes;
let output = B::int_flip(input, axes);
@ -1661,7 +1654,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
impl<const D: usize, B: FusionBackend> Operation<B> for RepeatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let output = B::int_repeat::<D>(tensor, self.desc.dim, self.desc.times);

View File

@ -1,25 +1,25 @@
use crate::stream::InterpolateBackwardDescription;
use crate::{
client::FusionClient,
stream::{
use crate::{client::FusionClient, stream::execution::Operation, Fusion, FusionBackend};
use burn_tensor::{
ops::{
conv::{
calculate_conv_output_size, calculate_conv_transpose_output_size,
calculate_pool_output_size,
},
ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateOptions,
MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices,
ModuleOps,
},
repr::{
AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription,
AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription,
AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription,
AvgPool2dDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription,
ConvTranspose2dDescription, InterpolateDescription, MaxPool1dDescription,
MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription,
MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription,
MaxPool2dWithIndicesDescription, Operation, OperationDescription,
ConvTranspose2dDescription, HandleContainer, InterpolateBackwardDescription,
InterpolateDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription,
MaxPool1dWithIndicesDescription, MaxPool2dDescription,
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription,
ModuleOperationDescription, OperationDescription,
},
Fusion, FusionBackend, HandleContainer,
};
use burn_tensor::ops::{
conv::{
calculate_conv_output_size, calculate_conv_transpose_output_size,
calculate_pool_output_size,
},
ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateOptions,
MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
};
macro_rules! make_ops {
@ -30,7 +30,7 @@ macro_rules! make_ops {
}
impl<B: FusionBackend> Operation<B> for $name {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
#[allow(clippy::redundant_closure_call)]
$fn(self.desc, handles)
}
@ -88,9 +88,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.clone().register(
streams,
OperationDescription::Module(crate::stream::ModuleOperationDescription::Conv1d(
description.clone(),
)),
OperationDescription::Module(ModuleOperationDescription::Conv1d(description.clone())),
Conv1dOps::new(description),
);
@ -155,9 +153,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
streams,
OperationDescription::Module(crate::stream::ModuleOperationDescription::Conv2d(
desc.clone(),
)),
OperationDescription::Module(ModuleOperationDescription::Conv2d(desc.clone())),
Conv2dOps::new(desc),
);
@ -216,9 +212,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
streams,
OperationDescription::Module(
crate::stream::ModuleOperationDescription::ConvTranspose1d(desc.clone()),
),
OperationDescription::Module(ModuleOperationDescription::ConvTranspose1d(desc.clone())),
ConvTranspose1dOps::new(desc),
);
@ -285,9 +279,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
streams,
OperationDescription::Module(
crate::stream::ModuleOperationDescription::ConvTranspose2d(desc.clone()),
),
OperationDescription::Module(ModuleOperationDescription::ConvTranspose2d(desc.clone())),
ConvTranspose2dOps::new(desc),
);
@ -333,9 +325,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Module(crate::stream::ModuleOperationDescription::AvgPool1d(
desc.clone(),
)),
OperationDescription::Module(ModuleOperationDescription::AvgPool1d(desc.clone())),
AvgPool1dOps::new(desc),
);
@ -385,9 +375,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Module(crate::stream::ModuleOperationDescription::AvgPool2d(
desc.clone(),
)),
OperationDescription::Module(ModuleOperationDescription::AvgPool2d(desc.clone())),
AvgPool2dOps::new(desc),
);
@ -436,9 +424,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::Module(
crate::stream::ModuleOperationDescription::AvgPool1dBackward(desc.clone()),
),
OperationDescription::Module(ModuleOperationDescription::AvgPool1dBackward(
desc.clone(),
)),
AvgPool1dBackwardOps::new(desc),
);
@ -487,9 +475,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::Module(
crate::stream::ModuleOperationDescription::AvgPool2dBackward(desc.clone()),
),
OperationDescription::Module(ModuleOperationDescription::AvgPool2dBackward(
desc.clone(),
)),
AvgPool2dBackwardOps::new(desc),
);
@ -536,9 +524,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Module(crate::stream::ModuleOperationDescription::MaxPool1d(
desc.clone(),
)),
OperationDescription::Module(ModuleOperationDescription::MaxPool1d(desc.clone())),
MaxPool1dOps::new(desc),
);
@ -598,9 +584,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Module(crate::stream::ModuleOperationDescription::MaxPool2d(
desc.clone(),
)),
OperationDescription::Module(ModuleOperationDescription::MaxPool2d(desc.clone())),
MaxPool2dOps::new(desc),
);
@ -649,9 +633,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Module(
crate::stream::ModuleOperationDescription::MaxPool1dWithIndices(desc.clone()),
),
OperationDescription::Module(ModuleOperationDescription::MaxPool1dWithIndices(
desc.clone(),
)),
MaxPool1dWithIndicesOps::new(desc),
);
@ -714,9 +698,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Module(
crate::stream::ModuleOperationDescription::MaxPool2dWithIndices(desc.clone()),
),
OperationDescription::Module(ModuleOperationDescription::MaxPool2dWithIndices(
desc.clone(),
)),
MaxPool2dWithIndicesOps::new(desc),
);
@ -770,11 +754,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2, stream_3],
OperationDescription::Module(
crate::stream::ModuleOperationDescription::MaxPool1dWithIndicesBackward(
desc.clone(),
),
),
OperationDescription::Module(ModuleOperationDescription::MaxPool1dWithIndicesBackward(
desc.clone(),
)),
MaxPool1dWithIndicesBackwardOps::new(desc),
);
@ -828,11 +810,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2, stream_3],
OperationDescription::Module(
crate::stream::ModuleOperationDescription::MaxPool2dWithIndicesBackward(
desc.clone(),
),
),
OperationDescription::Module(ModuleOperationDescription::MaxPool2dWithIndicesBackward(
desc.clone(),
)),
MaxPool2dWithIndicesBackwardOps::new(desc),
);
@ -862,9 +842,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Module(
crate::stream::ModuleOperationDescription::AdaptiveAvgPool1d(desc.clone()),
),
OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool1d(
desc.clone(),
)),
AdaptiveAvgPool1dOps::new(desc),
);
@ -897,9 +877,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Module(
crate::stream::ModuleOperationDescription::AdaptiveAvgPool2d(desc.clone()),
),
OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool2d(
desc.clone(),
)),
AdaptiveAvgPool2dOps::new(desc),
);
@ -933,9 +913,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out.client.register(
vec![stream_1, stream_2],
OperationDescription::Module(
crate::stream::ModuleOperationDescription::AdaptiveAvgPool1dBackward(desc.clone()),
),
OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool1dBackward(
desc.clone(),
)),
AdaptiveAvgPool1dBackwardOps::new(desc),
);
@ -969,9 +949,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::Module(
crate::stream::ModuleOperationDescription::AdaptiveAvgPool2dBackward(desc.clone()),
),
OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool2dBackward(
desc.clone(),
)),
AdaptiveAvgPool2dBackwardOps::new(desc),
);
@ -1006,9 +986,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out.client.register(
vec![stream],
OperationDescription::Module(crate::stream::ModuleOperationDescription::Interpolate(
desc.clone(),
)),
OperationDescription::Module(ModuleOperationDescription::Interpolate(desc.clone())),
InterpolateOps::new(desc),
);
@ -1047,9 +1025,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::Module(
crate::stream::ModuleOperationDescription::InterpolateBackward(desc.clone()),
),
OperationDescription::Module(ModuleOperationDescription::InterpolateBackward(
desc.clone(),
)),
InterpolateBackwardOps::new(desc),
);
out

View File

@ -18,7 +18,7 @@ macro_rules! scalar_float_ops {
}
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
@ -38,7 +38,7 @@ macro_rules! scalar_float_ops {
}
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
let output = $ops(lhs, self.desc.rhs);
@ -62,7 +62,7 @@ macro_rules! scalar_float2int_ops {
}
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
let output = $ops(lhs, self.desc.rhs.clone());
@ -85,7 +85,7 @@ macro_rules! unary_float_ops {
}
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.input);
let output = $ops(input);
@ -108,7 +108,7 @@ macro_rules! unary_int_ops {
}
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let input = handles.get_int_tensor::<D>(&self.desc.input);
let output = $ops(input);
@ -131,7 +131,7 @@ macro_rules! scalar_float_cmp_ops {
}
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
@ -154,7 +154,7 @@ macro_rules! scalar_int_cmp_ops {
}
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
@ -184,7 +184,7 @@ macro_rules! scalar_int_ops {
}
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
@ -204,7 +204,7 @@ macro_rules! scalar_int_ops {
}
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
let output = $ops(lhs, self.desc.rhs);

View File

@ -1,8 +1,11 @@
use crate::{
stream::{MultiStream, Operation, OperationDescription, StreamId},
FusionBackend, HandleContainer, TensorId,
stream::{execution::Operation, MultiStream, StreamId},
FusionBackend,
};
use burn_tensor::{
ops::{FloatElem, IntElem},
repr::{HandleContainer, OperationDescription, TensorDescription, TensorId},
};
use burn_tensor::ops::{FloatElem, IntElem};
use std::sync::Arc;
pub struct FusionServer<B>
@ -11,14 +14,14 @@ where
{
streams: MultiStream<B>,
pub(crate) handles: HandleContainer<B>,
pub device: B::FusionDevice,
pub device: B::Device,
}
impl<B> FusionServer<B>
where
B: FusionBackend,
{
pub fn new(device: B::FusionDevice) -> Self {
pub fn new(device: B::Device) -> Self {
Self {
streams: MultiStream::new(device.clone()),
handles: HandleContainer::new(device.clone()),
@ -46,7 +49,7 @@ where
pub fn read_float<const D: usize>(
&mut self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
id: StreamId,
) -> burn_tensor::Reader<burn_tensor::Data<FloatElem<B>, D>> {
// Make sure all registered operations are executed.
@ -59,7 +62,7 @@ where
pub fn read_int<const D: usize>(
&mut self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
id: StreamId,
) -> burn_tensor::Reader<burn_tensor::Data<IntElem<B>, D>> {
// Make sure all registered operations are executed.
@ -72,7 +75,7 @@ where
pub fn read_bool<const D: usize>(
&mut self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
id: StreamId,
) -> burn_tensor::Reader<burn_tensor::Data<bool, D>> {
// Make sure all registered operations are executed.
@ -85,7 +88,7 @@ where
pub fn change_server_float<const D: usize>(
&mut self,
tensor: &crate::TensorDescription,
tensor: &TensorDescription,
device: &B::Device,
server_device: &mut Self,
) -> Arc<TensorId> {
@ -101,7 +104,7 @@ where
}
pub fn change_server_int<const D: usize>(
&mut self,
tensor: &crate::TensorDescription,
tensor: &TensorDescription,
device: &B::Device,
server_device: &mut Self,
) -> Arc<TensorId> {
@ -117,7 +120,7 @@ where
}
pub fn change_server_bool<const D: usize>(
&mut self,
tensor: &crate::TensorDescription,
tensor: &TensorDescription,
device: &B::Device,
server_device: &mut Self,
) -> Arc<TensorId> {

View File

@ -1,8 +1,9 @@
use super::Operation;
use super::OperationConverter;
use super::OperationDescription;
use burn_tensor::repr::OperationDescription;
use crate::FusionBackend;
use super::{execution::Operation, OperationConverter, RelativeOps};
/// A growing list of [tensor operation descriptions](OperationDescription).
pub struct OperationQueue<B: FusionBackend> {
pub(crate) global: Vec<OperationDescription>,

View File

@ -1,23 +1,5 @@
use super::{
AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription,
AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription,
AvgPool2dBackwardDescription, AvgPool2dDescription, BaseOperationDescription,
BinaryOperationDescription, BoolOperationDescription, ClampOperationDescription,
Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, ConvTranspose2dDescription,
EmbeddingBackwardDescription, EmbeddingDescription, ExpandOperationDescription,
FlipOperationDescription, FloatOperationDescription, GatherOperationDescription,
IntOperationDescription, InterpolateBackwardDescription, InterpolateDescription,
MaskFillOperationDescription, MaskWhereOperationDescription, MaxPool1dDescription,
MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription, MaxPool2dDescription,
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription,
ModuleOperationDescription, NumericOperationDescription, OperationDescription,
PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription,
ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription,
SelectAssignOperationDescription, SelectOperationDescription, SliceOperationDescription,
SwapDimsDescription, UnaryOperationDescription,
};
use crate::{FusionBackend, HandleContainer, TensorDescription, TensorId};
use burn_tensor::{Element, ElementConversion};
use crate::FusionBackend;
use burn_tensor::{repr::*, Element, ElementConversion};
use hashbrown::HashMap;
/// The context contains the relative graph tensor mapping so that a relative tensor id can be
@ -49,6 +31,16 @@ pub(crate) struct OperationConverter {
scalar_ints: Vec<i32>,
}
pub(crate) trait RelativeOps {
fn to_relative(&self, converter: &mut OperationConverter) -> Self;
}
trait RelativeOpsScalar<E: Element> {
fn to_relative<F>(&self, converter: &mut OperationConverter, local_elem: F) -> Self
where
F: Fn(&mut OperationConverter, &E) -> E;
}
impl OperationConverter {
pub(crate) fn context<'a, B: FusionBackend>(
&'a self,
@ -85,8 +77,8 @@ impl OperationConverter {
}
}
impl OperationDescription {
pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self {
impl RelativeOps for OperationDescription {
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
match self {
OperationDescription::BaseFloat(ops) => {
OperationDescription::BaseFloat(ops.to_relative(converter))
@ -117,8 +109,8 @@ impl OperationDescription {
}
}
impl ModuleOperationDescription {
pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self {
impl RelativeOps for ModuleOperationDescription {
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
match self {
ModuleOperationDescription::Embedding(desc) => {
ModuleOperationDescription::Embedding(EmbeddingDescription {
@ -172,7 +164,7 @@ impl ModuleOperationDescription {
})
}
ModuleOperationDescription::AvgPool1d(desc) => {
ModuleOperationDescription::AvgPool1d(super::AvgPool1dDescription {
ModuleOperationDescription::AvgPool1d(AvgPool1dDescription {
x: desc.x.to_relative(converter),
kernel_size: desc.kernel_size,
stride: desc.stride,
@ -192,7 +184,7 @@ impl ModuleOperationDescription {
})
}
ModuleOperationDescription::AvgPool1dBackward(desc) => {
ModuleOperationDescription::AvgPool1dBackward(super::AvgPool1dBackwardDescription {
ModuleOperationDescription::AvgPool1dBackward(AvgPool1dBackwardDescription {
x: desc.x.to_relative(converter),
grad: desc.grad.to_relative(converter),
kernel_size: desc.kernel_size,
@ -336,8 +328,8 @@ impl ModuleOperationDescription {
}
}
impl FloatOperationDescription {
pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self {
impl RelativeOps for FloatOperationDescription {
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
match self {
FloatOperationDescription::Exp(desc) => {
FloatOperationDescription::Exp(UnaryOperationDescription {
@ -423,8 +415,8 @@ impl FloatOperationDescription {
}
}
impl BoolOperationDescription {
pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self {
impl RelativeOps for BoolOperationDescription {
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
match self {
BoolOperationDescription::IntoFloat(desc) => {
BoolOperationDescription::IntoFloat(UnaryOperationDescription {
@ -448,8 +440,8 @@ impl BoolOperationDescription {
}
}
impl IntOperationDescription {
pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self {
impl RelativeOps for IntOperationDescription {
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
match self {
IntOperationDescription::IntoFloat(desc) => {
IntOperationDescription::IntoFloat(UnaryOperationDescription {
@ -461,8 +453,8 @@ impl IntOperationDescription {
}
}
impl<E: Element> NumericOperationDescription<E> {
pub(crate) fn to_relative<F>(&self, converter: &mut OperationConverter, local_elem: F) -> Self
impl<E: Element> RelativeOpsScalar<E> for NumericOperationDescription<E> {
fn to_relative<F>(&self, converter: &mut OperationConverter, local_elem: F) -> Self
where
F: Fn(&mut OperationConverter, &E) -> E,
{
@ -779,8 +771,8 @@ impl<E: Element> NumericOperationDescription<E> {
}
}
impl BaseOperationDescription {
pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self {
impl RelativeOps for BaseOperationDescription {
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
match self {
BaseOperationDescription::ToDevice(desc) => {
BaseOperationDescription::ToDevice(desc.to_relative(converter))
@ -828,7 +820,7 @@ impl BaseOperationDescription {
})
}
BaseOperationDescription::SliceAssign(desc) => {
BaseOperationDescription::SliceAssign(super::SliceAssignOperationDescription {
BaseOperationDescription::SliceAssign(SliceAssignOperationDescription {
tensor: desc.tensor.to_relative(converter),
ranges: desc.ranges.iter().map(|_range| 0..1).collect(),
value: desc.value.to_relative(converter),
@ -836,14 +828,14 @@ impl BaseOperationDescription {
})
}
BaseOperationDescription::Equal(desc) => {
BaseOperationDescription::Equal(super::BinaryOperationDescription {
BaseOperationDescription::Equal(BinaryOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
BaseOperationDescription::Repeat(desc) => {
BaseOperationDescription::Repeat(super::RepeatOperationDescription {
BaseOperationDescription::Repeat(RepeatOperationDescription {
tensor: desc.tensor.to_relative(converter),
dim: desc.dim,
times: desc.times,
@ -851,7 +843,7 @@ impl BaseOperationDescription {
})
}
BaseOperationDescription::Cat(desc) => {
BaseOperationDescription::Cat(super::CatOperationDescription {
BaseOperationDescription::Cat(CatOperationDescription {
tensors: desc
.tensors
.iter()
@ -865,8 +857,8 @@ impl BaseOperationDescription {
}
}
impl TensorDescription {
pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self {
impl RelativeOps for TensorDescription {
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
let relative_id = if let Some(value) = converter.tensors_global2relative.get(&self.id) {
// If we already have the same tensor registered, we have to update its value, but not
// its id.
@ -911,9 +903,8 @@ impl TensorDescription {
#[cfg(test)]
mod tests {
use crate::TensorStatus;
use super::*;
use burn_tensor::repr::{TensorDescription, TensorId, TensorStatus};
#[test]
fn tensor_description_to_relative() {

View File

@ -1,9 +1,11 @@
use burn_tensor::repr::HandleContainer;
use crate::{
stream::{
store::{ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy},
OperationQueue,
OperationQueue, RelativeOps,
},
FusionBackend, HandleContainer, Optimization,
FusionBackend, Optimization,
};
/// The mode in which the execution is done.
@ -13,6 +15,12 @@ pub(crate) enum ExecutionMode {
Sync,
}
/// General trait to abstract how a single operation is executed.
pub trait Operation<B: FusionBackend>: Send + Sync {
/// Execute the operation.
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>);
}
impl<B: FusionBackend> OperationQueue<B> {
/// Execute the queue partially following the execution strategy from the plan.
pub(crate) fn execute(

View File

@ -1,5 +1,7 @@
use burn_tensor::repr::OperationDescription;
use super::ExecutionMode;
use crate::{stream::OperationDescription, OptimizationBuilder, OptimizationStatus};
use crate::{OptimizationBuilder, OptimizationStatus};
/// Explore and create new optimization.
pub struct Explorer<O> {

View File

@ -1,13 +1,12 @@
use burn_tensor::repr::OperationDescription;
use super::validator::{
ExecutionPlanOperationsStore, TriggerOperationsStore, TriggerProgress, TriggerValidator,
ValidatorState,
};
use super::ExecutionMode;
use crate::stream::execution::validator::OperationsValidator;
use crate::stream::{
store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger, SearchQuery},
OperationDescription,
};
use crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger, SearchQuery};
use std::marker::PhantomData;
/// The policy keeps track of all possible execution plans for the current operations.
@ -266,14 +265,13 @@ impl<O> Policy<O> {
#[cfg(test)]
mod tests {
use super::*;
use crate::{
stream::{
store::{ExecutionPlan, ExecutionStrategy, ExecutionTrigger},
FloatOperationDescription, UnaryOperationDescription,
},
TensorDescription, TensorId, TensorStatus,
use burn_tensor::repr::{
FloatOperationDescription, TensorDescription, TensorId, TensorStatus,
UnaryOperationDescription,
};
use super::*;
use crate::stream::store::{ExecutionPlan, ExecutionStrategy, ExecutionTrigger};
use std::ops::Range;
#[test]

View File

@ -1,9 +1,10 @@
use burn_tensor::repr::OperationDescription;
use super::{ExecutionMode, Exploration, Explorer};
use crate::stream::execution::{Action, Policy};
use crate::stream::store::{
ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger,
};
use crate::stream::OperationDescription;
use crate::OptimizationBuilder;
/// Process a [stream segment](StreamSegment) following a [policy](Policy).

View File

@ -6,16 +6,17 @@
//! To test these components effectively, we create mock types for the stream, optimization,
//! optimization builder, and stream segment. These mock types aid in comprehensively
//! understanding the process of optimizing streams.
use burn_tensor::repr::{
BinaryOperationDescription, FloatOperationDescription, NumericOperationDescription,
OperationDescription, ScalarOperationDescription, TensorDescription, TensorId, TensorStatus,
UnaryOperationDescription,
};
use crate::{
stream::{
store::{
ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger,
},
BinaryOperationDescription, FloatOperationDescription, NumericOperationDescription,
OperationDescription, ScalarOperationDescription,
stream::store::{
ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger,
},
OptimizationBuilder, OptimizationProperties, OptimizationStatus, TensorDescription, TensorId,
TensorStatus,
OptimizationBuilder, OptimizationProperties, OptimizationStatus,
};
use super::*;
@ -558,18 +559,16 @@ fn operation_2() -> OperationDescription {
/// Just a simple operation.
fn operation_3() -> OperationDescription {
OperationDescription::Float(FloatOperationDescription::Log(
crate::stream::UnaryOperationDescription {
input: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
},
out: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
status: TensorStatus::NotInit,
},
OperationDescription::Float(FloatOperationDescription::Log(UnaryOperationDescription {
input: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
},
))
out: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
status: TensorStatus::NotInit,
},
}))
}

View File

@ -1,7 +1,6 @@
use crate::stream::{
store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger},
OperationDescription,
};
use burn_tensor::repr::OperationDescription;
use crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger};
/// Compare each operation in the list of operations provided by the [store](OperationsStore)
/// to verify if the newly added operations match the original list.

View File

@ -4,9 +4,7 @@ pub(crate) mod store;
mod base;
mod context;
mod multi;
mod operation;
pub use base::*;
pub use context::*;
pub use multi::*;
pub use operation::*;

View File

@ -1,20 +1,22 @@
use burn_tensor::repr::{HandleContainer, OperationDescription};
use super::{
execution::{ExecutionMode, Processor, StreamSegment},
execution::{ExecutionMode, Operation, Processor, StreamSegment},
store::{ExecutionPlanId, ExecutionPlanStore},
Operation, OperationDescription, OperationQueue, StreamId,
OperationQueue, StreamId,
};
use crate::{FusionBackend, HandleContainer};
use crate::FusionBackend;
use std::collections::HashMap;
/// Keep track of multiple concurrent streams of operations.
pub struct MultiStream<B: FusionBackend> {
streams: HashMap<StreamId, Stream<B>>,
optimizations: ExecutionPlanStore<B::Optimization>,
device: B::FusionDevice,
device: B::Device,
}
impl<B: FusionBackend> MultiStream<B> {
pub(crate) fn new(device: B::FusionDevice) -> Self {
pub(crate) fn new(device: B::Device) -> Self {
Self {
streams: HashMap::new(),
optimizations: ExecutionPlanStore::new(),
@ -146,9 +148,9 @@ impl<'i, B: FusionBackend> StreamSegment<B::Optimization> for Segment<'i, B> {
}
impl<B: FusionBackend> Stream<B> {
fn new(device: B::FusionDevice) -> Self {
fn new(device: B::Device) -> Self {
Self {
processor: Processor::new(B::optimizations(device.into())),
processor: Processor::new(B::optimizations(device)),
queue: OperationQueue::new(),
}
}

View File

@ -1,5 +1,5 @@
use super::{ExecutionPlanIndex, InsertQuery, SearchQuery};
use crate::stream::OperationDescription;
use burn_tensor::repr::OperationDescription;
use serde::{Deserialize, Serialize};
/// The store that contains all explorations done on a device.

View File

@ -1,4 +1,5 @@
use crate::stream::{store::ExecutionPlanId, OperationDescription};
use crate::stream::store::ExecutionPlanId;
use burn_tensor::repr::OperationDescription;
use serde::{Deserialize, Serialize};
use std::{
collections::{hash_map::DefaultHasher, HashMap},
@ -115,14 +116,13 @@ impl ExecutionPlanIndex {
#[cfg(test)]
mod tests {
use super::*;
use crate::{
stream::{
BinaryOperationDescription, NumericOperationDescription, ScalarOperationDescription,
},
use burn_tensor::repr::{
BinaryOperationDescription, NumericOperationDescription, ScalarOperationDescription,
TensorDescription, TensorId, TensorStatus,
};
use super::*;
#[test]
fn should_find_optimization_id_based_on_tensor_ops() {
let mut index = ExecutionPlanIndex::default();

View File

@ -2,9 +2,9 @@ use crate::{client::FusionClient, stream::StreamId};
use burn_tensor::{
backend::Backend,
ops::{FloatElem, IntElem},
repr::{TensorDescription, TensorId, TensorStatus},
Data, Reader, Shape,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
/// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind.
@ -33,7 +33,7 @@ impl<C: FusionClient> core::fmt::Debug for FusionTensor<C> {
self.shape,
self.is_orphan,
<C::FusionBackend as Backend>::name(),
self.client.device().clone().into(),
self.client.device().clone(),
)
.as_str(),
)
@ -127,47 +127,3 @@ impl<C: FusionClient> Drop for FusionTensor<C> {
}
}
}
/// The tensor unique identifier.
#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]
pub struct TensorId {
value: u64,
}
/// The status of the current tensor.
#[derive(Hash, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum TensorStatus {
/// The tensor can be read, but not written.
ReadOnly,
/// The tensor can be mutated inplace.
ReadWrite,
/// No handle exists for that tensor.
NotInit,
}
/// A tensor definition represents a snapshot of a tensor when it was used.
///
/// # Example
///
/// A tensor that is used multiple times has its status updated for each operation.
///
/// 1. Status::NotInit
/// 2. Status::ReadOnly
/// 3. Status::ReadOnly
/// 4. Status::ReadWrite
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct TensorDescription {
/// The [tensor id](TensorId).
pub id: TensorId,
/// The shape of the tensor.
pub shape: Vec<usize>,
/// The [status](TensorStatus) of the tensor when it was used.
pub status: TensorStatus,
}
impl TensorId {
/// Create a new tensor id.
pub fn new(value: u64) -> Self {
Self { value }
}
}

View File

@ -1,7 +1,5 @@
#[cfg(feature = "fusion")]
use crate::fusion::JitFusionHandle;
#[cfg(feature = "fusion")]
use burn_fusion::TensorDescription;
use super::{
dialect::gpu::{self},
@ -136,8 +134,8 @@ impl CompilationSettings {
pub fn dynamic_settings<R: Runtime>(
self,
info: &CompilationInfo,
inputs: &[&TensorDescription],
outputs: &[&TensorDescription],
inputs: &[&burn_tensor::repr::TensorDescription],
outputs: &[&burn_tensor::repr::TensorDescription],
handles_inputs: &[JitFusionHandle<R>],
stateful: bool,
) -> Self {
@ -154,8 +152,8 @@ impl CompilationSettings {
fn dynamic_inplace<R: Runtime>(
self,
info: &CompilationInfo,
inputs: &[&TensorDescription],
outputs: &[&TensorDescription],
inputs: &[&burn_tensor::repr::TensorDescription],
outputs: &[&burn_tensor::repr::TensorDescription],
handles_inputs: &[JitFusionHandle<R>],
) -> Self {
let mut potential_inplace = inputs
@ -170,9 +168,9 @@ impl CompilationSettings {
}
match desc.status {
burn_fusion::TensorStatus::ReadOnly => return None,
burn_fusion::TensorStatus::NotInit => return None,
burn_fusion::TensorStatus::ReadWrite => (),
burn_tensor::repr::TensorStatus::ReadOnly => return None,
burn_tensor::repr::TensorStatus::NotInit => return None,
burn_tensor::repr::TensorStatus::ReadWrite => (),
};
Some((pos, desc, input))
@ -215,8 +213,8 @@ impl CompilationSettings {
fn dynamic_reading_strategy<R: Runtime>(
mut self,
info: &CompilationInfo,
inputs: &[&TensorDescription],
outputs: &[&TensorDescription],
inputs: &[&burn_tensor::repr::TensorDescription],
outputs: &[&burn_tensor::repr::TensorDescription],
handles_inputs: &[JitFusionHandle<R>],
) -> Self {
// First output is chosen for the layout reference.

View File

@ -4,7 +4,7 @@ use crate::{
};
use burn_compute::client::ComputeClient;
use burn_fusion::{client::MutexFusionClient, FusionBackend};
use burn_tensor::Shape;
use burn_tensor::{repr::ReprBackend, Shape};
use core::marker::PhantomData;
use serde::{Deserialize, Serialize};
@ -53,53 +53,61 @@ impl<R: Runtime> burn_fusion::Optimization<JitBackend<R>> for JitOptimization<R>
}
}
impl<R: Runtime> FusionBackend for JitBackend<R> {
type OptimizationState = JitOptimizationState;
type Optimization = JitOptimization<R>;
type FusionDevice = R::Device;
impl<R: Runtime> ReprBackend for JitBackend<R> {
type Handle = JitFusionHandle<R>;
type FusionClient = MutexFusionClient<Self>;
fn optimizations(
device: R::Device,
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
vec![Box::new(ElementWiseBuilder::new(device))]
}
fn float_tensor<const D: usize>(
handle: Self::Handle,
shape: Shape<D>,
) -> Self::FloatTensorPrimitive<D> {
) -> burn_tensor::ops::FloatTensor<Self, D> {
handle.into_tensor(shape)
}
fn int_tensor<const D: usize>(
handle: Self::Handle,
shape: Shape<D>,
) -> Self::IntTensorPrimitive<D> {
) -> burn_tensor::ops::IntTensor<Self, D> {
handle.into_tensor(shape)
}
fn bool_tensor<const D: usize>(
handle: Self::Handle,
shape: Shape<D>,
) -> Self::BoolTensorPrimitive<D> {
) -> burn_tensor::ops::BoolTensor<Self, D> {
handle.into_tensor(shape)
}
fn float_tensor_handle<const D: usize>(tensor: Self::FloatTensorPrimitive<D>) -> Self::Handle {
fn float_tensor_handle<const D: usize>(
tensor: burn_tensor::ops::FloatTensor<Self, D>,
) -> Self::Handle {
tensor.into()
}
fn int_tensor_handle<const D: usize>(tensor: Self::IntTensorPrimitive<D>) -> Self::Handle {
fn int_tensor_handle<const D: usize>(
tensor: burn_tensor::ops::IntTensor<Self, D>,
) -> Self::Handle {
tensor.into()
}
fn bool_tensor_handle<const D: usize>(tensor: Self::BoolTensorPrimitive<D>) -> Self::Handle {
fn bool_tensor_handle<const D: usize>(
tensor: burn_tensor::ops::BoolTensor<Self, D>,
) -> Self::Handle {
tensor.into()
}
}
impl<R: Runtime> FusionBackend for JitBackend<R> {
type OptimizationState = JitOptimizationState;
type Optimization = JitOptimization<R>;
type FusionClient = MutexFusionClient<Self>;
fn optimizations(
device: R::Device,
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
vec![Box::new(ElementWiseBuilder::new(device))]
}
}
pub fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
let mut strides = vec![0; shape.len()];

View File

@ -7,16 +7,14 @@ use crate::{
fusion::{tracing::TraceBuilder, JitOptimization},
JitBackend, Runtime,
};
use burn_fusion::{
stream::{
BaseOperationDescription, BinaryOperationDescription, FloatOperationDescription,
NumericOperationDescription, OperationDescription, ScalarOperationDescription,
UnaryOperationDescription,
},
OptimizationBuilder, OptimizationProperties, OptimizationStatus, TensorDescription,
};
use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus};
use burn_tensor::{
ops::{FloatElem, IntElem},
repr::{
BaseOperationDescription, BinaryOperationDescription, FloatOperationDescription,
NumericOperationDescription, OperationDescription, ScalarOperationDescription,
TensorDescription, UnaryOperationDescription,
},
Device, Element,
};

View File

@ -1,3 +1,5 @@
use burn_tensor::repr::TensorDescription;
use crate::{
codegen::{
calculate_num_elems_dyn_rank,
@ -11,7 +13,6 @@ use crate::{
kernel::elemwise_workgroup,
Runtime,
};
use burn_fusion::TensorDescription;
use std::{marker::PhantomData, sync::Arc};
#[derive(new)]

View File

@ -15,7 +15,8 @@ use burn_compute::client::ComputeClient;
use burn_compute::server::Handle;
use burn_compute::tune::AutotuneOperation;
use burn_fusion::stream::Context;
use burn_fusion::{TensorDescription, TensorStatus};
use burn_tensor::repr::TensorDescription;
use burn_tensor::repr::TensorStatus;
use burn_tensor::Device;
use std::marker::PhantomData;
use std::sync::Arc;

View File

@ -1,7 +1,9 @@
use super::{trace::Trace, Scalars};
use crate::codegen::dialect::gpu::{self, Operation, Variable};
use burn_fusion::{TensorDescription, TensorId};
use burn_tensor::Element;
use burn_tensor::{
repr::{TensorDescription, TensorId, TensorStatus},
Element,
};
use hashbrown::HashMap;
/// Type facilitating building a [trace](Trace) by doing most of the conversions between the
@ -415,7 +417,7 @@ impl TraceBuilder {
// are going to be used after the fused kernel by other operations.
for entry in self.tensors.values() {
let (tensor, _) = &entry;
if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
if let TensorStatus::ReadOnly = tensor.status {
if self.output_to_local.contains_key(&tensor.id) {
outputs.push(entry.clone());
}

View File

@ -1,6 +1,6 @@
use super::Scalars;
use crate::codegen::{dialect::gpu, CompilationInfo, InputInfo, OutputInfo};
use burn_fusion::TensorDescription;
use burn_tensor::repr::TensorDescription;
use serde::{Deserialize, Serialize};
/// A trace encapsulates all information necessary to perform the compilation and execution of

View File

@ -16,8 +16,7 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
/// The channel used to communicate with the compute server.
type Channel: ComputeChannel<Self::Server>;
/// The device used to retrieve the compute client.
#[cfg(any(feature = "fusion", test))]
type Device: burn_fusion::FusionDevice
type Device: burn_tensor::backend::DeviceOps
+ Default
+ core::hash::Hash
+ PartialEq
@ -26,16 +25,6 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
+ core::fmt::Debug
+ Sync
+ Send;
/// The device used to retrieve the compute client.
#[cfg(not(any(feature = "fusion", test)))]
type Device: Default
+ core::hash::Hash
+ PartialEq
+ Eq
+ Clone
+ core::fmt::Debug
+ Sync
+ Send;
/// A version of the runtime that supports full precision.
///

View File

@ -2,7 +2,7 @@ use crate::NdArrayTensor;
use crate::{element::FloatNdArrayElement, PrecisionBridge};
use alloc::string::String;
use burn_common::stub::Mutex;
use burn_tensor::backend::Backend;
use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
use core::marker::PhantomData;
use rand::{rngs::StdRng, SeedableRng};
@ -15,6 +15,14 @@ pub enum NdArrayDevice {
Cpu,
}
impl DeviceOps for NdArrayDevice {
fn id(&self) -> burn_tensor::backend::DeviceId {
match self {
NdArrayDevice::Cpu => DeviceId::new(0, 0),
}
}
}
impl Default for NdArrayDevice {
fn default() -> Self {
Self::Cpu

View File

@ -2,7 +2,7 @@ use crate::PrecisionBridge;
use super::element::TchElement;
use super::TchTensor;
use burn_tensor::backend::Backend;
use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
use burn_tensor::ops::IntTensorOps;
use burn_tensor::{Int, Tensor};
@ -59,6 +59,17 @@ impl From<tch::Device> for LibTorchDevice {
}
}
impl DeviceOps for LibTorchDevice {
fn id(&self) -> burn_tensor::backend::DeviceId {
match self {
LibTorchDevice::Cpu => DeviceId::new(0, 0),
LibTorchDevice::Cuda(index) => DeviceId::new(1, *index as u32),
LibTorchDevice::Mps => DeviceId::new(2, 0),
LibTorchDevice::Vulkan => DeviceId::new(3, 0),
}
}
}
impl Default for LibTorchDevice {
fn default() -> Self {
Self::Cpu

View File

@ -11,11 +11,12 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-tensor"
version.workspace = true
[features]
default = ["std"]
default = ["std", "repr"]
doc = ["default"]
experimental-named-tensor = []
export_tests = ["burn-tensor-testgen"]
std = ["rand/std", "half/std", "num-traits/std"]
repr = []
wasm-sync = []
[dependencies]

View File

@ -11,6 +11,10 @@ extern crate alloc;
mod tensor;
/// Burn Tensor representaton
#[cfg(feature = "repr")]
pub mod repr;
#[cfg(feature = "export_tests")]
#[allow(missing_docs)]
mod tests;

View File

@ -0,0 +1,26 @@
use crate::{
backend::Backend,
ops::{BoolTensor, FloatTensor, IntTensor},
Shape,
};
/// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation
/// for compilation purpose or other...
pub trait ReprBackend: Backend {
/// The type that can be used to point to a tensor of any kind.
type Handle: Sync + Send + Clone;
/// Convert a [handle](ReprBackend::Handle) to a [float tensor](Backend::FloatTensorPrimitive).
fn float_tensor<const D: usize>(handle: Self::Handle, shape: Shape<D>) -> FloatTensor<Self, D>;
/// Convert a [handle](ReprBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive).
fn int_tensor<const D: usize>(handle: Self::Handle, shape: Shape<D>) -> IntTensor<Self, D>;
/// Convert a [handle](ReprBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive).
fn bool_tensor<const D: usize>(handle: Self::Handle, shape: Shape<D>) -> BoolTensor<Self, D>;
/// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](ReprBackend::Handle).
fn float_tensor_handle<const D: usize>(tensor: FloatTensor<Self, D>) -> Self::Handle;
/// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](ReprBackend::Handle).
fn int_tensor_handle<const D: usize>(tensor: IntTensor<Self, D>) -> Self::Handle;
/// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](ReprBackend::Handle).
fn bool_tensor_handle<const D: usize>(tensor: BoolTensor<Self, D>) -> Self::Handle;
}

View File

@ -1,30 +1,41 @@
use crate::{FusionBackend, TensorDescription, TensorId, TensorStatus};
use burn_tensor::Shape;
use crate::{
backend::Backend,
repr::{
backend::ReprBackend,
tensor::{TensorDescription, TensorId, TensorStatus},
},
Shape,
};
use std::{collections::HashMap, sync::Arc};
/// Keep all [tensor handles](FusionBackend::Handle) in one place and ensure that all resources
/// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources
/// are used optimally.
#[derive(Default)]
pub struct HandleContainer<B: FusionBackend> {
pub struct HandleContainer<B: ReprBackend> {
handles: HashMap<TensorId, Handle<B>>,
counter: u64,
pub(crate) handles_orphan: Vec<TensorId>,
/// Handle candidates to be freed.
pub handles_orphan: Vec<TensorId>,
/// The device on which all tensors are held.
pub device: B::Device,
}
enum Handle<B: FusionBackend> {
/// Backend [tensor handle](ReprBackend::Handle) wrapper tracking their creation state
pub enum Handle<B: Backend + ReprBackend> {
/// No [tensor handle](ReprBackend::Handle) has been created yet
NotInit,
/// A [tensor handle](ReprBackend::Handle) has been created
Existing(B::Handle),
}
impl<B: FusionBackend> HandleContainer<B> {
pub(crate) fn new(device_handle: B::FusionDevice) -> Self {
impl<B: ReprBackend> HandleContainer<B> {
/// Create a new HandleContainer
pub fn new(device_handle: B::Device) -> Self {
Self {
handles: HashMap::new(),
handles_orphan: Vec::new(),
counter: 0,
device: device_handle.clone().into(),
device: device_handle.clone(),
}
}
@ -59,69 +70,69 @@ impl<B: FusionBackend> HandleContainer<B> {
}
}
/// Get the [float tensor](burn_tensor::backend::Backend::FloatTensorPrimitive) corresponding to the
/// Get the [float tensor](Backend::FloatTensorPrimitive) corresponding to the
/// given [tensor description](TensorDescription).
pub fn get_float_tensor<const D: usize>(
&mut self,
tensor: &TensorDescription,
) -> B::FloatTensorPrimitive<D> {
B::float_tensor(
B::float_tensor::<D>(
self.get_handle(&tensor.id, &tensor.status),
Shape::from(&tensor.shape),
)
}
/// Get the [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) corresponding to the
/// Get the [int tensor](Backend::IntTensorPrimitive) corresponding to the
/// given [tensor description](TensorDescription).
pub fn get_int_tensor<const D: usize>(
&mut self,
tensor: &TensorDescription,
) -> B::IntTensorPrimitive<D> {
B::int_tensor(
B::int_tensor::<D>(
self.get_handle(&tensor.id, &tensor.status),
Shape::from(&tensor.shape),
)
}
/// Get the [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) corresponding to the
/// Get the [bool tensor](Backend::BoolTensorPrimitive) corresponding to the
/// given [tensor description](TensorDescription).
pub fn get_bool_tensor<const D: usize>(
&mut self,
tensor: &TensorDescription,
) -> B::BoolTensorPrimitive<D> {
B::bool_tensor(
B::bool_tensor::<D>(
self.get_handle(&tensor.id, &tensor.status),
Shape::from(&tensor.shape),
)
}
/// Register a new [float tensor](burn_tensor::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId).
/// Register a new [float tensor](Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId).
pub fn register_float_tensor<const D: usize>(
&mut self,
id: &TensorId,
tensor: B::FloatTensorPrimitive<D>,
) {
let handle = B::float_tensor_handle(tensor);
let handle = B::float_tensor_handle::<D>(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
/// Register a new [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId).
/// Register a new [int tensor](Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId).
pub fn register_int_tensor<const D: usize>(
&mut self,
id: &TensorId,
tensor: B::IntTensorPrimitive<D>,
) {
let handle = B::int_tensor_handle(tensor);
let handle = B::int_tensor_handle::<D>(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
/// Register a new [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId).
/// Register a new [bool tensor](Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId).
pub fn register_bool_tensor<const D: usize>(
&mut self,
id: &TensorId,
tensor: B::BoolTensorPrimitive<D>,
) {
let handle = B::bool_tensor_handle(tensor);
let handle = B::bool_tensor_handle::<D>(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
@ -134,7 +145,8 @@ impl<B: FusionBackend> HandleContainer<B> {
Arc::new(id)
}
pub(crate) fn free(&mut self, tensor: &TensorDescription) {
/// Remove tensor handle from container if writable
pub fn free(&mut self, tensor: &TensorDescription) {
match tensor.status {
TensorStatus::ReadOnly => (),
TensorStatus::NotInit => (),
@ -144,7 +156,8 @@ impl<B: FusionBackend> HandleContainer<B> {
}
}
pub(crate) fn free_orphans(&mut self, remaining: &[&TensorId]) {
/// Remove tensor handle from container if not in use
pub fn free_orphans(&mut self, remaining: &[&TensorId]) {
let mut handles_orphan = Vec::new();
// TODO: Optimization => Change the for loop order depending of the length of each.

View File

@ -0,0 +1,9 @@
mod backend;
mod handle;
mod operation;
mod tensor;
pub use backend::*;
pub use handle::*;
pub use operation::*;
pub use tensor::*;

View File

@ -1,15 +1,11 @@
use crate::FusionBackend;
use crate::{HandleContainer, TensorDescription};
use burn_tensor::ops::{ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions};
use burn_tensor::{Distribution, Element};
use serde::{Deserialize, Serialize};
use std::ops::Range;
/// General trait to abstract how a single operation is executed.
pub trait Operation<B: FusionBackend>: Send + Sync {
/// Execute the operation.
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>);
}
use crate::{
ops::{ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions},
repr::tensor::TensorDescription,
Distribution, Element,
};
/// Describe all tensor operations possible.
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
@ -37,92 +33,92 @@ pub enum OperationDescription {
/// Operation description specific to a float tensor.
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum FloatOperationDescription {
/// Operation corresponding to [exp](burn_tensor::ops::FloatTensorOps::float_exp).
/// Operation corresponding to [exp](crate::ops::FloatTensorOps::float_exp).
Exp(UnaryOperationDescription),
/// Operation corresponding to [log](burn_tensor::ops::FloatTensorOps::float_log).
/// Operation corresponding to [log](crate::ops::FloatTensorOps::float_log).
Log(UnaryOperationDescription),
/// Operation corresponding to [log1p](burn_tensor::ops::FloatTensorOps::float_log1p).
/// Operation corresponding to [log1p](crate::ops::FloatTensorOps::float_log1p).
Log1p(UnaryOperationDescription),
/// Operation corresponding to [erf](burn_tensor::ops::FloatTensorOps::float_erf).
/// Operation corresponding to [erf](crate::ops::FloatTensorOps::float_erf).
Erf(UnaryOperationDescription),
/// Operation corresponding to [powf_scalar](burn_tensor::ops::FloatTensorOps::float_powf_scalar).
/// Operation corresponding to [powf_scalar](crate::ops::FloatTensorOps::float_powf_scalar).
PowfScalar(ScalarOperationDescription<f32>),
/// Operation corresponding to [sqrt](burn_tensor::ops::FloatTensorOps::float_sqrt).
/// Operation corresponding to [sqrt](crate::ops::FloatTensorOps::float_sqrt).
Sqrt(UnaryOperationDescription),
/// Operation corresponding to [cos](burn_tensor::ops::FloatTensorOps::float_cos).
/// Operation corresponding to [cos](crate::ops::FloatTensorOps::float_cos).
Cos(UnaryOperationDescription),
/// Operation corresponding to [sin](burn_tensor::ops::FloatTensorOps::float_sin).
/// Operation corresponding to [sin](crate::ops::FloatTensorOps::float_sin).
Sin(UnaryOperationDescription),
/// Operation corresponding to [tanh](burn_tensor::ops::FloatTensorOps::float_tanh).
/// Operation corresponding to [tanh](crate::ops::FloatTensorOps::float_tanh).
Tanh(UnaryOperationDescription),
/// Operation corresponding to [into_int](burn_tensor::ops::FloatTensorOps::float_into_int).
/// Operation corresponding to [into_int](crate::ops::FloatTensorOps::float_into_int).
IntoInt(UnaryOperationDescription),
/// Operation corresponding to [matmul](burn_tensor::ops::FloatTensorOps::float_matmul).
/// Operation corresponding to [matmul](crate::ops::FloatTensorOps::float_matmul).
Matmul(BinaryOperationDescription),
/// Operation corresponding to [random](burn_tensor::ops::FloatTensorOps::float_random).
/// Operation corresponding to [random](crate::ops::FloatTensorOps::float_random).
Random(RandomOperationDescription),
/// Operation corresponding to [recip](burn_tensor::ops::FloatTensorOps::float_recip).
/// Operation corresponding to [recip](crate::ops::FloatTensorOps::float_recip).
Recip(UnaryOperationDescription),
}
/// Operation description specific to module.
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum ModuleOperationDescription {
/// Operation corresponding to [embedding](burn_tensor::ops::ModuleOps::embedding).
/// Operation corresponding to [embedding](crate::ops::ModuleOps::embedding).
Embedding(EmbeddingDescription),
/// Operation corresponding to [embedding_backward](burn_tensor::ops::ModuleOps::embedding_backward).
/// Operation corresponding to [embedding_backward](crate::ops::ModuleOps::embedding_backward).
EmbeddingBackward(EmbeddingBackwardDescription),
/// Operation corresponding to [conv1d](burn_tensor::ops::ModuleOps::conv1d).
/// Operation corresponding to [conv1d](crate::ops::ModuleOps::conv1d).
Conv1d(Conv1dDescription),
/// Operation corresponding to [conv2d](burn_tensor::ops::ModuleOps::conv2d).
/// Operation corresponding to [conv2d](crate::ops::ModuleOps::conv2d).
Conv2d(Conv2dDescription),
/// Operation corresponding to [conv transpose 1d](burn_tensor::ops::ModuleOps::conv_transpose1d).
/// Operation corresponding to [conv transpose 1d](crate::ops::ModuleOps::conv_transpose1d).
ConvTranspose1d(ConvTranspose1dDescription),
/// Operation corresponding to [conv transpose 2d](burn_tensor::ops::ModuleOps::conv_transpose2d).
/// Operation corresponding to [conv transpose 2d](crate::ops::ModuleOps::conv_transpose2d).
ConvTranspose2d(ConvTranspose2dDescription),
/// Operation corresponding to [avg pool 1d](burn_tensor::ops::ModuleOps::avg_pool1d).
/// Operation corresponding to [avg pool 1d](crate::ops::ModuleOps::avg_pool1d).
AvgPool1d(AvgPool1dDescription),
/// Operation corresponding to [avg pool 2d](burn_tensor::ops::ModuleOps::avg_pool2d).
/// Operation corresponding to [avg pool 2d](crate::ops::ModuleOps::avg_pool2d).
AvgPool2d(AvgPool2dDescription),
/// Operation corresponding to
/// [avg pool 1d backward](burn_tensor::ops::ModuleOps::avg_pool1d_backward).
/// [avg pool 1d backward](crate::ops::ModuleOps::avg_pool1d_backward).
AvgPool1dBackward(AvgPool1dBackwardDescription),
/// Operation corresponding to
/// [avg pool 2d backward](burn_tensor::ops::ModuleOps::avg_pool2d_backward).
/// [avg pool 2d backward](crate::ops::ModuleOps::avg_pool2d_backward).
AvgPool2dBackward(AvgPool2dBackwardDescription),
/// Operation corresponding to
/// [adaptive avg pool 1d](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d).
/// [adaptive avg pool 1d](crate::ops::ModuleOps::adaptive_avg_pool1d).
AdaptiveAvgPool1d(AdaptiveAvgPool1dDescription),
/// Operation corresponding to
/// [adaptive avg pool 2d](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d).
/// [adaptive avg pool 2d](crate::ops::ModuleOps::adaptive_avg_pool2d).
AdaptiveAvgPool2d(AdaptiveAvgPool2dDescription),
/// Operation corresponding to
/// [adaptive avg pool 1d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d_backward).
/// [adaptive avg pool 1d backward](crate::ops::ModuleOps::adaptive_avg_pool1d_backward).
AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardDescription),
/// Operation corresponding to
/// [adaptive avg pool 2d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d_backward).
/// [adaptive avg pool 2d backward](crate::ops::ModuleOps::adaptive_avg_pool2d_backward).
AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardDescription),
/// Operation corresponding to
/// [max pool 1d](burn_tensor::ops::ModuleOps::max_pool1d).
/// [max pool 1d](crate::ops::ModuleOps::max_pool1d).
MaxPool1d(MaxPool1dDescription),
/// Operation corresponding to
/// [max pool 1d with indices](burn_tensor::ops::ModuleOps::max_pool1d_with_indices).
/// [max pool 1d with indices](crate::ops::ModuleOps::max_pool1d_with_indices).
MaxPool1dWithIndices(MaxPool1dWithIndicesDescription),
/// Operation corresponding to
/// [max pool 1d with indices backward](burn_tensor::ops::ModuleOps::max_pool1d_with_indices_backward).
/// [max pool 1d with indices backward](crate::ops::ModuleOps::max_pool1d_with_indices_backward).
MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardDescription),
/// Operation corresponding to
/// [max pool 2d](burn_tensor::ops::ModuleOps::max_pool1d).
/// [max pool 2d](crate::ops::ModuleOps::max_pool1d).
MaxPool2d(MaxPool2dDescription),
/// Operation corresponding to
/// [max pool 2d with indices](burn_tensor::ops::ModuleOps::max_pool2d_with_indices).
/// [max pool 2d with indices](crate::ops::ModuleOps::max_pool2d_with_indices).
MaxPool2dWithIndices(MaxPool2dWithIndicesDescription),
/// Operation corresponding to
/// [max pool 2d with indices backward](burn_tensor::ops::ModuleOps::max_pool2d_with_indices_backward).
/// [max pool 2d with indices backward](crate::ops::ModuleOps::max_pool2d_with_indices_backward).
MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardDescription),
/// Operation corresponding to [interpolate](burn_tensor::ops::ModuleOps::interpolate).
/// Operation corresponding to [interpolate](crate::ops::ModuleOps::interpolate).
Interpolate(InterpolateDescription),
/// Operation corresponding to [interpolate backward](burn_tensor::ops::ModuleOps::interpolate_backward).
/// Operation corresponding to [interpolate backward](crate::ops::ModuleOps::interpolate_backward).
InterpolateBackward(InterpolateBackwardDescription),
}
@ -131,73 +127,73 @@ pub enum ModuleOperationDescription {
pub enum BaseOperationDescription {
/// Operation corresponding to:
///
/// Float => [to device](burn_tensor::ops::FloatTensorOps::float_to_device).
/// Int => [to device](burn_tensor::ops::IntTensorOps::int_to_device).
/// Bool => [to device](burn_tensor::ops::BoolTensorOps::bool_to_device).
/// Float => [to device](crate::ops::FloatTensorOps::float_to_device).
/// Int => [to device](crate::ops::IntTensorOps::int_to_device).
/// Bool => [to device](crate::ops::BoolTensorOps::bool_to_device).
ToDevice(TensorDescription),
/// Operation corresponding to:
///
/// Float => [reshape](burn_tensor::ops::FloatTensorOps::float_reshape).
/// Int => [reshape](burn_tensor::ops::IntTensorOps::int_reshape).
/// Bool => [reshape](burn_tensor::ops::BoolTensorOps::bool_reshape).
/// Float => [reshape](crate::ops::FloatTensorOps::float_reshape).
/// Int => [reshape](crate::ops::IntTensorOps::int_reshape).
/// Bool => [reshape](crate::ops::BoolTensorOps::bool_reshape).
Reshape(ReshapeDescription),
/// Operation corresponding to:
///
/// Float => [swap_dims](burn_tensor::ops::FloatTensorOps::float_swap_dims).
/// Int => [swap_dims](burn_tensor::ops::IntTensorOps::int_swap_dims).
/// Bool => [swap_dims](burn_tensor::ops::BoolTensorOps::bool_swap_dims).
/// Float => [swap_dims](crate::ops::FloatTensorOps::float_swap_dims).
/// Int => [swap_dims](crate::ops::IntTensorOps::int_swap_dims).
/// Bool => [swap_dims](crate::ops::BoolTensorOps::bool_swap_dims).
SwapDims(SwapDimsDescription),
/// Operation corresponding to:
///
/// Float => [permute](burn_tensor::ops::FloatTensorOps::float_permute).
/// Int => [permute](burn_tensor::ops::IntTensorOps::int_permute).
/// Bool => [permute](burn_tensor::ops::BoolTensorOps::bool_permute).
/// Float => [permute](crate::ops::FloatTensorOps::float_permute).
/// Int => [permute](crate::ops::IntTensorOps::int_permute).
/// Bool => [permute](crate::ops::BoolTensorOps::bool_permute).
Permute(PermuteOperationDescription),
/// Operation corresponding to:
/// Float => [flip](burn_tensor::ops::FloatTensorOps::float_flip).
/// Int => [flip](burn_tensor::ops::IntTensorOps::int_flip).
/// Bool => [flip](burn_tensor::ops::BoolTensorOps::bool_flip).
/// Float => [flip](crate::ops::FloatTensorOps::float_flip).
/// Int => [flip](crate::ops::IntTensorOps::int_flip).
/// Bool => [flip](crate::ops::BoolTensorOps::bool_flip).
Flip(FlipOperationDescription),
/// Operation corresponding to:
///
/// Float => [expand](burn_tensor::ops::FloatTensorOps::float_expand).
/// Int => [expand](burn_tensor::ops::IntTensorOps::int_expand).
/// Bool => [expand](burn_tensor::ops::BoolTensorOps::bool_expand).
/// Float => [expand](crate::ops::FloatTensorOps::float_expand).
/// Int => [expand](crate::ops::IntTensorOps::int_expand).
/// Bool => [expand](crate::ops::BoolTensorOps::bool_expand).
Expand(ExpandOperationDescription),
/// Operation corresponding to:
///
/// Float => [slice](burn_tensor::ops::FloatTensorOps::float_slice).
/// Int => [slice](burn_tensor::ops::IntTensorOps::int_slice).
/// Bool => [slice](burn_tensor::ops::BoolTensorOps::bool_slice).
/// Float => [slice](crate::ops::FloatTensorOps::float_slice).
/// Int => [slice](crate::ops::IntTensorOps::int_slice).
/// Bool => [slice](crate::ops::BoolTensorOps::bool_slice).
Slice(SliceOperationDescription),
/// Operation corresponding to:
///
/// Float => [slice assign](burn_tensor::ops::FloatTensorOps::float_slice_assign).
/// Int => [slice assign](burn_tensor::ops::IntTensorOps::int_slice_assign).
/// Bool => [slice assign](burn_tensor::ops::BoolTensorOps::bool_slice_assign).
/// Float => [slice assign](crate::ops::FloatTensorOps::float_slice_assign).
/// Int => [slice assign](crate::ops::IntTensorOps::int_slice_assign).
/// Bool => [slice assign](crate::ops::BoolTensorOps::bool_slice_assign).
SliceAssign(SliceAssignOperationDescription),
/// Operation corresponding to:
///
/// Float => [equal](burn_tensor::ops::FloatTensorOps::float_equal).
/// Int => [equal](burn_tensor::ops::IntTensorOps::int_equal).
/// Bool => [equal](burn_tensor::ops::BoolTensorOps::bool_equal).
/// Float => [equal](crate::ops::FloatTensorOps::float_equal).
/// Int => [equal](crate::ops::IntTensorOps::int_equal).
/// Bool => [equal](crate::ops::BoolTensorOps::bool_equal).
Equal(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [repeat](burn_tensor::ops::FloatTensorOps::float_repeat).
/// Int => [repeat](burn_tensor::ops::IntTensorOps::int_repeat).
/// Bool => [repeat](burn_tensor::ops::BoolTensorOps::bool_repeat).
/// Float => [repeat](crate::ops::FloatTensorOps::float_repeat).
/// Int => [repeat](crate::ops::IntTensorOps::int_repeat).
/// Bool => [repeat](crate::ops::BoolTensorOps::bool_repeat).
Repeat(RepeatOperationDescription),
/// Operation corresponding to:
///
/// Float => [cat](burn_tensor::ops::FloatTensorOps::float_cat).
/// Int => [cat](burn_tensor::ops::IntTensorOps::int_cat).
/// Bool => [cat](burn_tensor::ops::BoolTensorOps::bool_cat).
/// Float => [cat](crate::ops::FloatTensorOps::float_cat).
/// Int => [cat](crate::ops::IntTensorOps::int_cat).
/// Bool => [cat](crate::ops::BoolTensorOps::bool_cat).
Cat(CatOperationDescription),
}
@ -206,248 +202,248 @@ pub enum BaseOperationDescription {
pub enum NumericOperationDescription<E> {
/// Operation corresponding to:
///
/// Float => [add](burn_tensor::ops::FloatTensorOps::float_add).
/// Int => [add](burn_tensor::ops::IntTensorOps::int_add).
/// Float => [add](crate::ops::FloatTensorOps::float_add).
/// Int => [add](crate::ops::IntTensorOps::int_add).
Add(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [add scalar](burn_tensor::ops::FloatTensorOps::float_add_scalar).
/// Int => [add scalar](burn_tensor::ops::IntTensorOps::int_add_scalar).
/// Float => [add scalar](crate::ops::FloatTensorOps::float_add_scalar).
/// Int => [add scalar](crate::ops::IntTensorOps::int_add_scalar).
AddScalar(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [sub](burn_tensor::ops::FloatTensorOps::float_sub).
/// Int => [sub](burn_tensor::ops::IntTensorOps::int_sub).
/// Float => [sub](crate::ops::FloatTensorOps::float_sub).
/// Int => [sub](crate::ops::IntTensorOps::int_sub).
Sub(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [sub scalar](burn_tensor::ops::FloatTensorOps::float_sub_scalar).
/// Int => [sub scalar](burn_tensor::ops::IntTensorOps::int_sub_scalar).
/// Float => [sub scalar](crate::ops::FloatTensorOps::float_sub_scalar).
/// Int => [sub scalar](crate::ops::IntTensorOps::int_sub_scalar).
SubScalar(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [div](burn_tensor::ops::FloatTensorOps::float_div).
/// Int => [div](burn_tensor::ops::IntTensorOps::int_div).
/// Float => [div](crate::ops::FloatTensorOps::float_div).
/// Int => [div](crate::ops::IntTensorOps::int_div).
Div(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [div scalar](burn_tensor::ops::FloatTensorOps::float_div_scalar).
/// Int => [div scalar](burn_tensor::ops::IntTensorOps::int_div_scalar).
/// Float => [div scalar](crate::ops::FloatTensorOps::float_div_scalar).
/// Int => [div scalar](crate::ops::IntTensorOps::int_div_scalar).
DivScalar(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [div](burn_tensor::ops::FloatTensorOps::float_remainder_scalar).
/// Int => [div](burn_tensor::ops::IntTensorOps::int_remainder_scalar).
/// Float => [div](crate::ops::FloatTensorOps::float_remainder_scalar).
/// Int => [div](crate::ops::IntTensorOps::int_remainder_scalar).
RemScalar(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [mul](burn_tensor::ops::FloatTensorOps::float_mul).
/// Int => [mul](burn_tensor::ops::IntTensorOps::int_mul).
/// Float => [mul](crate::ops::FloatTensorOps::float_mul).
/// Int => [mul](crate::ops::IntTensorOps::int_mul).
Mul(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [mul scalar](burn_tensor::ops::FloatTensorOps::float_mul_scalar).
/// Int => [mul scalar](burn_tensor::ops::IntTensorOps::int_mul_scalar).
/// Float => [mul scalar](crate::ops::FloatTensorOps::float_mul_scalar).
/// Int => [mul scalar](crate::ops::IntTensorOps::int_mul_scalar).
MulScalar(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [abs](burn_tensor::ops::FloatTensorOps::float_abs).
/// Int => [abs](burn_tensor::ops::IntTensorOps::int_abs).
/// Float => [abs](crate::ops::FloatTensorOps::float_abs).
/// Int => [abs](crate::ops::IntTensorOps::int_abs).
Abs(UnaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [ones](burn_tensor::ops::FloatTensorOps::float_ones).
/// Int => [ones](burn_tensor::ops::IntTensorOps::int_ones).
/// Float => [ones](crate::ops::FloatTensorOps::float_ones).
/// Int => [ones](crate::ops::IntTensorOps::int_ones).
Ones(TensorDescription),
/// Operation corresponding to:
///
/// Float => [zeros](burn_tensor::ops::FloatTensorOps::float_zeros).
/// Int => [zeros](burn_tensor::ops::IntTensorOps::int_zeros).
/// Float => [zeros](crate::ops::FloatTensorOps::float_zeros).
/// Int => [zeros](crate::ops::IntTensorOps::int_zeros).
Zeros(TensorDescription),
/// Operation corresponding to:
///
/// Float => [full](burn_tensor::ops::FloatTensorOps::float_full).
/// Int => [full](burn_tensor::ops::IntTensorOps::int_full).
/// Float => [full](crate::ops::FloatTensorOps::float_full).
/// Int => [full](crate::ops::IntTensorOps::int_full).
Full((TensorDescription, E)),
/// Operation corresponding to:
///
/// Float => [gather](burn_tensor::ops::FloatTensorOps::float_gather).
/// Int => [gather](burn_tensor::ops::IntTensorOps::int_gather).
/// Float => [gather](crate::ops::FloatTensorOps::float_gather).
/// Int => [gather](crate::ops::IntTensorOps::int_gather).
Gather(GatherOperationDescription),
/// Operation corresponding to:
///
/// Float => [scatter](burn_tensor::ops::FloatTensorOps::float_scatter).
/// Int => [scatter](burn_tensor::ops::IntTensorOps::int_scatter).
/// Float => [scatter](crate::ops::FloatTensorOps::float_scatter).
/// Int => [scatter](crate::ops::IntTensorOps::int_scatter).
Scatter(ScatterOperationDescription),
/// Operation corresponding to:
///
/// Float => [select](burn_tensor::ops::FloatTensorOps::float_select).
/// Int => [select](burn_tensor::ops::IntTensorOps::int_select).
/// Float => [select](crate::ops::FloatTensorOps::float_select).
/// Int => [select](crate::ops::IntTensorOps::int_select).
Select(SelectOperationDescription),
/// Operation corresponding to:
///
/// Float => [select assign](burn_tensor::ops::FloatTensorOps::float_select_assign).
/// Int => [select assign](burn_tensor::ops::IntTensorOps::int_select_assign).
/// Float => [select assign](crate::ops::FloatTensorOps::float_select_assign).
/// Int => [select assign](crate::ops::IntTensorOps::int_select_assign).
SelectAssign(SelectAssignOperationDescription),
/// Operation corresponding to:
///
/// Float => [mask where](burn_tensor::ops::FloatTensorOps::float_mask_where).
/// Int => [mask where](burn_tensor::ops::IntTensorOps::int_mask_where).
/// Float => [mask where](crate::ops::FloatTensorOps::float_mask_where).
/// Int => [mask where](crate::ops::IntTensorOps::int_mask_where).
MaskWhere(MaskWhereOperationDescription),
/// Operation corresponding to:
///
/// Float => [mask fill](burn_tensor::ops::FloatTensorOps::float_mask_fill).
/// Int => [mask fill](burn_tensor::ops::IntTensorOps::int_mask_fill).
/// Float => [mask fill](crate::ops::FloatTensorOps::float_mask_fill).
/// Int => [mask fill](crate::ops::IntTensorOps::int_mask_fill).
MaskFill(MaskFillOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [mean dim](burn_tensor::ops::FloatTensorOps::float_mean_dim).
/// Int => [mean dim](burn_tensor::ops::IntTensorOps::int_mean_dim).
/// Float => [mean dim](crate::ops::FloatTensorOps::float_mean_dim).
/// Int => [mean dim](crate::ops::IntTensorOps::int_mean_dim).
MeanDim(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [mean](burn_tensor::ops::FloatTensorOps::float_mean).
/// Int => [mean](burn_tensor::ops::IntTensorOps::int_mean).
/// Float => [mean](crate::ops::FloatTensorOps::float_mean).
/// Int => [mean](crate::ops::IntTensorOps::int_mean).
Mean(UnaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [sum](burn_tensor::ops::FloatTensorOps::float_sum).
/// Int => [sum](burn_tensor::ops::IntTensorOps::int_sum).
/// Float => [sum](crate::ops::FloatTensorOps::float_sum).
/// Int => [sum](crate::ops::IntTensorOps::int_sum).
Sum(UnaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [sum dim](burn_tensor::ops::FloatTensorOps::float_sum_dim).
/// Int => [sum dim](burn_tensor::ops::IntTensorOps::int_sum_dim).
/// Float => [sum dim](crate::ops::FloatTensorOps::float_sum_dim).
/// Int => [sum dim](crate::ops::IntTensorOps::int_sum_dim).
SumDim(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [prod](burn_tensor::ops::FloatTensorOps::float_prod).
/// Int => [prod](burn_tensor::ops::IntTensorOps::int_prod).
/// Float => [prod](crate::ops::FloatTensorOps::float_prod).
/// Int => [prod](crate::ops::IntTensorOps::int_prod).
Prod(UnaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [prod dim](burn_tensor::ops::FloatTensorOps::float_prod_dim).
/// Int => [prod dim](burn_tensor::ops::IntTensorOps::int_prod_dim).
/// Float => [prod dim](crate::ops::FloatTensorOps::float_prod_dim).
/// Int => [prod dim](crate::ops::IntTensorOps::int_prod_dim).
ProdDim(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [equal elem](burn_tensor::ops::FloatTensorOps::float_equal_elem).
/// Int => [equal elem](burn_tensor::ops::IntTensorOps::int_equal_elem).
/// Float => [equal elem](crate::ops::FloatTensorOps::float_equal_elem).
/// Int => [equal elem](crate::ops::IntTensorOps::int_equal_elem).
EqualElem(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [greater](burn_tensor::ops::FloatTensorOps::float_greater).
/// Int => [greater](burn_tensor::ops::IntTensorOps::int_greater).
/// Float => [greater](crate::ops::FloatTensorOps::float_greater).
/// Int => [greater](crate::ops::IntTensorOps::int_greater).
Greater(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [greater elem](burn_tensor::ops::FloatTensorOps::float_greater_elem).
/// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem).
/// Float => [greater elem](crate::ops::FloatTensorOps::float_greater_elem).
/// Int => [greater elem](crate::ops::IntTensorOps::int_greater_elem).
GreaterElem(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [greater equal](burn_tensor::ops::FloatTensorOps::float_greater_elem).
/// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem).
/// Float => [greater equal](crate::ops::FloatTensorOps::float_greater_elem).
/// Int => [greater elem](crate::ops::IntTensorOps::int_greater_elem).
GreaterEqual(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [greater equal elem](burn_tensor::ops::FloatTensorOps::float_greater_equal_elem).
/// Int => [greater equal elem](burn_tensor::ops::IntTensorOps::int_greater_equal_elem).
/// Float => [greater equal elem](crate::ops::FloatTensorOps::float_greater_equal_elem).
/// Int => [greater equal elem](crate::ops::IntTensorOps::int_greater_equal_elem).
GreaterEqualElem(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [lower](burn_tensor::ops::FloatTensorOps::float_lower).
/// Int => [lower](burn_tensor::ops::IntTensorOps::int_lower).
/// Float => [lower](crate::ops::FloatTensorOps::float_lower).
/// Int => [lower](crate::ops::IntTensorOps::int_lower).
Lower(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [lower elem](burn_tensor::ops::FloatTensorOps::float_lower_elem).
/// Int => [lower elem](burn_tensor::ops::IntTensorOps::int_lower_elem).
/// Float => [lower elem](crate::ops::FloatTensorOps::float_lower_elem).
/// Int => [lower elem](crate::ops::IntTensorOps::int_lower_elem).
LowerElem(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [lower equal](burn_tensor::ops::FloatTensorOps::float_lower_equal).
/// Int => [lower equal](burn_tensor::ops::IntTensorOps::int_lower_equal).
/// Float => [lower equal](crate::ops::FloatTensorOps::float_lower_equal).
/// Int => [lower equal](crate::ops::IntTensorOps::int_lower_equal).
LowerEqual(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [lower equal elem](burn_tensor::ops::FloatTensorOps::float_lower_equal_elem).
/// Int => [lower equal elem](burn_tensor::ops::IntTensorOps::int_lower_equal_elem).
/// Float => [lower equal elem](crate::ops::FloatTensorOps::float_lower_equal_elem).
/// Int => [lower equal elem](crate::ops::IntTensorOps::int_lower_equal_elem).
LowerEqualElem(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [argmax](burn_tensor::ops::FloatTensorOps::float_argmax).
/// Int => [argmax](burn_tensor::ops::IntTensorOps::int_argmax).
/// Float => [argmax](crate::ops::FloatTensorOps::float_argmax).
/// Int => [argmax](crate::ops::IntTensorOps::int_argmax).
ArgMax(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [argmin](burn_tensor::ops::FloatTensorOps::float_argmin).
/// Int => [argmin](burn_tensor::ops::IntTensorOps::int_argmin).
/// Float => [argmin](crate::ops::FloatTensorOps::float_argmin).
/// Int => [argmin](crate::ops::IntTensorOps::int_argmin).
ArgMin(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [max](burn_tensor::ops::FloatTensorOps::float_max).
/// Int => [max](burn_tensor::ops::IntTensorOps::int_max).
/// Float => [max](crate::ops::FloatTensorOps::float_max).
/// Int => [max](crate::ops::IntTensorOps::int_max).
Max(UnaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [max dim with indices](burn_tensor::ops::FloatTensorOps::float_max_dim_with_indices).
/// Int => [max dim with indices](burn_tensor::ops::IntTensorOps::int_max_dim_with_indices).
/// Float => [max dim with indices](crate::ops::FloatTensorOps::float_max_dim_with_indices).
/// Int => [max dim with indices](crate::ops::IntTensorOps::int_max_dim_with_indices).
MaxDimWithIndices(ReduceDimWithIndicesDescription),
/// Operation corresponding to:
///
/// Float => [min dim with indices](burn_tensor::ops::FloatTensorOps::float_min_dim_with_indices).
/// Int => [min dim with indices](burn_tensor::ops::IntTensorOps::int_min_dim_with_indices).
/// Float => [min dim with indices](crate::ops::FloatTensorOps::float_min_dim_with_indices).
/// Int => [min dim with indices](crate::ops::IntTensorOps::int_min_dim_with_indices).
MinDimWithIndices(ReduceDimWithIndicesDescription),
/// Operation corresponding to:
///
/// Float => [min](burn_tensor::ops::FloatTensorOps::float_min).
/// Int => [min](burn_tensor::ops::IntTensorOps::int_min).
/// Float => [min](crate::ops::FloatTensorOps::float_min).
/// Int => [min](crate::ops::IntTensorOps::int_min).
Min(UnaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [max dim](burn_tensor::ops::FloatTensorOps::float_max_dim).
/// Int => [max dim](burn_tensor::ops::IntTensorOps::int_max_dim).
/// Float => [max dim](crate::ops::FloatTensorOps::float_max_dim).
/// Int => [max dim](crate::ops::IntTensorOps::int_max_dim).
MaxDim(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [min dim](burn_tensor::ops::FloatTensorOps::float_min_dim).
/// Int => [min dim](burn_tensor::ops::IntTensorOps::int_min_dim).
/// Float => [min dim](crate::ops::FloatTensorOps::float_min_dim).
/// Int => [min dim](crate::ops::IntTensorOps::int_min_dim).
MinDim(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [clamp](burn_tensor::ops::FloatTensorOps::float_clamp).
/// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp).
/// Float => [clamp](crate::ops::FloatTensorOps::float_clamp).
/// Int => [clamp](crate::ops::IntTensorOps::int_clamp).
Clamp(ClampOperationDescription<E>),
/// Operation corresponding to:
///
/// Int => [random](burn_tensor::ops::IntTensorOps::int_random).
/// Int => [random](crate::ops::IntTensorOps::int_random).
IntRandom(RandomOperationDescription),
/// Operation corresponding to:
///
/// Float => [powf](burn_tensor::ops::FloatTensorOps::float_powf).
/// Int => [powf](burn_tensor::ops::IntTensorOps::int_powf).
/// Float => [powf](crate::ops::FloatTensorOps::float_powf).
/// Int => [powf](crate::ops::IntTensorOps::int_powf).
Powf(BinaryOperationDescription),
}
/// Operation description specific to an int tensor.
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum IntOperationDescription {
/// Operation corresponding to [into float](burn_tensor::ops::IntTensorOps::int_into_float).
/// Operation corresponding to [into float](crate::ops::IntTensorOps::int_into_float).
IntoFloat(UnaryOperationDescription),
}
/// Operation description specific to a bool tensor.
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum BoolOperationDescription {
/// Operation corresponding to [into float](burn_tensor::ops::BoolTensorOps::bool_into_float).
/// Operation corresponding to [into float](crate::ops::BoolTensorOps::bool_into_float).
IntoFloat(UnaryOperationDescription),
/// Operation corresponding to [into int](burn_tensor::ops::BoolTensorOps::bool_into_int).
/// Operation corresponding to [into int](crate::ops::BoolTensorOps::bool_into_int).
IntoInt(UnaryOperationDescription),
/// Operation corresponding to [not](burn_tensor::ops::BoolTensorOps::bool_not).
/// Operation corresponding to [not](crate::ops::BoolTensorOps::bool_not).
Not(UnaryOperationDescription),
}
@ -1057,7 +1053,7 @@ pub struct InterpolateBackwardDescription {
impl OperationDescription {
/// Cleanup the remaining tensor handles that have not been used.
pub(crate) fn nodes(&self) -> Vec<&TensorDescription> {
pub fn nodes(&self) -> Vec<&TensorDescription> {
match self {
OperationDescription::BaseFloat(ops) => ops.nodes(),
OperationDescription::BaseInt(ops) => ops.nodes(),

View File

@ -0,0 +1,45 @@
use serde::{Deserialize, Serialize};
/// The tensor unique identifier.
#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]
pub struct TensorId {
value: u64,
}
/// The status of the current tensor.
#[derive(Hash, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum TensorStatus {
/// The tensor can be read, but not written.
ReadOnly,
/// The tensor can be mutated inplace.
ReadWrite,
/// No handle exists for that tensor.
NotInit,
}
/// A tensor definition represents a snapshot of a tensor when it was used.
///
/// # Example
///
/// A tensor that is used multiple times has its status updated for each operation.
///
/// 1. Status::NotInit
/// 2. Status::ReadOnly
/// 3. Status::ReadOnly
/// 4. Status::ReadWrite
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct TensorDescription {
/// The [tensor id](TensorId).
pub id: TensorId,
/// The shape of the tensor.
pub shape: Vec<usize>,
/// The [status](TensorStatus) of the tensor when it was used.
pub status: TensorStatus,
}
impl TensorId {
/// Create a new tensor id.
pub fn new(value: u64) -> Self {
Self { value }
}
}

View File

@ -3,7 +3,7 @@ use alloc::string::String;
use crate::ops::*;
use crate::tensor::Element;
use super::BackendBridge;
use super::{BackendBridge, DeviceOps};
/// This trait defines all types and functions needed for a backend to be used with burn.
///
@ -66,7 +66,7 @@ pub trait Backend:
+ 'static
{
/// Device type.
type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync;
type Device: DeviceOps;
/// A bridge that can cast tensors to full precision.
type FullPrecisionBridge: BackendBridge<Self> + 'static;

View File

@ -0,0 +1,14 @@
/// The device id.
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
pub struct DeviceId {
/// The type id identifies the type of the device.
pub type_id: u16,
/// The index id identifies the device number.
pub index_id: u32,
}
/// The handle device trait allows to get an id for a backend device.
pub trait DeviceOps: Clone + Default + PartialEq + Send + Sync + core::fmt::Debug {
/// Return the [device id](DeviceId).
fn id(&self) -> DeviceId;
}

View File

@ -1,8 +1,10 @@
mod base;
mod bridge;
mod device;
pub use base::*;
pub use bridge::*;
pub use device::*;
// Not needed for now, useful for different tensor memory layout
// pub mod conversion;

View File

@ -1,14 +0,0 @@
use crate::WgpuDevice;
use burn_fusion::{DeviceId, FusionDevice};
impl FusionDevice for WgpuDevice {
fn id(&self) -> DeviceId {
match self {
WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32),
WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32),
WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32),
WgpuDevice::Cpu => DeviceId::new(3, 0),
WgpuDevice::BestAvailable => DeviceId::new(4, 0),
}
}
}

View File

@ -10,9 +10,6 @@ mod element;
mod graphics;
mod runtime;
#[cfg(feature = "fusion")]
mod fusion;
#[cfg(feature = "template")]
pub use burn_jit::{
compute::Kernel,

View File

@ -13,6 +13,7 @@ use burn_compute::{
ComputeRuntime,
};
use burn_jit::Runtime;
use burn_tensor::backend::{DeviceId, DeviceOps};
use std::marker::PhantomData;
use wgpu::{AdapterInfo, DeviceDescriptor};
@ -52,6 +53,18 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> Runtime for WgpuRuntime<G,
}
}
impl DeviceOps for WgpuDevice {
fn id(&self) -> DeviceId {
match self {
WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32),
WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32),
WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32),
WgpuDevice::Cpu => DeviceId::new(3, 0),
WgpuDevice::BestAvailable => DeviceId::new(4, 0),
}
}
}
/// The values that control how a WGPU Runtime will perform its calculations.
pub struct RuntimeOptions {
/// How the buffers are deallocated.