mirror of https://github.com/tracel-ai/burn.git
Move HandleContainer and Tensor Ops descriptions from burn-fusion to burn-tensor (#1654)
* Move HandlerContainer and Tensor Ops description to burn-tensor Move HandleContainer and Tensor operations descriptions to burn-tensor crate. Removed the FusionDevice and replaced it with a DeviceOps trait bound to Backend::Device. For now added modules to burn-tensor are excluded from no-std as they rely on Arc. * [burn-tensor] Flatten module hierarchy for tensor representation + Add new repr feature to cargo file. * Remove prefix on dosctring * [burn-fusion] Require default features of burn-tensor
This commit is contained in:
parent
e6b1b7a317
commit
c579686a8a
|
@ -3696,6 +3696,14 @@ dependencies = [
|
|||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "refactor"
|
||||
version = "0.14.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.10.4"
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use burn_tensor::{backend::Backend, Device};
|
||||
use burn_tensor::{
|
||||
backend::{Backend, DeviceId, DeviceOps},
|
||||
Device,
|
||||
};
|
||||
use candle_core::DeviceLocation;
|
||||
|
||||
use crate::{
|
||||
|
@ -60,6 +63,16 @@ impl From<candle_core::Device> for CandleDevice {
|
|||
}
|
||||
}
|
||||
|
||||
impl DeviceOps for CandleDevice {
|
||||
fn id(&self) -> burn_tensor::backend::DeviceId {
|
||||
match self {
|
||||
CandleDevice::Cpu => DeviceId::new(0, 0),
|
||||
CandleDevice::Cuda(index) => DeviceId::new(1, *index as u32),
|
||||
CandleDevice::Metal(index) => DeviceId::new(2, *index as u32),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CandleDevice {
|
||||
fn default() -> Self {
|
||||
Self::Cpu
|
||||
|
|
|
@ -16,7 +16,7 @@ std = ["serde/std"]
|
|||
doc = ["default"]
|
||||
|
||||
[dependencies]
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.14.0", default-features = false }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.14.0" }
|
||||
burn-common = { path = "../burn-common", version = "0.14.0" }
|
||||
hashbrown = { workspace = true }
|
||||
derive-new = {workspace = true }
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
use crate::{
|
||||
client::FusionClient,
|
||||
stream::{Context, OperationDescription},
|
||||
FusionClientLocator, FusionTensor, PrecisionBridge,
|
||||
client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge,
|
||||
};
|
||||
use burn_tensor::{
|
||||
backend::Backend,
|
||||
repr::{OperationDescription, ReprBackend},
|
||||
Device,
|
||||
};
|
||||
use burn_tensor::{backend::Backend, Device, Shape};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new();
|
||||
|
||||
pub(crate) fn get_client<B: FusionBackend>(device: &B::FusionDevice) -> B::FusionClient {
|
||||
pub(crate) fn get_client<B: FusionBackend>(device: &B::Device) -> B::FusionClient {
|
||||
CLIENTS.client(device)
|
||||
}
|
||||
|
||||
|
@ -43,7 +45,7 @@ impl<B: FusionBackend> Backend for Fusion<B> {
|
|||
}
|
||||
|
||||
fn sync(device: &Self::Device) {
|
||||
let client = CLIENTS.client::<B::FusionClient>(&device.clone().into());
|
||||
let client = CLIENTS.client::<B::FusionClient>(&device.clone());
|
||||
client.drain();
|
||||
B::sync(device)
|
||||
}
|
||||
|
@ -114,62 +116,17 @@ pub trait Optimization<B: FusionBackend>: Send {
|
|||
fn from_state(device: &B::Device, state: B::OptimizationState) -> Self;
|
||||
}
|
||||
|
||||
/// The device id.
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
|
||||
pub struct DeviceId {
|
||||
/// The type id identifies the type of the device.
|
||||
pub type_id: u16,
|
||||
/// The index id identifies the device number.
|
||||
pub index_id: u32,
|
||||
}
|
||||
|
||||
/// The handle device trait allows to get an id for a backend device.
|
||||
pub trait FusionDevice: Clone + Send + Sync + PartialEq {
|
||||
/// Return the [device id](DeviceId).
|
||||
fn id(&self) -> DeviceId;
|
||||
}
|
||||
|
||||
/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
|
||||
/// [operation builder](crate::OptimizationBuilder).
|
||||
pub trait FusionBackend: Backend {
|
||||
pub trait FusionBackend: Backend + ReprBackend {
|
||||
/// The state that can be serialized for an optimization.
|
||||
type OptimizationState: Serialize + DeserializeOwned;
|
||||
/// Optimization type for the backend.
|
||||
type Optimization: Optimization<Self>;
|
||||
|
||||
/// The device type that can return an ID.
|
||||
///
|
||||
/// It can be the same as (Backend::Device), but must implement (FusionDevice).
|
||||
type FusionDevice: FusionDevice + From<Self::Device> + Into<Self::Device> + core::fmt::Debug;
|
||||
/// The type that can be used to point to a tensor of any kind.
|
||||
type Handle: Sync + Send + Clone;
|
||||
/// What kind of client should be used.
|
||||
type FusionClient: FusionClient<FusionBackend = Self>;
|
||||
|
||||
/// The list of optimizations that will be used to optimize the computational graph.
|
||||
fn optimizations(device: Device<Self>)
|
||||
-> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
|
||||
|
||||
/// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::FloatTensorPrimitive).
|
||||
fn float_tensor<const D: usize>(
|
||||
handle: Self::Handle,
|
||||
shape: Shape<D>,
|
||||
) -> Self::FloatTensorPrimitive<D>;
|
||||
/// Convert a [handle](FusionBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive).
|
||||
fn int_tensor<const D: usize>(
|
||||
handle: Self::Handle,
|
||||
shape: Shape<D>,
|
||||
) -> Self::IntTensorPrimitive<D>;
|
||||
/// Convert a [handle](FusionBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive).
|
||||
fn bool_tensor<const D: usize>(
|
||||
handle: Self::Handle,
|
||||
shape: Shape<D>,
|
||||
) -> Self::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](FusionBackend::Handle).
|
||||
fn float_tensor_handle<const D: usize>(tensor: Self::FloatTensorPrimitive<D>) -> Self::Handle;
|
||||
/// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](FusionBackend::Handle).
|
||||
fn int_tensor_handle<const D: usize>(tensor: Self::IntTensorPrimitive<D>) -> Self::Handle;
|
||||
/// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](FusionBackend::Handle).
|
||||
fn bool_tensor_handle<const D: usize>(tensor: Self::BoolTensorPrimitive<D>) -> Self::Handle;
|
||||
}
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
use crate::{
|
||||
stream::{Operation, OperationDescription, StreamId},
|
||||
FusionBackend, FusionTensor, Handle, TensorDescription, TensorId,
|
||||
stream::{execution::Operation, StreamId},
|
||||
FusionBackend, FusionTensor, Handle,
|
||||
};
|
||||
use burn_tensor::{
|
||||
backend::Backend,
|
||||
ops::{FloatElem, IntElem},
|
||||
Data, Reader,
|
||||
repr::{OperationDescription, TensorDescription, TensorId},
|
||||
Data, Device, Reader,
|
||||
};
|
||||
|
||||
/// Define how to interact with the fusion server.
|
||||
|
@ -12,8 +14,8 @@ pub trait FusionClient: Send + Sync + Clone {
|
|||
/// The [fusion backend](FusionBackend) associated type.
|
||||
type FusionBackend: FusionBackend;
|
||||
|
||||
/// Create a new client for the given [fusion device](FusionBackend::FusionDevice).
|
||||
fn new(device: <Self::FusionBackend as FusionBackend>::FusionDevice) -> Self;
|
||||
/// Create a new client for the given [device](Backend::Device).
|
||||
fn new(device: Device<Self::FusionBackend>) -> Self;
|
||||
/// Register a new [tensor operation description](OperationDescription).
|
||||
fn register<O: Operation<Self::FusionBackend> + 'static>(
|
||||
&self,
|
||||
|
@ -24,7 +26,7 @@ pub trait FusionClient: Send + Sync + Clone {
|
|||
/// Register all lazy computation.
|
||||
fn drain(&self);
|
||||
/// Get the current device used by all operations handled by this client.
|
||||
fn device(&self) -> &<Self::FusionBackend as FusionBackend>::FusionDevice;
|
||||
fn device(&self) -> &<Self::FusionBackend as Backend>::Device;
|
||||
/// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it.
|
||||
fn tensor_uninitialized(&self, shape: Vec<usize>) -> FusionTensor<Self>;
|
||||
/// Create a tensor with the given handle and shape.
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
use super::FusionClient;
|
||||
use crate::{
|
||||
stream::{Operation, OperationDescription, StreamId},
|
||||
stream::{execution::Operation, StreamId},
|
||||
FusionBackend, FusionServer, FusionTensor, Handle,
|
||||
};
|
||||
use burn_tensor::ops::FloatElem;
|
||||
use burn_tensor::{
|
||||
backend::Backend,
|
||||
ops::FloatElem,
|
||||
repr::{OperationDescription, TensorDescription, TensorId},
|
||||
};
|
||||
use spin::Mutex;
|
||||
use std::sync::Arc;
|
||||
|
||||
|
@ -13,7 +17,7 @@ where
|
|||
B: FusionBackend,
|
||||
{
|
||||
server: Arc<Mutex<FusionServer<B>>>,
|
||||
device: B::FusionDevice,
|
||||
device: B::Device,
|
||||
}
|
||||
|
||||
impl<B> Clone for MutexFusionClient<B>
|
||||
|
@ -34,7 +38,7 @@ where
|
|||
{
|
||||
type FusionBackend = B;
|
||||
|
||||
fn new(device: B::FusionDevice) -> Self {
|
||||
fn new(device: B::Device) -> Self {
|
||||
Self {
|
||||
device: device.clone(),
|
||||
server: Arc::new(Mutex::new(FusionServer::new(device))),
|
||||
|
@ -63,7 +67,7 @@ where
|
|||
FusionTensor::new(id, shape, self.clone(), StreamId::current())
|
||||
}
|
||||
|
||||
fn device(&self) -> &<Self::FusionBackend as FusionBackend>::FusionDevice {
|
||||
fn device(&self) -> &<Self::FusionBackend as Backend>::Device {
|
||||
&self.device
|
||||
}
|
||||
fn register_tensor(
|
||||
|
@ -82,7 +86,7 @@ where
|
|||
|
||||
fn read_tensor_float<const D: usize>(
|
||||
&self,
|
||||
tensor: crate::TensorDescription,
|
||||
tensor: TensorDescription,
|
||||
stream: StreamId,
|
||||
) -> burn_tensor::Reader<burn_tensor::Data<FloatElem<Self::FusionBackend>, D>> {
|
||||
self.server.lock().read_float(tensor, stream)
|
||||
|
@ -90,7 +94,7 @@ where
|
|||
|
||||
fn read_tensor_int<const D: usize>(
|
||||
&self,
|
||||
tensor: crate::TensorDescription,
|
||||
tensor: TensorDescription,
|
||||
id: StreamId,
|
||||
) -> burn_tensor::Reader<burn_tensor::Data<burn_tensor::ops::IntElem<Self::FusionBackend>, D>>
|
||||
{
|
||||
|
@ -99,7 +103,7 @@ where
|
|||
|
||||
fn read_tensor_bool<const D: usize>(
|
||||
&self,
|
||||
tensor: crate::TensorDescription,
|
||||
tensor: TensorDescription,
|
||||
stream: StreamId,
|
||||
) -> burn_tensor::Reader<burn_tensor::Data<bool, D>> {
|
||||
self.server.lock().read_bool(tensor, stream)
|
||||
|
@ -107,17 +111,16 @@ where
|
|||
|
||||
fn change_client_float<const D: usize>(
|
||||
&self,
|
||||
tensor: crate::TensorDescription,
|
||||
tensor: TensorDescription,
|
||||
client: Self,
|
||||
stream: StreamId,
|
||||
) -> FusionTensor<Self> {
|
||||
let device = client.device.clone().into();
|
||||
|
||||
let mut server_other = client.server.lock();
|
||||
let mut server_current = self.server.lock();
|
||||
server_current.drain_stream(stream);
|
||||
|
||||
let id = server_current.change_server_float::<D>(&tensor, &device, &mut server_other);
|
||||
let id =
|
||||
server_current.change_server_float::<D>(&tensor, &client.device, &mut server_other);
|
||||
|
||||
core::mem::drop(server_other);
|
||||
core::mem::drop(server_current);
|
||||
|
@ -127,17 +130,15 @@ where
|
|||
|
||||
fn change_client_int<const D: usize>(
|
||||
&self,
|
||||
tensor: crate::TensorDescription,
|
||||
tensor: TensorDescription,
|
||||
client: Self,
|
||||
stream: StreamId,
|
||||
) -> FusionTensor<Self> {
|
||||
let device = client.device.clone().into();
|
||||
|
||||
let mut server_other = client.server.lock();
|
||||
let mut server_current = self.server.lock();
|
||||
server_current.drain_stream(stream);
|
||||
|
||||
let id = server_current.change_server_int::<D>(&tensor, &device, &mut server_other);
|
||||
let id = server_current.change_server_int::<D>(&tensor, &client.device, &mut server_other);
|
||||
|
||||
core::mem::drop(server_other);
|
||||
core::mem::drop(server_current);
|
||||
|
@ -147,17 +148,15 @@ where
|
|||
|
||||
fn change_client_bool<const D: usize>(
|
||||
&self,
|
||||
tensor: crate::TensorDescription,
|
||||
tensor: TensorDescription,
|
||||
client: Self,
|
||||
stream: StreamId,
|
||||
) -> FusionTensor<Self> {
|
||||
let device = client.device.clone().into();
|
||||
|
||||
let mut server_other = client.server.lock();
|
||||
let mut server_current = self.server.lock();
|
||||
server_current.drain_stream(stream);
|
||||
|
||||
let id = server_current.change_server_bool::<D>(&tensor, &device, &mut server_other);
|
||||
let id = server_current.change_server_bool::<D>(&tensor, &client.device, &mut server_other);
|
||||
|
||||
core::mem::drop(server_other);
|
||||
core::mem::drop(server_current);
|
||||
|
@ -165,7 +164,7 @@ where
|
|||
FusionTensor::new(id, tensor.shape, client, StreamId::current())
|
||||
}
|
||||
|
||||
fn register_orphan(&self, id: &crate::TensorId) {
|
||||
fn register_orphan(&self, id: &TensorId) {
|
||||
self.server.lock().drop_tensor_handle(*id);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,14 @@
|
|||
use crate::{client::FusionClient, DeviceId, FusionBackend, FusionDevice};
|
||||
use burn_tensor::{
|
||||
backend::{Backend, DeviceId, DeviceOps},
|
||||
repr::ReprBackend,
|
||||
};
|
||||
|
||||
use crate::client::FusionClient;
|
||||
|
||||
use std::{any::Any, collections::HashMap, ops::DerefMut};
|
||||
|
||||
/// Type alias for [fusion backend handle](FusionBackend::Handle).
|
||||
pub type Handle<B> = <B as FusionBackend>::Handle;
|
||||
/// Type alias for [representation backend handle](burn_tensor::repr::ReprBackend::Handle).
|
||||
pub type Handle<B> = <B as ReprBackend>::Handle;
|
||||
type Key = (core::any::TypeId, DeviceId);
|
||||
|
||||
pub(crate) struct FusionClientLocator {
|
||||
|
@ -22,7 +28,7 @@ impl FusionClientLocator {
|
|||
/// Provide the init function to create a new client if it isn't already initialized.
|
||||
pub fn client<C: FusionClient + 'static>(
|
||||
&self,
|
||||
device: &<C::FusionBackend as FusionBackend>::FusionDevice,
|
||||
device: &<C::FusionBackend as Backend>::Device,
|
||||
) -> C {
|
||||
let device_id = device.id();
|
||||
let client_id = (core::any::TypeId::of::<C>(), device_id);
|
||||
|
|
|
@ -16,7 +16,6 @@ pub mod stream;
|
|||
mod backend;
|
||||
mod bridge;
|
||||
mod fusion;
|
||||
mod handle;
|
||||
mod ops;
|
||||
mod server;
|
||||
mod tensor;
|
||||
|
@ -26,5 +25,4 @@ pub(crate) use server::*;
|
|||
pub use backend::*;
|
||||
pub use bridge::*;
|
||||
pub use fusion::*;
|
||||
pub use handle::*;
|
||||
pub use tensor::*;
|
||||
|
|
|
@ -11,7 +11,7 @@ macro_rules! binary_float_ops {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
|
||||
let rhs = handles.get_float_tensor(&self.desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
@ -35,7 +35,7 @@ macro_rules! binary_float_cmp_ops {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
|
||||
let rhs = handles.get_float_tensor(&self.desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
@ -59,7 +59,7 @@ macro_rules! binary_int_cmp_ops {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
|
||||
let rhs = handles.get_int_tensor(&self.desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
@ -93,7 +93,7 @@ macro_rules! binary_int_ops {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
|
||||
let rhs = handles.get_int_tensor(&self.desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
|
|
@ -2,23 +2,24 @@ use crate::{
|
|||
client::FusionClient,
|
||||
get_client,
|
||||
ops::binary::binary_ops_shape,
|
||||
stream::{
|
||||
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
|
||||
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, Operation,
|
||||
OperationDescription, PermuteOperationDescription, RepeatOperationDescription,
|
||||
ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, StreamId,
|
||||
SwapDimsDescription, UnaryOperationDescription,
|
||||
},
|
||||
stream::{execution::Operation, StreamId},
|
||||
Fusion, FusionBackend,
|
||||
};
|
||||
use burn_tensor::{
|
||||
ops::{BoolTensor, BoolTensorOps},
|
||||
repr::{
|
||||
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
|
||||
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
|
||||
HandleContainer, OperationDescription, PermuteOperationDescription,
|
||||
RepeatOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
|
||||
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
|
||||
},
|
||||
Device, Shape,
|
||||
};
|
||||
|
||||
impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
||||
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let client = get_client::<B>(&device.clone());
|
||||
let tensor = B::bool_empty(shape.clone(), device);
|
||||
|
||||
client.register_tensor(
|
||||
|
@ -42,7 +43,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
data: burn_tensor::Data<bool, D>,
|
||||
device: &Device<Self>,
|
||||
) -> BoolTensor<Self, D> {
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let client = get_client::<B>(&device.clone());
|
||||
let tensor = B::bool_from_data(data, device);
|
||||
let shape = B::bool_shape(&tensor);
|
||||
|
||||
|
@ -62,7 +63,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for IntoIntOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let output = B::bool_into_int(input);
|
||||
handles.register_int_tensor(&self.desc.out.id, output);
|
||||
|
@ -95,7 +96,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for IntoFloatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let output = B::bool_into_float(input);
|
||||
handles.register_float_tensor(&self.desc.out.id, output);
|
||||
|
@ -119,15 +120,15 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
fn bool_device<const D: usize>(tensor: &BoolTensor<Self, D>) -> Device<Self> {
|
||||
tensor.client.device().clone().into()
|
||||
tensor.client.device().clone()
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: BoolTensor<Self, D>,
|
||||
device: &Device<Self>,
|
||||
) -> BoolTensor<Self, D> {
|
||||
let device_original: &B::FusionDevice = tensor.client.device();
|
||||
let device_target: B::FusionDevice = device.clone().into();
|
||||
let device_original: &B::Device = tensor.client.device();
|
||||
let device_target: B::Device = device.clone();
|
||||
|
||||
if device_original == &device_target {
|
||||
return tensor;
|
||||
|
@ -154,7 +155,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for ReshapeDimsOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D1>(&self.desc.input);
|
||||
let output = B::bool_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
|
||||
handles.register_bool_tensor(&self.desc.out.id, output);
|
||||
|
@ -188,7 +189,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_bool_tensor::<D1>(&self.desc.tensor);
|
||||
|
||||
let output =
|
||||
|
@ -232,7 +233,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceAssignOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_bool_tensor::<D1>(&self.desc.tensor);
|
||||
let value = handles.get_bool_tensor::<D1>(&self.desc.value);
|
||||
|
||||
|
@ -277,7 +278,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for CatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensors = self
|
||||
.desc
|
||||
.tensors
|
||||
|
@ -329,7 +330,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for EqualOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let lhs = handles.get_bool_tensor::<D>(&self.desc.lhs);
|
||||
let rhs = handles.get_bool_tensor(&self.desc.rhs);
|
||||
let output = B::bool_equal(lhs, rhs);
|
||||
|
@ -364,7 +365,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for NotOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let output = B::bool_not(input);
|
||||
handles.register_bool_tensor(&self.desc.out.id, output);
|
||||
|
@ -381,7 +382,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Bool(crate::stream::BoolOperationDescription::Not(desc.clone())),
|
||||
OperationDescription::Bool(BoolOperationDescription::Not(desc.clone())),
|
||||
NotOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -399,7 +400,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for SwapDimsOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2);
|
||||
handles.register_bool_tensor(&self.desc.out.id, output);
|
||||
|
@ -438,7 +439,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for PermuteDimsOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let axes: [usize; D] = self.desc.axes.try_into().unwrap();
|
||||
let output = B::bool_permute(input, axes);
|
||||
|
@ -478,7 +479,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, const D2: usize, B: FusionBackend> Operation<B> for ExpandOps<D, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let shape: [usize; D2] = self.desc.shape.try_into().unwrap();
|
||||
let output = B::bool_expand(input, shape.into());
|
||||
|
@ -516,7 +517,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for FlipOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let output = B::bool_flip(input, self.desc.axes.as_slice());
|
||||
handles.register_bool_tensor(&self.desc.out.id, output);
|
||||
|
@ -552,7 +553,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for RepeatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_bool_tensor::<D>(&self.desc.tensor);
|
||||
|
||||
let output = B::bool_repeat::<D>(tensor, self.desc.dim, self.desc.times);
|
||||
|
|
|
@ -4,21 +4,23 @@ use crate::{
|
|||
get_client,
|
||||
ops::binary::binary_ops_shape,
|
||||
scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
|
||||
stream::{
|
||||
stream::{execution::Operation, StreamId},
|
||||
unary_float_ops, Fusion, FusionBackend,
|
||||
};
|
||||
|
||||
use burn_tensor::{
|
||||
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
|
||||
repr::{
|
||||
BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
|
||||
ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription,
|
||||
FloatOperationDescription, GatherOperationDescription, MaskFillOperationDescription,
|
||||
MaskWhereOperationDescription, NumericOperationDescription, Operation,
|
||||
FloatOperationDescription, GatherOperationDescription, HandleContainer,
|
||||
MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription,
|
||||
OperationDescription, PermuteOperationDescription, RandomOperationDescription,
|
||||
ReduceDimWithIndicesDescription, RepeatOperationDescription, ReshapeDescription,
|
||||
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
|
||||
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
|
||||
StreamId, SwapDimsDescription, UnaryOperationDescription,
|
||||
SwapDimsDescription, TensorDescription, UnaryOperationDescription,
|
||||
},
|
||||
unary_float_ops, Fusion, FusionBackend, TensorDescription,
|
||||
};
|
||||
use burn_tensor::{
|
||||
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
|
||||
Data, Device, Distribution, ElementConversion, Reader, Shape,
|
||||
};
|
||||
use std::ops::Range;
|
||||
|
@ -28,7 +30,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
data: Data<FloatElem<Self>, D>,
|
||||
device: &Device<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let client = get_client::<B>(&device.clone());
|
||||
let tensor = B::float_from_data(data, device);
|
||||
let shape = B::float_shape(&tensor);
|
||||
|
||||
|
@ -50,7 +52,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for RandomOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let shape = Shape::from(self.desc.out.shape.clone());
|
||||
let output: B::FloatTensorPrimitive<D> =
|
||||
B::float_random(shape, self.desc.distribution, &handles.device);
|
||||
|
@ -60,7 +62,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let stream = StreamId::current();
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let client = get_client::<B>(&device.clone());
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = RandomOperationDescription {
|
||||
|
@ -83,7 +85,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for ZerosOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let shape = Shape::from(self.out.shape.clone());
|
||||
let output = B::float_zeros::<D>(shape, &handles.device);
|
||||
handles.register_float_tensor(&self.out.id, output);
|
||||
|
@ -92,7 +94,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let stream = StreamId::current();
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let client = get_client::<B>(&device.clone());
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = out.to_description_out();
|
||||
|
@ -112,7 +114,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for OnesOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let shape = Shape::from(self.out.shape.clone());
|
||||
let output = B::float_ones::<D>(shape, &handles.device);
|
||||
handles.register_float_tensor(&self.out.id, output);
|
||||
|
@ -121,7 +123,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let stream = StreamId::current();
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let client = get_client::<B>(&device.clone());
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = out.to_description_out();
|
||||
|
@ -146,7 +148,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for FullOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let shape = Shape::from(self.out.shape.clone());
|
||||
let output: B::FloatTensorPrimitive<D> =
|
||||
B::float_full(shape, self.elem.elem(), &handles.device);
|
||||
|
@ -156,7 +158,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let stream = StreamId::current();
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let client = get_client::<B>(&device.clone());
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = (out.to_description_out(), fill_value.elem::<f32>());
|
||||
|
@ -180,15 +182,15 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
fn float_device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
|
||||
tensor.client.device().clone().into()
|
||||
tensor.client.device().clone()
|
||||
}
|
||||
|
||||
fn float_to_device<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
device: &Device<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
let device_original: &B::FusionDevice = tensor.client.device();
|
||||
let device_target: B::FusionDevice = device.clone().into();
|
||||
let device_original: &B::Device = tensor.client.device();
|
||||
let device_target: B::Device = device.clone();
|
||||
|
||||
if device_original == &device_target {
|
||||
return tensor;
|
||||
|
@ -212,7 +214,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for IntoIntOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_float_tensor::<D>(&self.desc.input);
|
||||
let output = B::float_into_int(input);
|
||||
|
||||
|
@ -237,7 +239,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
fn float_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let client = get_client::<B>(&device.clone());
|
||||
let stream = StreamId::current();
|
||||
let tensor = B::float_empty(shape.clone(), device);
|
||||
|
||||
|
@ -307,7 +309,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for ClampOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_float_tensor::<D>(&self.desc.tensor);
|
||||
let output = B::float_clamp(input, self.desc.min.elem(), self.desc.max.elem());
|
||||
|
||||
|
@ -551,7 +553,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for SwapDimsOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_float_tensor::<D>(&self.desc.input);
|
||||
let output = B::float_swap_dims(input, self.desc.dim1, self.desc.dim2);
|
||||
handles.register_float_tensor(&self.desc.out.id, output);
|
||||
|
@ -591,7 +593,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for ReshapeDimsOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_float_tensor::<D1>(&self.desc.input);
|
||||
let output = B::float_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
|
||||
handles.register_float_tensor(&self.desc.out.id, output);
|
||||
|
@ -626,7 +628,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for GatherOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor(&self.desc.indices);
|
||||
|
||||
|
@ -667,7 +669,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for ScatterOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor(&self.desc.indices);
|
||||
let value = handles.get_float_tensor(&self.desc.value);
|
||||
|
@ -712,7 +714,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for SelectOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor(&self.desc.indices);
|
||||
|
||||
|
@ -754,7 +756,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for SelectAssignOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor(&self.desc.indices);
|
||||
let value = handles.get_float_tensor(&self.desc.value);
|
||||
|
@ -799,7 +801,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_float_tensor::<D1>(&self.desc.tensor);
|
||||
|
||||
let output =
|
||||
|
@ -842,7 +844,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceAssignOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_float_tensor::<D1>(&self.desc.tensor);
|
||||
let value = handles.get_float_tensor::<D1>(&self.desc.value);
|
||||
|
||||
|
@ -887,7 +889,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for MaskWhereOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
|
||||
let value = handles.get_float_tensor(&self.desc.value);
|
||||
let mask = handles.get_bool_tensor(&self.desc.mask);
|
||||
|
@ -932,7 +934,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for MaskFillOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
|
||||
let mask = handles.get_bool_tensor(&self.desc.mask);
|
||||
|
||||
|
@ -1530,7 +1532,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for CatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensors = self
|
||||
.desc
|
||||
.tensors
|
||||
|
@ -1607,7 +1609,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for RepeatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
|
||||
|
||||
let output = B::float_repeat::<D>(tensor, self.desc.dim, self.desc.times);
|
||||
|
@ -1715,7 +1717,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for MaxDimWithIndicesOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
|
||||
let (output, indices) = B::float_max_dim_with_indices(tensor, self.desc.dim);
|
||||
|
||||
|
@ -1802,7 +1804,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for MinDimWithIndicesOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
|
||||
let (output, indices) = B::float_min_dim_with_indices(tensor, self.desc.dim);
|
||||
|
||||
|
@ -1871,7 +1873,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for PermuteDimsOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_float_tensor::<D>(&self.desc.input);
|
||||
let axes: [usize; D] = self.desc.axes.try_into().unwrap();
|
||||
let output = B::float_permute(input, axes);
|
||||
|
@ -1911,7 +1913,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, const D2: usize, B: FusionBackend> Operation<B> for ExpandOps<D, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_float_tensor::<D>(&self.desc.input);
|
||||
let shape: [usize; D2] = self.desc.shape.try_into().unwrap();
|
||||
let output = B::float_expand(input, shape.into());
|
||||
|
@ -1949,7 +1951,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for FlipOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_float_tensor::<D>(&self.desc.input);
|
||||
let output = B::float_flip(input, &self.desc.axes);
|
||||
handles.register_float_tensor(&self.desc.out.id, output);
|
||||
|
|
|
@ -4,28 +4,29 @@ use crate::{
|
|||
get_client,
|
||||
ops::binary::binary_ops_shape,
|
||||
scalar_int_cmp_ops, scalar_int_ops,
|
||||
stream::{
|
||||
self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
|
||||
ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription,
|
||||
GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription,
|
||||
NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
|
||||
RandomOperationDescription, ReduceDimWithIndicesDescription, RepeatOperationDescription,
|
||||
ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription,
|
||||
SelectAssignOperationDescription, SelectOperationDescription,
|
||||
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
|
||||
UnaryOperationDescription,
|
||||
},
|
||||
unary_int_ops, Fusion, FusionBackend, TensorDescription,
|
||||
stream::{execution::Operation, StreamId},
|
||||
unary_int_ops, Fusion, FusionBackend,
|
||||
};
|
||||
use burn_tensor::{
|
||||
ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps},
|
||||
repr::{
|
||||
self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
|
||||
ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription,
|
||||
GatherOperationDescription, HandleContainer, MaskFillOperationDescription,
|
||||
MaskWhereOperationDescription, NumericOperationDescription, OperationDescription,
|
||||
PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription,
|
||||
RepeatOperationDescription, ReshapeDescription, ScalarOperationDescription,
|
||||
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
|
||||
SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription,
|
||||
TensorDescription, UnaryOperationDescription,
|
||||
},
|
||||
Data, Device, Distribution, ElementConversion, Reader, Shape,
|
||||
};
|
||||
use core::ops::Range;
|
||||
|
||||
impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
||||
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let client = get_client::<B>(&device.clone());
|
||||
let tensor = B::int_empty(shape.clone(), device);
|
||||
let stream = StreamId::current();
|
||||
|
||||
|
@ -44,7 +45,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
data: Data<IntElem<Self>, D>,
|
||||
device: &Device<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let client = get_client::<B>(&device.clone());
|
||||
let tensor = B::int_from_data(data, device);
|
||||
let shape = B::int_shape(&tensor);
|
||||
let stream = StreamId::current();
|
||||
|
@ -53,15 +54,15 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
fn int_device<const D: usize>(tensor: &IntTensor<Self, D>) -> Device<Self> {
|
||||
tensor.client.device().clone().into()
|
||||
tensor.client.device().clone()
|
||||
}
|
||||
|
||||
fn int_to_device<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
device: &Device<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
let device_original: &B::FusionDevice = tensor.client.device();
|
||||
let device_target: B::FusionDevice = device.clone().into();
|
||||
let device_original: &B::Device = tensor.client.device();
|
||||
let device_target: B::Device = device.clone();
|
||||
|
||||
if device_original == &device_target {
|
||||
return tensor;
|
||||
|
@ -86,7 +87,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for ReshapeDimsOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_int_tensor::<D1>(&self.desc.input);
|
||||
let output = B::int_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
|
||||
handles.register_int_tensor(&self.desc.out.id, output);
|
||||
|
@ -120,7 +121,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D1>(&self.desc.tensor);
|
||||
|
||||
let output =
|
||||
|
@ -164,7 +165,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceAssignOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D1>(&self.desc.tensor);
|
||||
let value = handles.get_int_tensor::<D1>(&self.desc.value);
|
||||
|
||||
|
@ -208,7 +209,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for MaskWhereOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let value = handles.get_int_tensor(&self.desc.value);
|
||||
let mask = handles.get_bool_tensor(&self.desc.mask);
|
||||
|
@ -251,7 +252,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for MaskFillOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let mask = handles.get_bool_tensor(&self.desc.mask);
|
||||
|
||||
|
@ -291,7 +292,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for GatherOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor(&self.desc.indices);
|
||||
|
||||
|
@ -331,7 +332,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for ScatterOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor(&self.desc.indices);
|
||||
let value = handles.get_int_tensor(&self.desc.value);
|
||||
|
@ -374,7 +375,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for SelectOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor(&self.desc.indices);
|
||||
|
||||
|
@ -416,7 +417,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for SelectAssignOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor(&self.desc.indices);
|
||||
let value = handles.get_int_tensor(&self.desc.value);
|
||||
|
@ -457,7 +458,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for CatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensors = self
|
||||
.desc
|
||||
.tensors
|
||||
|
@ -770,9 +771,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::Add(
|
||||
desc.clone(),
|
||||
)),
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::Add(desc.clone())),
|
||||
AddOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -795,7 +794,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::AddScalar(
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::AddScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
AddOps::<D>::new(desc),
|
||||
|
@ -823,9 +822,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::Sub(
|
||||
desc.clone(),
|
||||
)),
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::Sub(desc.clone())),
|
||||
SubOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -848,7 +845,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::SubScalar(
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::SubScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
SubOps::<D>::new(desc),
|
||||
|
@ -876,9 +873,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::Mul(
|
||||
desc.clone(),
|
||||
)),
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::Mul(desc.clone())),
|
||||
MulOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -901,7 +896,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::MulScalar(
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::MulScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
MulOps::<D>::new(desc),
|
||||
|
@ -929,9 +924,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::Div(
|
||||
desc.clone(),
|
||||
)),
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::Div(desc.clone())),
|
||||
DivOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -954,7 +947,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::DivScalar(
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::DivScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
DivOps::<D>::new(desc),
|
||||
|
@ -979,7 +972,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::RemScalar(
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::RemScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
ModOps::<D>::new(desc),
|
||||
|
@ -995,7 +988,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for ZerosOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let shape = Shape::from(self.desc.shape.clone());
|
||||
let output = B::int_zeros::<D>(shape, &handles.device);
|
||||
handles.register_int_tensor(&self.desc.id, output);
|
||||
|
@ -1004,7 +997,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let stream = StreamId::current();
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let client = get_client::<B>(&device.clone());
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
let desc = out.to_description_out();
|
||||
client.register(
|
||||
|
@ -1023,7 +1016,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for OnesOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let shape = Shape::from(self.desc.shape.clone());
|
||||
let output = B::int_ones::<D>(shape, &handles.device);
|
||||
handles.register_int_tensor(&self.desc.id, output);
|
||||
|
@ -1032,7 +1025,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let stream = StreamId::current();
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let client = get_client::<B>(&device.clone());
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = out.to_description_out();
|
||||
|
@ -1223,7 +1216,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for ClampOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let output = B::int_clamp(input, self.desc.min.elem(), self.desc.max.elem());
|
||||
|
||||
|
@ -1274,7 +1267,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for IntoFloatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_int_tensor::<D>(&self.desc.input);
|
||||
let output = B::int_into_float(input);
|
||||
handles.register_float_tensor(&self.desc.out.id, output);
|
||||
|
@ -1289,7 +1282,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Int(stream::IntOperationDescription::IntoFloat(desc.clone())),
|
||||
OperationDescription::Int(repr::IntOperationDescription::IntoFloat(desc.clone())),
|
||||
IntoFloatOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1307,7 +1300,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for SwapDimsOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_int_tensor::<D>(&self.desc.input);
|
||||
let output = B::int_swap_dims(input, self.desc.dim1, self.desc.dim2);
|
||||
handles.register_int_tensor(&self.desc.out.id, output);
|
||||
|
@ -1387,7 +1380,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for MaxDimWithIndicesOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let (output, indices) = B::int_max_dim_with_indices(tensor, self.desc.dim);
|
||||
|
||||
|
@ -1470,7 +1463,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for MinDimWithIndicesOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let (output, indices) = B::int_min_dim_with_indices(tensor, self.desc.dim);
|
||||
|
||||
|
@ -1513,7 +1506,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for IntRandomOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let shape = Shape::from(self.desc.out.shape.clone());
|
||||
let output: B::IntTensorPrimitive<D> =
|
||||
B::int_random(shape, self.desc.distribution, &handles.device);
|
||||
|
@ -1523,7 +1516,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let stream = StreamId::current();
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let client = get_client::<B>(&device.clone());
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = RandomOperationDescription {
|
||||
|
@ -1549,7 +1542,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for PermuteDimsOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_int_tensor::<D>(&self.desc.input);
|
||||
let axes: [usize; D] = self.desc.axes.try_into().unwrap();
|
||||
let output = B::int_permute(input, axes);
|
||||
|
@ -1589,7 +1582,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, const D2: usize, B: FusionBackend> Operation<B> for ExpandOps<D, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let shape: [usize; D2] = self.desc.shape.try_into().unwrap();
|
||||
let output = B::bool_expand(input, shape.into());
|
||||
|
@ -1623,7 +1616,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for FlipDimsOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_int_tensor::<D>(&self.desc.input);
|
||||
let axes = &self.desc.axes;
|
||||
let output = B::int_flip(input, axes);
|
||||
|
@ -1661,7 +1654,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for RepeatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
|
||||
let output = B::int_repeat::<D>(tensor, self.desc.dim, self.desc.times);
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
use crate::stream::InterpolateBackwardDescription;
|
||||
use crate::{
|
||||
client::FusionClient,
|
||||
stream::{
|
||||
use crate::{client::FusionClient, stream::execution::Operation, Fusion, FusionBackend};
|
||||
use burn_tensor::{
|
||||
ops::{
|
||||
conv::{
|
||||
calculate_conv_output_size, calculate_conv_transpose_output_size,
|
||||
calculate_pool_output_size,
|
||||
},
|
||||
ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateOptions,
|
||||
MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices,
|
||||
ModuleOps,
|
||||
},
|
||||
repr::{
|
||||
AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription,
|
||||
AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription,
|
||||
AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription,
|
||||
AvgPool2dDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription,
|
||||
ConvTranspose2dDescription, InterpolateDescription, MaxPool1dDescription,
|
||||
MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription,
|
||||
MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription,
|
||||
MaxPool2dWithIndicesDescription, Operation, OperationDescription,
|
||||
ConvTranspose2dDescription, HandleContainer, InterpolateBackwardDescription,
|
||||
InterpolateDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription,
|
||||
MaxPool1dWithIndicesDescription, MaxPool2dDescription,
|
||||
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription,
|
||||
ModuleOperationDescription, OperationDescription,
|
||||
},
|
||||
Fusion, FusionBackend, HandleContainer,
|
||||
};
|
||||
use burn_tensor::ops::{
|
||||
conv::{
|
||||
calculate_conv_output_size, calculate_conv_transpose_output_size,
|
||||
calculate_pool_output_size,
|
||||
},
|
||||
ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateOptions,
|
||||
MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
|
||||
};
|
||||
|
||||
macro_rules! make_ops {
|
||||
|
@ -30,7 +30,7 @@ macro_rules! make_ops {
|
|||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B> for $name {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
$fn(self.desc, handles)
|
||||
}
|
||||
|
@ -88,9 +88,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.clone().register(
|
||||
streams,
|
||||
OperationDescription::Module(crate::stream::ModuleOperationDescription::Conv1d(
|
||||
description.clone(),
|
||||
)),
|
||||
OperationDescription::Module(ModuleOperationDescription::Conv1d(description.clone())),
|
||||
Conv1dOps::new(description),
|
||||
);
|
||||
|
||||
|
@ -155,9 +153,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
streams,
|
||||
OperationDescription::Module(crate::stream::ModuleOperationDescription::Conv2d(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::Module(ModuleOperationDescription::Conv2d(desc.clone())),
|
||||
Conv2dOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -216,9 +212,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
streams,
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::ConvTranspose1d(desc.clone()),
|
||||
),
|
||||
OperationDescription::Module(ModuleOperationDescription::ConvTranspose1d(desc.clone())),
|
||||
ConvTranspose1dOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -285,9 +279,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
streams,
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::ConvTranspose2d(desc.clone()),
|
||||
),
|
||||
OperationDescription::Module(ModuleOperationDescription::ConvTranspose2d(desc.clone())),
|
||||
ConvTranspose2dOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -333,9 +325,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Module(crate::stream::ModuleOperationDescription::AvgPool1d(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::Module(ModuleOperationDescription::AvgPool1d(desc.clone())),
|
||||
AvgPool1dOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -385,9 +375,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Module(crate::stream::ModuleOperationDescription::AvgPool2d(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::Module(ModuleOperationDescription::AvgPool2d(desc.clone())),
|
||||
AvgPool2dOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -436,9 +424,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::AvgPool1dBackward(desc.clone()),
|
||||
),
|
||||
OperationDescription::Module(ModuleOperationDescription::AvgPool1dBackward(
|
||||
desc.clone(),
|
||||
)),
|
||||
AvgPool1dBackwardOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -487,9 +475,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::AvgPool2dBackward(desc.clone()),
|
||||
),
|
||||
OperationDescription::Module(ModuleOperationDescription::AvgPool2dBackward(
|
||||
desc.clone(),
|
||||
)),
|
||||
AvgPool2dBackwardOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -536,9 +524,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Module(crate::stream::ModuleOperationDescription::MaxPool1d(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::Module(ModuleOperationDescription::MaxPool1d(desc.clone())),
|
||||
MaxPool1dOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -598,9 +584,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Module(crate::stream::ModuleOperationDescription::MaxPool2d(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::Module(ModuleOperationDescription::MaxPool2d(desc.clone())),
|
||||
MaxPool2dOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -649,9 +633,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::MaxPool1dWithIndices(desc.clone()),
|
||||
),
|
||||
OperationDescription::Module(ModuleOperationDescription::MaxPool1dWithIndices(
|
||||
desc.clone(),
|
||||
)),
|
||||
MaxPool1dWithIndicesOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -714,9 +698,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::MaxPool2dWithIndices(desc.clone()),
|
||||
),
|
||||
OperationDescription::Module(ModuleOperationDescription::MaxPool2dWithIndices(
|
||||
desc.clone(),
|
||||
)),
|
||||
MaxPool2dWithIndicesOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -770,11 +754,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2, stream_3],
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::MaxPool1dWithIndicesBackward(
|
||||
desc.clone(),
|
||||
),
|
||||
),
|
||||
OperationDescription::Module(ModuleOperationDescription::MaxPool1dWithIndicesBackward(
|
||||
desc.clone(),
|
||||
)),
|
||||
MaxPool1dWithIndicesBackwardOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -828,11 +810,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2, stream_3],
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::MaxPool2dWithIndicesBackward(
|
||||
desc.clone(),
|
||||
),
|
||||
),
|
||||
OperationDescription::Module(ModuleOperationDescription::MaxPool2dWithIndicesBackward(
|
||||
desc.clone(),
|
||||
)),
|
||||
MaxPool2dWithIndicesBackwardOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -862,9 +842,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::AdaptiveAvgPool1d(desc.clone()),
|
||||
),
|
||||
OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool1d(
|
||||
desc.clone(),
|
||||
)),
|
||||
AdaptiveAvgPool1dOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -897,9 +877,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::AdaptiveAvgPool2d(desc.clone()),
|
||||
),
|
||||
OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool2d(
|
||||
desc.clone(),
|
||||
)),
|
||||
AdaptiveAvgPool2dOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -933,9 +913,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::AdaptiveAvgPool1dBackward(desc.clone()),
|
||||
),
|
||||
OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool1dBackward(
|
||||
desc.clone(),
|
||||
)),
|
||||
AdaptiveAvgPool1dBackwardOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -969,9 +949,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::AdaptiveAvgPool2dBackward(desc.clone()),
|
||||
),
|
||||
OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool2dBackward(
|
||||
desc.clone(),
|
||||
)),
|
||||
AdaptiveAvgPool2dBackwardOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1006,9 +986,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Module(crate::stream::ModuleOperationDescription::Interpolate(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::Module(ModuleOperationDescription::Interpolate(desc.clone())),
|
||||
InterpolateOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1047,9 +1025,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::InterpolateBackward(desc.clone()),
|
||||
),
|
||||
OperationDescription::Module(ModuleOperationDescription::InterpolateBackward(
|
||||
desc.clone(),
|
||||
)),
|
||||
InterpolateBackwardOps::new(desc),
|
||||
);
|
||||
out
|
||||
|
|
|
@ -18,7 +18,7 @@ macro_rules! scalar_float_ops {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
|
||||
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
|
||||
|
||||
|
@ -38,7 +38,7 @@ macro_rules! scalar_float_ops {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
|
||||
let output = $ops(lhs, self.desc.rhs);
|
||||
|
||||
|
@ -62,7 +62,7 @@ macro_rules! scalar_float2int_ops {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
|
||||
let output = $ops(lhs, self.desc.rhs.clone());
|
||||
|
||||
|
@ -85,7 +85,7 @@ macro_rules! unary_float_ops {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_float_tensor::<D>(&self.desc.input);
|
||||
let output = $ops(input);
|
||||
|
||||
|
@ -108,7 +108,7 @@ macro_rules! unary_int_ops {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let input = handles.get_int_tensor::<D>(&self.desc.input);
|
||||
let output = $ops(input);
|
||||
|
||||
|
@ -131,7 +131,7 @@ macro_rules! scalar_float_cmp_ops {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
|
||||
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
|
||||
|
||||
|
@ -154,7 +154,7 @@ macro_rules! scalar_int_cmp_ops {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
|
||||
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
|
||||
|
||||
|
@ -184,7 +184,7 @@ macro_rules! scalar_int_ops {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
|
||||
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
|
||||
|
||||
|
@ -204,7 +204,7 @@ macro_rules! scalar_int_ops {
|
|||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>) {
|
||||
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
|
||||
let output = $ops(lhs, self.desc.rhs);
|
||||
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
use crate::{
|
||||
stream::{MultiStream, Operation, OperationDescription, StreamId},
|
||||
FusionBackend, HandleContainer, TensorId,
|
||||
stream::{execution::Operation, MultiStream, StreamId},
|
||||
FusionBackend,
|
||||
};
|
||||
use burn_tensor::{
|
||||
ops::{FloatElem, IntElem},
|
||||
repr::{HandleContainer, OperationDescription, TensorDescription, TensorId},
|
||||
};
|
||||
use burn_tensor::ops::{FloatElem, IntElem};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct FusionServer<B>
|
||||
|
@ -11,14 +14,14 @@ where
|
|||
{
|
||||
streams: MultiStream<B>,
|
||||
pub(crate) handles: HandleContainer<B>,
|
||||
pub device: B::FusionDevice,
|
||||
pub device: B::Device,
|
||||
}
|
||||
|
||||
impl<B> FusionServer<B>
|
||||
where
|
||||
B: FusionBackend,
|
||||
{
|
||||
pub fn new(device: B::FusionDevice) -> Self {
|
||||
pub fn new(device: B::Device) -> Self {
|
||||
Self {
|
||||
streams: MultiStream::new(device.clone()),
|
||||
handles: HandleContainer::new(device.clone()),
|
||||
|
@ -46,7 +49,7 @@ where
|
|||
|
||||
pub fn read_float<const D: usize>(
|
||||
&mut self,
|
||||
tensor: crate::TensorDescription,
|
||||
tensor: TensorDescription,
|
||||
id: StreamId,
|
||||
) -> burn_tensor::Reader<burn_tensor::Data<FloatElem<B>, D>> {
|
||||
// Make sure all registered operations are executed.
|
||||
|
@ -59,7 +62,7 @@ where
|
|||
|
||||
pub fn read_int<const D: usize>(
|
||||
&mut self,
|
||||
tensor: crate::TensorDescription,
|
||||
tensor: TensorDescription,
|
||||
id: StreamId,
|
||||
) -> burn_tensor::Reader<burn_tensor::Data<IntElem<B>, D>> {
|
||||
// Make sure all registered operations are executed.
|
||||
|
@ -72,7 +75,7 @@ where
|
|||
|
||||
pub fn read_bool<const D: usize>(
|
||||
&mut self,
|
||||
tensor: crate::TensorDescription,
|
||||
tensor: TensorDescription,
|
||||
id: StreamId,
|
||||
) -> burn_tensor::Reader<burn_tensor::Data<bool, D>> {
|
||||
// Make sure all registered operations are executed.
|
||||
|
@ -85,7 +88,7 @@ where
|
|||
|
||||
pub fn change_server_float<const D: usize>(
|
||||
&mut self,
|
||||
tensor: &crate::TensorDescription,
|
||||
tensor: &TensorDescription,
|
||||
device: &B::Device,
|
||||
server_device: &mut Self,
|
||||
) -> Arc<TensorId> {
|
||||
|
@ -101,7 +104,7 @@ where
|
|||
}
|
||||
pub fn change_server_int<const D: usize>(
|
||||
&mut self,
|
||||
tensor: &crate::TensorDescription,
|
||||
tensor: &TensorDescription,
|
||||
device: &B::Device,
|
||||
server_device: &mut Self,
|
||||
) -> Arc<TensorId> {
|
||||
|
@ -117,7 +120,7 @@ where
|
|||
}
|
||||
pub fn change_server_bool<const D: usize>(
|
||||
&mut self,
|
||||
tensor: &crate::TensorDescription,
|
||||
tensor: &TensorDescription,
|
||||
device: &B::Device,
|
||||
server_device: &mut Self,
|
||||
) -> Arc<TensorId> {
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
use super::Operation;
|
||||
use super::OperationConverter;
|
||||
use super::OperationDescription;
|
||||
use burn_tensor::repr::OperationDescription;
|
||||
|
||||
use crate::FusionBackend;
|
||||
|
||||
use super::{execution::Operation, OperationConverter, RelativeOps};
|
||||
|
||||
/// A growing list of [tensor operation descriptions](OperationDescription).
|
||||
pub struct OperationQueue<B: FusionBackend> {
|
||||
pub(crate) global: Vec<OperationDescription>,
|
||||
|
|
|
@ -1,23 +1,5 @@
|
|||
use super::{
|
||||
AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription,
|
||||
AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription,
|
||||
AvgPool2dBackwardDescription, AvgPool2dDescription, BaseOperationDescription,
|
||||
BinaryOperationDescription, BoolOperationDescription, ClampOperationDescription,
|
||||
Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, ConvTranspose2dDescription,
|
||||
EmbeddingBackwardDescription, EmbeddingDescription, ExpandOperationDescription,
|
||||
FlipOperationDescription, FloatOperationDescription, GatherOperationDescription,
|
||||
IntOperationDescription, InterpolateBackwardDescription, InterpolateDescription,
|
||||
MaskFillOperationDescription, MaskWhereOperationDescription, MaxPool1dDescription,
|
||||
MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription, MaxPool2dDescription,
|
||||
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription,
|
||||
ModuleOperationDescription, NumericOperationDescription, OperationDescription,
|
||||
PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription,
|
||||
ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription,
|
||||
SelectAssignOperationDescription, SelectOperationDescription, SliceOperationDescription,
|
||||
SwapDimsDescription, UnaryOperationDescription,
|
||||
};
|
||||
use crate::{FusionBackend, HandleContainer, TensorDescription, TensorId};
|
||||
use burn_tensor::{Element, ElementConversion};
|
||||
use crate::FusionBackend;
|
||||
use burn_tensor::{repr::*, Element, ElementConversion};
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// The context contains the relative graph tensor mapping so that a relative tensor id can be
|
||||
|
@ -49,6 +31,16 @@ pub(crate) struct OperationConverter {
|
|||
scalar_ints: Vec<i32>,
|
||||
}
|
||||
|
||||
pub(crate) trait RelativeOps {
|
||||
fn to_relative(&self, converter: &mut OperationConverter) -> Self;
|
||||
}
|
||||
|
||||
trait RelativeOpsScalar<E: Element> {
|
||||
fn to_relative<F>(&self, converter: &mut OperationConverter, local_elem: F) -> Self
|
||||
where
|
||||
F: Fn(&mut OperationConverter, &E) -> E;
|
||||
}
|
||||
|
||||
impl OperationConverter {
|
||||
pub(crate) fn context<'a, B: FusionBackend>(
|
||||
&'a self,
|
||||
|
@ -85,8 +77,8 @@ impl OperationConverter {
|
|||
}
|
||||
}
|
||||
|
||||
impl OperationDescription {
|
||||
pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
impl RelativeOps for OperationDescription {
|
||||
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
match self {
|
||||
OperationDescription::BaseFloat(ops) => {
|
||||
OperationDescription::BaseFloat(ops.to_relative(converter))
|
||||
|
@ -117,8 +109,8 @@ impl OperationDescription {
|
|||
}
|
||||
}
|
||||
|
||||
impl ModuleOperationDescription {
|
||||
pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
impl RelativeOps for ModuleOperationDescription {
|
||||
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
match self {
|
||||
ModuleOperationDescription::Embedding(desc) => {
|
||||
ModuleOperationDescription::Embedding(EmbeddingDescription {
|
||||
|
@ -172,7 +164,7 @@ impl ModuleOperationDescription {
|
|||
})
|
||||
}
|
||||
ModuleOperationDescription::AvgPool1d(desc) => {
|
||||
ModuleOperationDescription::AvgPool1d(super::AvgPool1dDescription {
|
||||
ModuleOperationDescription::AvgPool1d(AvgPool1dDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
kernel_size: desc.kernel_size,
|
||||
stride: desc.stride,
|
||||
|
@ -192,7 +184,7 @@ impl ModuleOperationDescription {
|
|||
})
|
||||
}
|
||||
ModuleOperationDescription::AvgPool1dBackward(desc) => {
|
||||
ModuleOperationDescription::AvgPool1dBackward(super::AvgPool1dBackwardDescription {
|
||||
ModuleOperationDescription::AvgPool1dBackward(AvgPool1dBackwardDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
grad: desc.grad.to_relative(converter),
|
||||
kernel_size: desc.kernel_size,
|
||||
|
@ -336,8 +328,8 @@ impl ModuleOperationDescription {
|
|||
}
|
||||
}
|
||||
|
||||
impl FloatOperationDescription {
|
||||
pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
impl RelativeOps for FloatOperationDescription {
|
||||
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
match self {
|
||||
FloatOperationDescription::Exp(desc) => {
|
||||
FloatOperationDescription::Exp(UnaryOperationDescription {
|
||||
|
@ -423,8 +415,8 @@ impl FloatOperationDescription {
|
|||
}
|
||||
}
|
||||
|
||||
impl BoolOperationDescription {
|
||||
pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
impl RelativeOps for BoolOperationDescription {
|
||||
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
match self {
|
||||
BoolOperationDescription::IntoFloat(desc) => {
|
||||
BoolOperationDescription::IntoFloat(UnaryOperationDescription {
|
||||
|
@ -448,8 +440,8 @@ impl BoolOperationDescription {
|
|||
}
|
||||
}
|
||||
|
||||
impl IntOperationDescription {
|
||||
pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
impl RelativeOps for IntOperationDescription {
|
||||
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
match self {
|
||||
IntOperationDescription::IntoFloat(desc) => {
|
||||
IntOperationDescription::IntoFloat(UnaryOperationDescription {
|
||||
|
@ -461,8 +453,8 @@ impl IntOperationDescription {
|
|||
}
|
||||
}
|
||||
|
||||
impl<E: Element> NumericOperationDescription<E> {
|
||||
pub(crate) fn to_relative<F>(&self, converter: &mut OperationConverter, local_elem: F) -> Self
|
||||
impl<E: Element> RelativeOpsScalar<E> for NumericOperationDescription<E> {
|
||||
fn to_relative<F>(&self, converter: &mut OperationConverter, local_elem: F) -> Self
|
||||
where
|
||||
F: Fn(&mut OperationConverter, &E) -> E,
|
||||
{
|
||||
|
@ -779,8 +771,8 @@ impl<E: Element> NumericOperationDescription<E> {
|
|||
}
|
||||
}
|
||||
|
||||
impl BaseOperationDescription {
|
||||
pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
impl RelativeOps for BaseOperationDescription {
|
||||
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
match self {
|
||||
BaseOperationDescription::ToDevice(desc) => {
|
||||
BaseOperationDescription::ToDevice(desc.to_relative(converter))
|
||||
|
@ -828,7 +820,7 @@ impl BaseOperationDescription {
|
|||
})
|
||||
}
|
||||
BaseOperationDescription::SliceAssign(desc) => {
|
||||
BaseOperationDescription::SliceAssign(super::SliceAssignOperationDescription {
|
||||
BaseOperationDescription::SliceAssign(SliceAssignOperationDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
ranges: desc.ranges.iter().map(|_range| 0..1).collect(),
|
||||
value: desc.value.to_relative(converter),
|
||||
|
@ -836,14 +828,14 @@ impl BaseOperationDescription {
|
|||
})
|
||||
}
|
||||
BaseOperationDescription::Equal(desc) => {
|
||||
BaseOperationDescription::Equal(super::BinaryOperationDescription {
|
||||
BaseOperationDescription::Equal(BinaryOperationDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
BaseOperationDescription::Repeat(desc) => {
|
||||
BaseOperationDescription::Repeat(super::RepeatOperationDescription {
|
||||
BaseOperationDescription::Repeat(RepeatOperationDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
dim: desc.dim,
|
||||
times: desc.times,
|
||||
|
@ -851,7 +843,7 @@ impl BaseOperationDescription {
|
|||
})
|
||||
}
|
||||
BaseOperationDescription::Cat(desc) => {
|
||||
BaseOperationDescription::Cat(super::CatOperationDescription {
|
||||
BaseOperationDescription::Cat(CatOperationDescription {
|
||||
tensors: desc
|
||||
.tensors
|
||||
.iter()
|
||||
|
@ -865,8 +857,8 @@ impl BaseOperationDescription {
|
|||
}
|
||||
}
|
||||
|
||||
impl TensorDescription {
|
||||
pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
impl RelativeOps for TensorDescription {
|
||||
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
let relative_id = if let Some(value) = converter.tensors_global2relative.get(&self.id) {
|
||||
// If we already have the same tensor registered, we have to update its value, but not
|
||||
// its id.
|
||||
|
@ -911,9 +903,8 @@ impl TensorDescription {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::TensorStatus;
|
||||
|
||||
use super::*;
|
||||
use burn_tensor::repr::{TensorDescription, TensorId, TensorStatus};
|
||||
|
||||
#[test]
|
||||
fn tensor_description_to_relative() {
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
use burn_tensor::repr::HandleContainer;
|
||||
|
||||
use crate::{
|
||||
stream::{
|
||||
store::{ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy},
|
||||
OperationQueue,
|
||||
OperationQueue, RelativeOps,
|
||||
},
|
||||
FusionBackend, HandleContainer, Optimization,
|
||||
FusionBackend, Optimization,
|
||||
};
|
||||
|
||||
/// The mode in which the execution is done.
|
||||
|
@ -13,6 +15,12 @@ pub(crate) enum ExecutionMode {
|
|||
Sync,
|
||||
}
|
||||
|
||||
/// General trait to abstract how a single operation is executed.
|
||||
pub trait Operation<B: FusionBackend>: Send + Sync {
|
||||
/// Execute the operation.
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>);
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> OperationQueue<B> {
|
||||
/// Execute the queue partially following the execution strategy from the plan.
|
||||
pub(crate) fn execute(
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use burn_tensor::repr::OperationDescription;
|
||||
|
||||
use super::ExecutionMode;
|
||||
use crate::{stream::OperationDescription, OptimizationBuilder, OptimizationStatus};
|
||||
use crate::{OptimizationBuilder, OptimizationStatus};
|
||||
|
||||
/// Explore and create new optimization.
|
||||
pub struct Explorer<O> {
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
use burn_tensor::repr::OperationDescription;
|
||||
|
||||
use super::validator::{
|
||||
ExecutionPlanOperationsStore, TriggerOperationsStore, TriggerProgress, TriggerValidator,
|
||||
ValidatorState,
|
||||
};
|
||||
use super::ExecutionMode;
|
||||
use crate::stream::execution::validator::OperationsValidator;
|
||||
use crate::stream::{
|
||||
store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger, SearchQuery},
|
||||
OperationDescription,
|
||||
};
|
||||
use crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger, SearchQuery};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// The policy keeps track of all possible execution plans for the current operations.
|
||||
|
@ -266,14 +265,13 @@ impl<O> Policy<O> {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
stream::{
|
||||
store::{ExecutionPlan, ExecutionStrategy, ExecutionTrigger},
|
||||
FloatOperationDescription, UnaryOperationDescription,
|
||||
},
|
||||
TensorDescription, TensorId, TensorStatus,
|
||||
use burn_tensor::repr::{
|
||||
FloatOperationDescription, TensorDescription, TensorId, TensorStatus,
|
||||
UnaryOperationDescription,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use crate::stream::store::{ExecutionPlan, ExecutionStrategy, ExecutionTrigger};
|
||||
use std::ops::Range;
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use burn_tensor::repr::OperationDescription;
|
||||
|
||||
use super::{ExecutionMode, Exploration, Explorer};
|
||||
use crate::stream::execution::{Action, Policy};
|
||||
use crate::stream::store::{
|
||||
ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger,
|
||||
};
|
||||
use crate::stream::OperationDescription;
|
||||
use crate::OptimizationBuilder;
|
||||
|
||||
/// Process a [stream segment](StreamSegment) following a [policy](Policy).
|
||||
|
|
|
@ -6,16 +6,17 @@
|
|||
//! To test these components effectively, we create mock types for the stream, optimization,
|
||||
//! optimization builder, and stream segment. These mock types aid in comprehensively
|
||||
//! understanding the process of optimizing streams.
|
||||
use burn_tensor::repr::{
|
||||
BinaryOperationDescription, FloatOperationDescription, NumericOperationDescription,
|
||||
OperationDescription, ScalarOperationDescription, TensorDescription, TensorId, TensorStatus,
|
||||
UnaryOperationDescription,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
stream::{
|
||||
store::{
|
||||
ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger,
|
||||
},
|
||||
BinaryOperationDescription, FloatOperationDescription, NumericOperationDescription,
|
||||
OperationDescription, ScalarOperationDescription,
|
||||
stream::store::{
|
||||
ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger,
|
||||
},
|
||||
OptimizationBuilder, OptimizationProperties, OptimizationStatus, TensorDescription, TensorId,
|
||||
TensorStatus,
|
||||
OptimizationBuilder, OptimizationProperties, OptimizationStatus,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
@ -558,18 +559,16 @@ fn operation_2() -> OperationDescription {
|
|||
|
||||
/// Just a simple operation.
|
||||
fn operation_3() -> OperationDescription {
|
||||
OperationDescription::Float(FloatOperationDescription::Log(
|
||||
crate::stream::UnaryOperationDescription {
|
||||
input: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
},
|
||||
out: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::NotInit,
|
||||
},
|
||||
OperationDescription::Float(FloatOperationDescription::Log(UnaryOperationDescription {
|
||||
input: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
},
|
||||
))
|
||||
out: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::NotInit,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use crate::stream::{
|
||||
store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger},
|
||||
OperationDescription,
|
||||
};
|
||||
use burn_tensor::repr::OperationDescription;
|
||||
|
||||
use crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger};
|
||||
|
||||
/// Compare each operation in the list of operations provided by the [store](OperationsStore)
|
||||
/// to verify if the newly added operations match the original list.
|
||||
|
|
|
@ -4,9 +4,7 @@ pub(crate) mod store;
|
|||
mod base;
|
||||
mod context;
|
||||
mod multi;
|
||||
mod operation;
|
||||
|
||||
pub use base::*;
|
||||
pub use context::*;
|
||||
pub use multi::*;
|
||||
pub use operation::*;
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
use burn_tensor::repr::{HandleContainer, OperationDescription};
|
||||
|
||||
use super::{
|
||||
execution::{ExecutionMode, Processor, StreamSegment},
|
||||
execution::{ExecutionMode, Operation, Processor, StreamSegment},
|
||||
store::{ExecutionPlanId, ExecutionPlanStore},
|
||||
Operation, OperationDescription, OperationQueue, StreamId,
|
||||
OperationQueue, StreamId,
|
||||
};
|
||||
use crate::{FusionBackend, HandleContainer};
|
||||
use crate::FusionBackend;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Keep track of multiple concurrent streams of operations.
|
||||
pub struct MultiStream<B: FusionBackend> {
|
||||
streams: HashMap<StreamId, Stream<B>>,
|
||||
optimizations: ExecutionPlanStore<B::Optimization>,
|
||||
device: B::FusionDevice,
|
||||
device: B::Device,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> MultiStream<B> {
|
||||
pub(crate) fn new(device: B::FusionDevice) -> Self {
|
||||
pub(crate) fn new(device: B::Device) -> Self {
|
||||
Self {
|
||||
streams: HashMap::new(),
|
||||
optimizations: ExecutionPlanStore::new(),
|
||||
|
@ -146,9 +148,9 @@ impl<'i, B: FusionBackend> StreamSegment<B::Optimization> for Segment<'i, B> {
|
|||
}
|
||||
|
||||
impl<B: FusionBackend> Stream<B> {
|
||||
fn new(device: B::FusionDevice) -> Self {
|
||||
fn new(device: B::Device) -> Self {
|
||||
Self {
|
||||
processor: Processor::new(B::optimizations(device.into())),
|
||||
processor: Processor::new(B::optimizations(device)),
|
||||
queue: OperationQueue::new(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::{ExecutionPlanIndex, InsertQuery, SearchQuery};
|
||||
use crate::stream::OperationDescription;
|
||||
use burn_tensor::repr::OperationDescription;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// The store that contains all explorations done on a device.
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::stream::{store::ExecutionPlanId, OperationDescription};
|
||||
use crate::stream::store::ExecutionPlanId;
|
||||
use burn_tensor::repr::OperationDescription;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::{hash_map::DefaultHasher, HashMap},
|
||||
|
@ -115,14 +116,13 @@ impl ExecutionPlanIndex {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
stream::{
|
||||
BinaryOperationDescription, NumericOperationDescription, ScalarOperationDescription,
|
||||
},
|
||||
use burn_tensor::repr::{
|
||||
BinaryOperationDescription, NumericOperationDescription, ScalarOperationDescription,
|
||||
TensorDescription, TensorId, TensorStatus,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn should_find_optimization_id_based_on_tensor_ops() {
|
||||
let mut index = ExecutionPlanIndex::default();
|
||||
|
|
|
@ -2,9 +2,9 @@ use crate::{client::FusionClient, stream::StreamId};
|
|||
use burn_tensor::{
|
||||
backend::Backend,
|
||||
ops::{FloatElem, IntElem},
|
||||
repr::{TensorDescription, TensorId, TensorStatus},
|
||||
Data, Reader, Shape,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind.
|
||||
|
@ -33,7 +33,7 @@ impl<C: FusionClient> core::fmt::Debug for FusionTensor<C> {
|
|||
self.shape,
|
||||
self.is_orphan,
|
||||
<C::FusionBackend as Backend>::name(),
|
||||
self.client.device().clone().into(),
|
||||
self.client.device().clone(),
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
|
@ -127,47 +127,3 @@ impl<C: FusionClient> Drop for FusionTensor<C> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The tensor unique identifier.
|
||||
#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]
|
||||
pub struct TensorId {
|
||||
value: u64,
|
||||
}
|
||||
|
||||
/// The status of the current tensor.
|
||||
#[derive(Hash, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum TensorStatus {
|
||||
/// The tensor can be read, but not written.
|
||||
ReadOnly,
|
||||
/// The tensor can be mutated inplace.
|
||||
ReadWrite,
|
||||
/// No handle exists for that tensor.
|
||||
NotInit,
|
||||
}
|
||||
|
||||
/// A tensor definition represents a snapshot of a tensor when it was used.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// A tensor that is used multiple times has its status updated for each operation.
|
||||
///
|
||||
/// 1. Status::NotInit
|
||||
/// 2. Status::ReadOnly
|
||||
/// 3. Status::ReadOnly
|
||||
/// 4. Status::ReadWrite
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct TensorDescription {
|
||||
/// The [tensor id](TensorId).
|
||||
pub id: TensorId,
|
||||
/// The shape of the tensor.
|
||||
pub shape: Vec<usize>,
|
||||
/// The [status](TensorStatus) of the tensor when it was used.
|
||||
pub status: TensorStatus,
|
||||
}
|
||||
|
||||
impl TensorId {
|
||||
/// Create a new tensor id.
|
||||
pub fn new(value: u64) -> Self {
|
||||
Self { value }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
#[cfg(feature = "fusion")]
|
||||
use crate::fusion::JitFusionHandle;
|
||||
#[cfg(feature = "fusion")]
|
||||
use burn_fusion::TensorDescription;
|
||||
|
||||
use super::{
|
||||
dialect::gpu::{self},
|
||||
|
@ -136,8 +134,8 @@ impl CompilationSettings {
|
|||
pub fn dynamic_settings<R: Runtime>(
|
||||
self,
|
||||
info: &CompilationInfo,
|
||||
inputs: &[&TensorDescription],
|
||||
outputs: &[&TensorDescription],
|
||||
inputs: &[&burn_tensor::repr::TensorDescription],
|
||||
outputs: &[&burn_tensor::repr::TensorDescription],
|
||||
handles_inputs: &[JitFusionHandle<R>],
|
||||
stateful: bool,
|
||||
) -> Self {
|
||||
|
@ -154,8 +152,8 @@ impl CompilationSettings {
|
|||
fn dynamic_inplace<R: Runtime>(
|
||||
self,
|
||||
info: &CompilationInfo,
|
||||
inputs: &[&TensorDescription],
|
||||
outputs: &[&TensorDescription],
|
||||
inputs: &[&burn_tensor::repr::TensorDescription],
|
||||
outputs: &[&burn_tensor::repr::TensorDescription],
|
||||
handles_inputs: &[JitFusionHandle<R>],
|
||||
) -> Self {
|
||||
let mut potential_inplace = inputs
|
||||
|
@ -170,9 +168,9 @@ impl CompilationSettings {
|
|||
}
|
||||
|
||||
match desc.status {
|
||||
burn_fusion::TensorStatus::ReadOnly => return None,
|
||||
burn_fusion::TensorStatus::NotInit => return None,
|
||||
burn_fusion::TensorStatus::ReadWrite => (),
|
||||
burn_tensor::repr::TensorStatus::ReadOnly => return None,
|
||||
burn_tensor::repr::TensorStatus::NotInit => return None,
|
||||
burn_tensor::repr::TensorStatus::ReadWrite => (),
|
||||
};
|
||||
|
||||
Some((pos, desc, input))
|
||||
|
@ -215,8 +213,8 @@ impl CompilationSettings {
|
|||
fn dynamic_reading_strategy<R: Runtime>(
|
||||
mut self,
|
||||
info: &CompilationInfo,
|
||||
inputs: &[&TensorDescription],
|
||||
outputs: &[&TensorDescription],
|
||||
inputs: &[&burn_tensor::repr::TensorDescription],
|
||||
outputs: &[&burn_tensor::repr::TensorDescription],
|
||||
handles_inputs: &[JitFusionHandle<R>],
|
||||
) -> Self {
|
||||
// First output is chosen for the layout reference.
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::{
|
|||
};
|
||||
use burn_compute::client::ComputeClient;
|
||||
use burn_fusion::{client::MutexFusionClient, FusionBackend};
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::{repr::ReprBackend, Shape};
|
||||
use core::marker::PhantomData;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
@ -53,53 +53,61 @@ impl<R: Runtime> burn_fusion::Optimization<JitBackend<R>> for JitOptimization<R>
|
|||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> FusionBackend for JitBackend<R> {
|
||||
type OptimizationState = JitOptimizationState;
|
||||
type Optimization = JitOptimization<R>;
|
||||
type FusionDevice = R::Device;
|
||||
impl<R: Runtime> ReprBackend for JitBackend<R> {
|
||||
type Handle = JitFusionHandle<R>;
|
||||
type FusionClient = MutexFusionClient<Self>;
|
||||
|
||||
fn optimizations(
|
||||
device: R::Device,
|
||||
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
|
||||
vec![Box::new(ElementWiseBuilder::new(device))]
|
||||
}
|
||||
|
||||
fn float_tensor<const D: usize>(
|
||||
handle: Self::Handle,
|
||||
shape: Shape<D>,
|
||||
) -> Self::FloatTensorPrimitive<D> {
|
||||
) -> burn_tensor::ops::FloatTensor<Self, D> {
|
||||
handle.into_tensor(shape)
|
||||
}
|
||||
|
||||
fn int_tensor<const D: usize>(
|
||||
handle: Self::Handle,
|
||||
shape: Shape<D>,
|
||||
) -> Self::IntTensorPrimitive<D> {
|
||||
) -> burn_tensor::ops::IntTensor<Self, D> {
|
||||
handle.into_tensor(shape)
|
||||
}
|
||||
|
||||
fn bool_tensor<const D: usize>(
|
||||
handle: Self::Handle,
|
||||
shape: Shape<D>,
|
||||
) -> Self::BoolTensorPrimitive<D> {
|
||||
) -> burn_tensor::ops::BoolTensor<Self, D> {
|
||||
handle.into_tensor(shape)
|
||||
}
|
||||
|
||||
fn float_tensor_handle<const D: usize>(tensor: Self::FloatTensorPrimitive<D>) -> Self::Handle {
|
||||
fn float_tensor_handle<const D: usize>(
|
||||
tensor: burn_tensor::ops::FloatTensor<Self, D>,
|
||||
) -> Self::Handle {
|
||||
tensor.into()
|
||||
}
|
||||
|
||||
fn int_tensor_handle<const D: usize>(tensor: Self::IntTensorPrimitive<D>) -> Self::Handle {
|
||||
fn int_tensor_handle<const D: usize>(
|
||||
tensor: burn_tensor::ops::IntTensor<Self, D>,
|
||||
) -> Self::Handle {
|
||||
tensor.into()
|
||||
}
|
||||
|
||||
fn bool_tensor_handle<const D: usize>(tensor: Self::BoolTensorPrimitive<D>) -> Self::Handle {
|
||||
fn bool_tensor_handle<const D: usize>(
|
||||
tensor: burn_tensor::ops::BoolTensor<Self, D>,
|
||||
) -> Self::Handle {
|
||||
tensor.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> FusionBackend for JitBackend<R> {
|
||||
type OptimizationState = JitOptimizationState;
|
||||
type Optimization = JitOptimization<R>;
|
||||
type FusionClient = MutexFusionClient<Self>;
|
||||
|
||||
fn optimizations(
|
||||
device: R::Device,
|
||||
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
|
||||
vec![Box::new(ElementWiseBuilder::new(device))]
|
||||
}
|
||||
}
|
||||
|
||||
pub fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
|
||||
let mut strides = vec![0; shape.len()];
|
||||
|
||||
|
|
|
@ -7,16 +7,14 @@ use crate::{
|
|||
fusion::{tracing::TraceBuilder, JitOptimization},
|
||||
JitBackend, Runtime,
|
||||
};
|
||||
use burn_fusion::{
|
||||
stream::{
|
||||
BaseOperationDescription, BinaryOperationDescription, FloatOperationDescription,
|
||||
NumericOperationDescription, OperationDescription, ScalarOperationDescription,
|
||||
UnaryOperationDescription,
|
||||
},
|
||||
OptimizationBuilder, OptimizationProperties, OptimizationStatus, TensorDescription,
|
||||
};
|
||||
use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus};
|
||||
use burn_tensor::{
|
||||
ops::{FloatElem, IntElem},
|
||||
repr::{
|
||||
BaseOperationDescription, BinaryOperationDescription, FloatOperationDescription,
|
||||
NumericOperationDescription, OperationDescription, ScalarOperationDescription,
|
||||
TensorDescription, UnaryOperationDescription,
|
||||
},
|
||||
Device, Element,
|
||||
};
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use burn_tensor::repr::TensorDescription;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
calculate_num_elems_dyn_rank,
|
||||
|
@ -11,7 +13,6 @@ use crate::{
|
|||
kernel::elemwise_workgroup,
|
||||
Runtime,
|
||||
};
|
||||
use burn_fusion::TensorDescription;
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
|
||||
#[derive(new)]
|
||||
|
|
|
@ -15,7 +15,8 @@ use burn_compute::client::ComputeClient;
|
|||
use burn_compute::server::Handle;
|
||||
use burn_compute::tune::AutotuneOperation;
|
||||
use burn_fusion::stream::Context;
|
||||
use burn_fusion::{TensorDescription, TensorStatus};
|
||||
use burn_tensor::repr::TensorDescription;
|
||||
use burn_tensor::repr::TensorStatus;
|
||||
use burn_tensor::Device;
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
use super::{trace::Trace, Scalars};
|
||||
use crate::codegen::dialect::gpu::{self, Operation, Variable};
|
||||
use burn_fusion::{TensorDescription, TensorId};
|
||||
use burn_tensor::Element;
|
||||
use burn_tensor::{
|
||||
repr::{TensorDescription, TensorId, TensorStatus},
|
||||
Element,
|
||||
};
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// Type facilitating building a [trace](Trace) by doing most of the conversions between the
|
||||
|
@ -415,7 +417,7 @@ impl TraceBuilder {
|
|||
// are going to be used after the fused kernel by other operations.
|
||||
for entry in self.tensors.values() {
|
||||
let (tensor, _) = &entry;
|
||||
if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
|
||||
if let TensorStatus::ReadOnly = tensor.status {
|
||||
if self.output_to_local.contains_key(&tensor.id) {
|
||||
outputs.push(entry.clone());
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::Scalars;
|
||||
use crate::codegen::{dialect::gpu, CompilationInfo, InputInfo, OutputInfo};
|
||||
use burn_fusion::TensorDescription;
|
||||
use burn_tensor::repr::TensorDescription;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A trace encapsulates all information necessary to perform the compilation and execution of
|
||||
|
|
|
@ -16,8 +16,7 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
|
|||
/// The channel used to communicate with the compute server.
|
||||
type Channel: ComputeChannel<Self::Server>;
|
||||
/// The device used to retrieve the compute client.
|
||||
#[cfg(any(feature = "fusion", test))]
|
||||
type Device: burn_fusion::FusionDevice
|
||||
type Device: burn_tensor::backend::DeviceOps
|
||||
+ Default
|
||||
+ core::hash::Hash
|
||||
+ PartialEq
|
||||
|
@ -26,16 +25,6 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
|
|||
+ core::fmt::Debug
|
||||
+ Sync
|
||||
+ Send;
|
||||
/// The device used to retrieve the compute client.
|
||||
#[cfg(not(any(feature = "fusion", test)))]
|
||||
type Device: Default
|
||||
+ core::hash::Hash
|
||||
+ PartialEq
|
||||
+ Eq
|
||||
+ Clone
|
||||
+ core::fmt::Debug
|
||||
+ Sync
|
||||
+ Send;
|
||||
|
||||
/// A version of the runtime that supports full precision.
|
||||
///
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::NdArrayTensor;
|
|||
use crate::{element::FloatNdArrayElement, PrecisionBridge};
|
||||
use alloc::string::String;
|
||||
use burn_common::stub::Mutex;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
|
||||
use core::marker::PhantomData;
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
|
@ -15,6 +15,14 @@ pub enum NdArrayDevice {
|
|||
Cpu,
|
||||
}
|
||||
|
||||
impl DeviceOps for NdArrayDevice {
|
||||
fn id(&self) -> burn_tensor::backend::DeviceId {
|
||||
match self {
|
||||
NdArrayDevice::Cpu => DeviceId::new(0, 0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NdArrayDevice {
|
||||
fn default() -> Self {
|
||||
Self::Cpu
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::PrecisionBridge;
|
|||
|
||||
use super::element::TchElement;
|
||||
use super::TchTensor;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
|
||||
use burn_tensor::ops::IntTensorOps;
|
||||
use burn_tensor::{Int, Tensor};
|
||||
|
||||
|
@ -59,6 +59,17 @@ impl From<tch::Device> for LibTorchDevice {
|
|||
}
|
||||
}
|
||||
|
||||
impl DeviceOps for LibTorchDevice {
|
||||
fn id(&self) -> burn_tensor::backend::DeviceId {
|
||||
match self {
|
||||
LibTorchDevice::Cpu => DeviceId::new(0, 0),
|
||||
LibTorchDevice::Cuda(index) => DeviceId::new(1, *index as u32),
|
||||
LibTorchDevice::Mps => DeviceId::new(2, 0),
|
||||
LibTorchDevice::Vulkan => DeviceId::new(3, 0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LibTorchDevice {
|
||||
fn default() -> Self {
|
||||
Self::Cpu
|
||||
|
|
|
@ -11,11 +11,12 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-tensor"
|
|||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
default = ["std", "repr"]
|
||||
doc = ["default"]
|
||||
experimental-named-tensor = []
|
||||
export_tests = ["burn-tensor-testgen"]
|
||||
std = ["rand/std", "half/std", "num-traits/std"]
|
||||
repr = []
|
||||
wasm-sync = []
|
||||
|
||||
[dependencies]
|
||||
|
|
|
@ -11,6 +11,10 @@ extern crate alloc;
|
|||
|
||||
mod tensor;
|
||||
|
||||
/// Burn Tensor representaton
|
||||
#[cfg(feature = "repr")]
|
||||
pub mod repr;
|
||||
|
||||
#[cfg(feature = "export_tests")]
|
||||
#[allow(missing_docs)]
|
||||
mod tests;
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
use crate::{
|
||||
backend::Backend,
|
||||
ops::{BoolTensor, FloatTensor, IntTensor},
|
||||
Shape,
|
||||
};
|
||||
|
||||
/// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation
|
||||
/// for compilation purpose or other...
|
||||
pub trait ReprBackend: Backend {
|
||||
/// The type that can be used to point to a tensor of any kind.
|
||||
type Handle: Sync + Send + Clone;
|
||||
|
||||
/// Convert a [handle](ReprBackend::Handle) to a [float tensor](Backend::FloatTensorPrimitive).
|
||||
fn float_tensor<const D: usize>(handle: Self::Handle, shape: Shape<D>) -> FloatTensor<Self, D>;
|
||||
/// Convert a [handle](ReprBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive).
|
||||
fn int_tensor<const D: usize>(handle: Self::Handle, shape: Shape<D>) -> IntTensor<Self, D>;
|
||||
/// Convert a [handle](ReprBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive).
|
||||
fn bool_tensor<const D: usize>(handle: Self::Handle, shape: Shape<D>) -> BoolTensor<Self, D>;
|
||||
|
||||
/// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](ReprBackend::Handle).
|
||||
fn float_tensor_handle<const D: usize>(tensor: FloatTensor<Self, D>) -> Self::Handle;
|
||||
/// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](ReprBackend::Handle).
|
||||
fn int_tensor_handle<const D: usize>(tensor: IntTensor<Self, D>) -> Self::Handle;
|
||||
/// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](ReprBackend::Handle).
|
||||
fn bool_tensor_handle<const D: usize>(tensor: BoolTensor<Self, D>) -> Self::Handle;
|
||||
}
|
|
@ -1,30 +1,41 @@
|
|||
use crate::{FusionBackend, TensorDescription, TensorId, TensorStatus};
|
||||
use burn_tensor::Shape;
|
||||
use crate::{
|
||||
backend::Backend,
|
||||
repr::{
|
||||
backend::ReprBackend,
|
||||
tensor::{TensorDescription, TensorId, TensorStatus},
|
||||
},
|
||||
Shape,
|
||||
};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
/// Keep all [tensor handles](FusionBackend::Handle) in one place and ensure that all resources
|
||||
/// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources
|
||||
/// are used optimally.
|
||||
#[derive(Default)]
|
||||
pub struct HandleContainer<B: FusionBackend> {
|
||||
pub struct HandleContainer<B: ReprBackend> {
|
||||
handles: HashMap<TensorId, Handle<B>>,
|
||||
counter: u64,
|
||||
pub(crate) handles_orphan: Vec<TensorId>,
|
||||
/// Handle candidates to be freed.
|
||||
pub handles_orphan: Vec<TensorId>,
|
||||
/// The device on which all tensors are held.
|
||||
pub device: B::Device,
|
||||
}
|
||||
|
||||
enum Handle<B: FusionBackend> {
|
||||
/// Backend [tensor handle](ReprBackend::Handle) wrapper tracking their creation state
|
||||
pub enum Handle<B: Backend + ReprBackend> {
|
||||
/// No [tensor handle](ReprBackend::Handle) has been created yet
|
||||
NotInit,
|
||||
/// A [tensor handle](ReprBackend::Handle) has been created
|
||||
Existing(B::Handle),
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> HandleContainer<B> {
|
||||
pub(crate) fn new(device_handle: B::FusionDevice) -> Self {
|
||||
impl<B: ReprBackend> HandleContainer<B> {
|
||||
/// Create a new HandleContainer
|
||||
pub fn new(device_handle: B::Device) -> Self {
|
||||
Self {
|
||||
handles: HashMap::new(),
|
||||
handles_orphan: Vec::new(),
|
||||
counter: 0,
|
||||
device: device_handle.clone().into(),
|
||||
device: device_handle.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -59,69 +70,69 @@ impl<B: FusionBackend> HandleContainer<B> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Get the [float tensor](burn_tensor::backend::Backend::FloatTensorPrimitive) corresponding to the
|
||||
/// Get the [float tensor](Backend::FloatTensorPrimitive) corresponding to the
|
||||
/// given [tensor description](TensorDescription).
|
||||
pub fn get_float_tensor<const D: usize>(
|
||||
&mut self,
|
||||
tensor: &TensorDescription,
|
||||
) -> B::FloatTensorPrimitive<D> {
|
||||
B::float_tensor(
|
||||
B::float_tensor::<D>(
|
||||
self.get_handle(&tensor.id, &tensor.status),
|
||||
Shape::from(&tensor.shape),
|
||||
)
|
||||
}
|
||||
|
||||
/// Get the [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) corresponding to the
|
||||
/// Get the [int tensor](Backend::IntTensorPrimitive) corresponding to the
|
||||
/// given [tensor description](TensorDescription).
|
||||
pub fn get_int_tensor<const D: usize>(
|
||||
&mut self,
|
||||
tensor: &TensorDescription,
|
||||
) -> B::IntTensorPrimitive<D> {
|
||||
B::int_tensor(
|
||||
B::int_tensor::<D>(
|
||||
self.get_handle(&tensor.id, &tensor.status),
|
||||
Shape::from(&tensor.shape),
|
||||
)
|
||||
}
|
||||
|
||||
/// Get the [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) corresponding to the
|
||||
/// Get the [bool tensor](Backend::BoolTensorPrimitive) corresponding to the
|
||||
/// given [tensor description](TensorDescription).
|
||||
pub fn get_bool_tensor<const D: usize>(
|
||||
&mut self,
|
||||
tensor: &TensorDescription,
|
||||
) -> B::BoolTensorPrimitive<D> {
|
||||
B::bool_tensor(
|
||||
B::bool_tensor::<D>(
|
||||
self.get_handle(&tensor.id, &tensor.status),
|
||||
Shape::from(&tensor.shape),
|
||||
)
|
||||
}
|
||||
|
||||
/// Register a new [float tensor](burn_tensor::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId).
|
||||
/// Register a new [float tensor](Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId).
|
||||
pub fn register_float_tensor<const D: usize>(
|
||||
&mut self,
|
||||
id: &TensorId,
|
||||
tensor: B::FloatTensorPrimitive<D>,
|
||||
) {
|
||||
let handle = B::float_tensor_handle(tensor);
|
||||
let handle = B::float_tensor_handle::<D>(tensor);
|
||||
self.handles.insert(*id, Handle::Existing(handle));
|
||||
}
|
||||
|
||||
/// Register a new [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId).
|
||||
/// Register a new [int tensor](Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId).
|
||||
pub fn register_int_tensor<const D: usize>(
|
||||
&mut self,
|
||||
id: &TensorId,
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
) {
|
||||
let handle = B::int_tensor_handle(tensor);
|
||||
let handle = B::int_tensor_handle::<D>(tensor);
|
||||
self.handles.insert(*id, Handle::Existing(handle));
|
||||
}
|
||||
|
||||
/// Register a new [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId).
|
||||
/// Register a new [bool tensor](Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId).
|
||||
pub fn register_bool_tensor<const D: usize>(
|
||||
&mut self,
|
||||
id: &TensorId,
|
||||
tensor: B::BoolTensorPrimitive<D>,
|
||||
) {
|
||||
let handle = B::bool_tensor_handle(tensor);
|
||||
let handle = B::bool_tensor_handle::<D>(tensor);
|
||||
self.handles.insert(*id, Handle::Existing(handle));
|
||||
}
|
||||
|
||||
|
@ -134,7 +145,8 @@ impl<B: FusionBackend> HandleContainer<B> {
|
|||
Arc::new(id)
|
||||
}
|
||||
|
||||
pub(crate) fn free(&mut self, tensor: &TensorDescription) {
|
||||
/// Remove tensor handle from container if writable
|
||||
pub fn free(&mut self, tensor: &TensorDescription) {
|
||||
match tensor.status {
|
||||
TensorStatus::ReadOnly => (),
|
||||
TensorStatus::NotInit => (),
|
||||
|
@ -144,7 +156,8 @@ impl<B: FusionBackend> HandleContainer<B> {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn free_orphans(&mut self, remaining: &[&TensorId]) {
|
||||
/// Remove tensor handle from container if not in use
|
||||
pub fn free_orphans(&mut self, remaining: &[&TensorId]) {
|
||||
let mut handles_orphan = Vec::new();
|
||||
|
||||
// TODO: Optimization => Change the for loop order depending of the length of each.
|
|
@ -0,0 +1,9 @@
|
|||
mod backend;
|
||||
mod handle;
|
||||
mod operation;
|
||||
mod tensor;
|
||||
|
||||
pub use backend::*;
|
||||
pub use handle::*;
|
||||
pub use operation::*;
|
||||
pub use tensor::*;
|
|
@ -1,15 +1,11 @@
|
|||
use crate::FusionBackend;
|
||||
use crate::{HandleContainer, TensorDescription};
|
||||
use burn_tensor::ops::{ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions};
|
||||
use burn_tensor::{Distribution, Element};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ops::Range;
|
||||
|
||||
/// General trait to abstract how a single operation is executed.
|
||||
pub trait Operation<B: FusionBackend>: Send + Sync {
|
||||
/// Execute the operation.
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>);
|
||||
}
|
||||
use crate::{
|
||||
ops::{ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions},
|
||||
repr::tensor::TensorDescription,
|
||||
Distribution, Element,
|
||||
};
|
||||
|
||||
/// Describe all tensor operations possible.
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
|
@ -37,92 +33,92 @@ pub enum OperationDescription {
|
|||
/// Operation description specific to a float tensor.
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
pub enum FloatOperationDescription {
|
||||
/// Operation corresponding to [exp](burn_tensor::ops::FloatTensorOps::float_exp).
|
||||
/// Operation corresponding to [exp](crate::ops::FloatTensorOps::float_exp).
|
||||
Exp(UnaryOperationDescription),
|
||||
/// Operation corresponding to [log](burn_tensor::ops::FloatTensorOps::float_log).
|
||||
/// Operation corresponding to [log](crate::ops::FloatTensorOps::float_log).
|
||||
Log(UnaryOperationDescription),
|
||||
/// Operation corresponding to [log1p](burn_tensor::ops::FloatTensorOps::float_log1p).
|
||||
/// Operation corresponding to [log1p](crate::ops::FloatTensorOps::float_log1p).
|
||||
Log1p(UnaryOperationDescription),
|
||||
/// Operation corresponding to [erf](burn_tensor::ops::FloatTensorOps::float_erf).
|
||||
/// Operation corresponding to [erf](crate::ops::FloatTensorOps::float_erf).
|
||||
Erf(UnaryOperationDescription),
|
||||
/// Operation corresponding to [powf_scalar](burn_tensor::ops::FloatTensorOps::float_powf_scalar).
|
||||
/// Operation corresponding to [powf_scalar](crate::ops::FloatTensorOps::float_powf_scalar).
|
||||
PowfScalar(ScalarOperationDescription<f32>),
|
||||
/// Operation corresponding to [sqrt](burn_tensor::ops::FloatTensorOps::float_sqrt).
|
||||
/// Operation corresponding to [sqrt](crate::ops::FloatTensorOps::float_sqrt).
|
||||
Sqrt(UnaryOperationDescription),
|
||||
/// Operation corresponding to [cos](burn_tensor::ops::FloatTensorOps::float_cos).
|
||||
/// Operation corresponding to [cos](crate::ops::FloatTensorOps::float_cos).
|
||||
Cos(UnaryOperationDescription),
|
||||
/// Operation corresponding to [sin](burn_tensor::ops::FloatTensorOps::float_sin).
|
||||
/// Operation corresponding to [sin](crate::ops::FloatTensorOps::float_sin).
|
||||
Sin(UnaryOperationDescription),
|
||||
/// Operation corresponding to [tanh](burn_tensor::ops::FloatTensorOps::float_tanh).
|
||||
/// Operation corresponding to [tanh](crate::ops::FloatTensorOps::float_tanh).
|
||||
Tanh(UnaryOperationDescription),
|
||||
/// Operation corresponding to [into_int](burn_tensor::ops::FloatTensorOps::float_into_int).
|
||||
/// Operation corresponding to [into_int](crate::ops::FloatTensorOps::float_into_int).
|
||||
IntoInt(UnaryOperationDescription),
|
||||
/// Operation corresponding to [matmul](burn_tensor::ops::FloatTensorOps::float_matmul).
|
||||
/// Operation corresponding to [matmul](crate::ops::FloatTensorOps::float_matmul).
|
||||
Matmul(BinaryOperationDescription),
|
||||
/// Operation corresponding to [random](burn_tensor::ops::FloatTensorOps::float_random).
|
||||
/// Operation corresponding to [random](crate::ops::FloatTensorOps::float_random).
|
||||
Random(RandomOperationDescription),
|
||||
/// Operation corresponding to [recip](burn_tensor::ops::FloatTensorOps::float_recip).
|
||||
/// Operation corresponding to [recip](crate::ops::FloatTensorOps::float_recip).
|
||||
Recip(UnaryOperationDescription),
|
||||
}
|
||||
|
||||
/// Operation description specific to module.
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
pub enum ModuleOperationDescription {
|
||||
/// Operation corresponding to [embedding](burn_tensor::ops::ModuleOps::embedding).
|
||||
/// Operation corresponding to [embedding](crate::ops::ModuleOps::embedding).
|
||||
Embedding(EmbeddingDescription),
|
||||
/// Operation corresponding to [embedding_backward](burn_tensor::ops::ModuleOps::embedding_backward).
|
||||
/// Operation corresponding to [embedding_backward](crate::ops::ModuleOps::embedding_backward).
|
||||
EmbeddingBackward(EmbeddingBackwardDescription),
|
||||
/// Operation corresponding to [conv1d](burn_tensor::ops::ModuleOps::conv1d).
|
||||
/// Operation corresponding to [conv1d](crate::ops::ModuleOps::conv1d).
|
||||
Conv1d(Conv1dDescription),
|
||||
/// Operation corresponding to [conv2d](burn_tensor::ops::ModuleOps::conv2d).
|
||||
/// Operation corresponding to [conv2d](crate::ops::ModuleOps::conv2d).
|
||||
Conv2d(Conv2dDescription),
|
||||
/// Operation corresponding to [conv transpose 1d](burn_tensor::ops::ModuleOps::conv_transpose1d).
|
||||
/// Operation corresponding to [conv transpose 1d](crate::ops::ModuleOps::conv_transpose1d).
|
||||
ConvTranspose1d(ConvTranspose1dDescription),
|
||||
/// Operation corresponding to [conv transpose 2d](burn_tensor::ops::ModuleOps::conv_transpose2d).
|
||||
/// Operation corresponding to [conv transpose 2d](crate::ops::ModuleOps::conv_transpose2d).
|
||||
ConvTranspose2d(ConvTranspose2dDescription),
|
||||
/// Operation corresponding to [avg pool 1d](burn_tensor::ops::ModuleOps::avg_pool1d).
|
||||
/// Operation corresponding to [avg pool 1d](crate::ops::ModuleOps::avg_pool1d).
|
||||
AvgPool1d(AvgPool1dDescription),
|
||||
/// Operation corresponding to [avg pool 2d](burn_tensor::ops::ModuleOps::avg_pool2d).
|
||||
/// Operation corresponding to [avg pool 2d](crate::ops::ModuleOps::avg_pool2d).
|
||||
AvgPool2d(AvgPool2dDescription),
|
||||
/// Operation corresponding to
|
||||
/// [avg pool 1d backward](burn_tensor::ops::ModuleOps::avg_pool1d_backward).
|
||||
/// [avg pool 1d backward](crate::ops::ModuleOps::avg_pool1d_backward).
|
||||
AvgPool1dBackward(AvgPool1dBackwardDescription),
|
||||
/// Operation corresponding to
|
||||
/// [avg pool 2d backward](burn_tensor::ops::ModuleOps::avg_pool2d_backward).
|
||||
/// [avg pool 2d backward](crate::ops::ModuleOps::avg_pool2d_backward).
|
||||
AvgPool2dBackward(AvgPool2dBackwardDescription),
|
||||
/// Operation corresponding to
|
||||
/// [adaptive avg pool 1d](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d).
|
||||
/// [adaptive avg pool 1d](crate::ops::ModuleOps::adaptive_avg_pool1d).
|
||||
AdaptiveAvgPool1d(AdaptiveAvgPool1dDescription),
|
||||
/// Operation corresponding to
|
||||
/// [adaptive avg pool 2d](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d).
|
||||
/// [adaptive avg pool 2d](crate::ops::ModuleOps::adaptive_avg_pool2d).
|
||||
AdaptiveAvgPool2d(AdaptiveAvgPool2dDescription),
|
||||
/// Operation corresponding to
|
||||
/// [adaptive avg pool 1d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d_backward).
|
||||
/// [adaptive avg pool 1d backward](crate::ops::ModuleOps::adaptive_avg_pool1d_backward).
|
||||
AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardDescription),
|
||||
/// Operation corresponding to
|
||||
/// [adaptive avg pool 2d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d_backward).
|
||||
/// [adaptive avg pool 2d backward](crate::ops::ModuleOps::adaptive_avg_pool2d_backward).
|
||||
AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardDescription),
|
||||
/// Operation corresponding to
|
||||
/// [max pool 1d](burn_tensor::ops::ModuleOps::max_pool1d).
|
||||
/// [max pool 1d](crate::ops::ModuleOps::max_pool1d).
|
||||
MaxPool1d(MaxPool1dDescription),
|
||||
/// Operation corresponding to
|
||||
/// [max pool 1d with indices](burn_tensor::ops::ModuleOps::max_pool1d_with_indices).
|
||||
/// [max pool 1d with indices](crate::ops::ModuleOps::max_pool1d_with_indices).
|
||||
MaxPool1dWithIndices(MaxPool1dWithIndicesDescription),
|
||||
/// Operation corresponding to
|
||||
/// [max pool 1d with indices backward](burn_tensor::ops::ModuleOps::max_pool1d_with_indices_backward).
|
||||
/// [max pool 1d with indices backward](crate::ops::ModuleOps::max_pool1d_with_indices_backward).
|
||||
MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardDescription),
|
||||
/// Operation corresponding to
|
||||
/// [max pool 2d](burn_tensor::ops::ModuleOps::max_pool1d).
|
||||
/// [max pool 2d](crate::ops::ModuleOps::max_pool1d).
|
||||
MaxPool2d(MaxPool2dDescription),
|
||||
/// Operation corresponding to
|
||||
/// [max pool 2d with indices](burn_tensor::ops::ModuleOps::max_pool2d_with_indices).
|
||||
/// [max pool 2d with indices](crate::ops::ModuleOps::max_pool2d_with_indices).
|
||||
MaxPool2dWithIndices(MaxPool2dWithIndicesDescription),
|
||||
/// Operation corresponding to
|
||||
/// [max pool 2d with indices backward](burn_tensor::ops::ModuleOps::max_pool2d_with_indices_backward).
|
||||
/// [max pool 2d with indices backward](crate::ops::ModuleOps::max_pool2d_with_indices_backward).
|
||||
MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardDescription),
|
||||
/// Operation corresponding to [interpolate](burn_tensor::ops::ModuleOps::interpolate).
|
||||
/// Operation corresponding to [interpolate](crate::ops::ModuleOps::interpolate).
|
||||
Interpolate(InterpolateDescription),
|
||||
/// Operation corresponding to [interpolate backward](burn_tensor::ops::ModuleOps::interpolate_backward).
|
||||
/// Operation corresponding to [interpolate backward](crate::ops::ModuleOps::interpolate_backward).
|
||||
InterpolateBackward(InterpolateBackwardDescription),
|
||||
}
|
||||
|
||||
|
@ -131,73 +127,73 @@ pub enum ModuleOperationDescription {
|
|||
pub enum BaseOperationDescription {
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [to device](burn_tensor::ops::FloatTensorOps::float_to_device).
|
||||
/// Int => [to device](burn_tensor::ops::IntTensorOps::int_to_device).
|
||||
/// Bool => [to device](burn_tensor::ops::BoolTensorOps::bool_to_device).
|
||||
/// Float => [to device](crate::ops::FloatTensorOps::float_to_device).
|
||||
/// Int => [to device](crate::ops::IntTensorOps::int_to_device).
|
||||
/// Bool => [to device](crate::ops::BoolTensorOps::bool_to_device).
|
||||
ToDevice(TensorDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [reshape](burn_tensor::ops::FloatTensorOps::float_reshape).
|
||||
/// Int => [reshape](burn_tensor::ops::IntTensorOps::int_reshape).
|
||||
/// Bool => [reshape](burn_tensor::ops::BoolTensorOps::bool_reshape).
|
||||
/// Float => [reshape](crate::ops::FloatTensorOps::float_reshape).
|
||||
/// Int => [reshape](crate::ops::IntTensorOps::int_reshape).
|
||||
/// Bool => [reshape](crate::ops::BoolTensorOps::bool_reshape).
|
||||
Reshape(ReshapeDescription),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [swap_dims](burn_tensor::ops::FloatTensorOps::float_swap_dims).
|
||||
/// Int => [swap_dims](burn_tensor::ops::IntTensorOps::int_swap_dims).
|
||||
/// Bool => [swap_dims](burn_tensor::ops::BoolTensorOps::bool_swap_dims).
|
||||
/// Float => [swap_dims](crate::ops::FloatTensorOps::float_swap_dims).
|
||||
/// Int => [swap_dims](crate::ops::IntTensorOps::int_swap_dims).
|
||||
/// Bool => [swap_dims](crate::ops::BoolTensorOps::bool_swap_dims).
|
||||
SwapDims(SwapDimsDescription),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [permute](burn_tensor::ops::FloatTensorOps::float_permute).
|
||||
/// Int => [permute](burn_tensor::ops::IntTensorOps::int_permute).
|
||||
/// Bool => [permute](burn_tensor::ops::BoolTensorOps::bool_permute).
|
||||
/// Float => [permute](crate::ops::FloatTensorOps::float_permute).
|
||||
/// Int => [permute](crate::ops::IntTensorOps::int_permute).
|
||||
/// Bool => [permute](crate::ops::BoolTensorOps::bool_permute).
|
||||
Permute(PermuteOperationDescription),
|
||||
|
||||
/// Operation corresponding to:
|
||||
/// Float => [flip](burn_tensor::ops::FloatTensorOps::float_flip).
|
||||
/// Int => [flip](burn_tensor::ops::IntTensorOps::int_flip).
|
||||
/// Bool => [flip](burn_tensor::ops::BoolTensorOps::bool_flip).
|
||||
/// Float => [flip](crate::ops::FloatTensorOps::float_flip).
|
||||
/// Int => [flip](crate::ops::IntTensorOps::int_flip).
|
||||
/// Bool => [flip](crate::ops::BoolTensorOps::bool_flip).
|
||||
Flip(FlipOperationDescription),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [expand](burn_tensor::ops::FloatTensorOps::float_expand).
|
||||
/// Int => [expand](burn_tensor::ops::IntTensorOps::int_expand).
|
||||
/// Bool => [expand](burn_tensor::ops::BoolTensorOps::bool_expand).
|
||||
/// Float => [expand](crate::ops::FloatTensorOps::float_expand).
|
||||
/// Int => [expand](crate::ops::IntTensorOps::int_expand).
|
||||
/// Bool => [expand](crate::ops::BoolTensorOps::bool_expand).
|
||||
Expand(ExpandOperationDescription),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [slice](burn_tensor::ops::FloatTensorOps::float_slice).
|
||||
/// Int => [slice](burn_tensor::ops::IntTensorOps::int_slice).
|
||||
/// Bool => [slice](burn_tensor::ops::BoolTensorOps::bool_slice).
|
||||
/// Float => [slice](crate::ops::FloatTensorOps::float_slice).
|
||||
/// Int => [slice](crate::ops::IntTensorOps::int_slice).
|
||||
/// Bool => [slice](crate::ops::BoolTensorOps::bool_slice).
|
||||
Slice(SliceOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [slice assign](burn_tensor::ops::FloatTensorOps::float_slice_assign).
|
||||
/// Int => [slice assign](burn_tensor::ops::IntTensorOps::int_slice_assign).
|
||||
/// Bool => [slice assign](burn_tensor::ops::BoolTensorOps::bool_slice_assign).
|
||||
/// Float => [slice assign](crate::ops::FloatTensorOps::float_slice_assign).
|
||||
/// Int => [slice assign](crate::ops::IntTensorOps::int_slice_assign).
|
||||
/// Bool => [slice assign](crate::ops::BoolTensorOps::bool_slice_assign).
|
||||
SliceAssign(SliceAssignOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [equal](burn_tensor::ops::FloatTensorOps::float_equal).
|
||||
/// Int => [equal](burn_tensor::ops::IntTensorOps::int_equal).
|
||||
/// Bool => [equal](burn_tensor::ops::BoolTensorOps::bool_equal).
|
||||
/// Float => [equal](crate::ops::FloatTensorOps::float_equal).
|
||||
/// Int => [equal](crate::ops::IntTensorOps::int_equal).
|
||||
/// Bool => [equal](crate::ops::BoolTensorOps::bool_equal).
|
||||
Equal(BinaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [repeat](burn_tensor::ops::FloatTensorOps::float_repeat).
|
||||
/// Int => [repeat](burn_tensor::ops::IntTensorOps::int_repeat).
|
||||
/// Bool => [repeat](burn_tensor::ops::BoolTensorOps::bool_repeat).
|
||||
/// Float => [repeat](crate::ops::FloatTensorOps::float_repeat).
|
||||
/// Int => [repeat](crate::ops::IntTensorOps::int_repeat).
|
||||
/// Bool => [repeat](crate::ops::BoolTensorOps::bool_repeat).
|
||||
Repeat(RepeatOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [cat](burn_tensor::ops::FloatTensorOps::float_cat).
|
||||
/// Int => [cat](burn_tensor::ops::IntTensorOps::int_cat).
|
||||
/// Bool => [cat](burn_tensor::ops::BoolTensorOps::bool_cat).
|
||||
/// Float => [cat](crate::ops::FloatTensorOps::float_cat).
|
||||
/// Int => [cat](crate::ops::IntTensorOps::int_cat).
|
||||
/// Bool => [cat](crate::ops::BoolTensorOps::bool_cat).
|
||||
Cat(CatOperationDescription),
|
||||
}
|
||||
|
||||
|
@ -206,248 +202,248 @@ pub enum BaseOperationDescription {
|
|||
pub enum NumericOperationDescription<E> {
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [add](burn_tensor::ops::FloatTensorOps::float_add).
|
||||
/// Int => [add](burn_tensor::ops::IntTensorOps::int_add).
|
||||
/// Float => [add](crate::ops::FloatTensorOps::float_add).
|
||||
/// Int => [add](crate::ops::IntTensorOps::int_add).
|
||||
Add(BinaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [add scalar](burn_tensor::ops::FloatTensorOps::float_add_scalar).
|
||||
/// Int => [add scalar](burn_tensor::ops::IntTensorOps::int_add_scalar).
|
||||
/// Float => [add scalar](crate::ops::FloatTensorOps::float_add_scalar).
|
||||
/// Int => [add scalar](crate::ops::IntTensorOps::int_add_scalar).
|
||||
AddScalar(ScalarOperationDescription<E>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [sub](burn_tensor::ops::FloatTensorOps::float_sub).
|
||||
/// Int => [sub](burn_tensor::ops::IntTensorOps::int_sub).
|
||||
/// Float => [sub](crate::ops::FloatTensorOps::float_sub).
|
||||
/// Int => [sub](crate::ops::IntTensorOps::int_sub).
|
||||
Sub(BinaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [sub scalar](burn_tensor::ops::FloatTensorOps::float_sub_scalar).
|
||||
/// Int => [sub scalar](burn_tensor::ops::IntTensorOps::int_sub_scalar).
|
||||
/// Float => [sub scalar](crate::ops::FloatTensorOps::float_sub_scalar).
|
||||
/// Int => [sub scalar](crate::ops::IntTensorOps::int_sub_scalar).
|
||||
SubScalar(ScalarOperationDescription<E>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [div](burn_tensor::ops::FloatTensorOps::float_div).
|
||||
/// Int => [div](burn_tensor::ops::IntTensorOps::int_div).
|
||||
/// Float => [div](crate::ops::FloatTensorOps::float_div).
|
||||
/// Int => [div](crate::ops::IntTensorOps::int_div).
|
||||
Div(BinaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [div scalar](burn_tensor::ops::FloatTensorOps::float_div_scalar).
|
||||
/// Int => [div scalar](burn_tensor::ops::IntTensorOps::int_div_scalar).
|
||||
/// Float => [div scalar](crate::ops::FloatTensorOps::float_div_scalar).
|
||||
/// Int => [div scalar](crate::ops::IntTensorOps::int_div_scalar).
|
||||
DivScalar(ScalarOperationDescription<E>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [div](burn_tensor::ops::FloatTensorOps::float_remainder_scalar).
|
||||
/// Int => [div](burn_tensor::ops::IntTensorOps::int_remainder_scalar).
|
||||
/// Float => [div](crate::ops::FloatTensorOps::float_remainder_scalar).
|
||||
/// Int => [div](crate::ops::IntTensorOps::int_remainder_scalar).
|
||||
RemScalar(ScalarOperationDescription<E>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [mul](burn_tensor::ops::FloatTensorOps::float_mul).
|
||||
/// Int => [mul](burn_tensor::ops::IntTensorOps::int_mul).
|
||||
/// Float => [mul](crate::ops::FloatTensorOps::float_mul).
|
||||
/// Int => [mul](crate::ops::IntTensorOps::int_mul).
|
||||
Mul(BinaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [mul scalar](burn_tensor::ops::FloatTensorOps::float_mul_scalar).
|
||||
/// Int => [mul scalar](burn_tensor::ops::IntTensorOps::int_mul_scalar).
|
||||
/// Float => [mul scalar](crate::ops::FloatTensorOps::float_mul_scalar).
|
||||
/// Int => [mul scalar](crate::ops::IntTensorOps::int_mul_scalar).
|
||||
MulScalar(ScalarOperationDescription<E>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [abs](burn_tensor::ops::FloatTensorOps::float_abs).
|
||||
/// Int => [abs](burn_tensor::ops::IntTensorOps::int_abs).
|
||||
/// Float => [abs](crate::ops::FloatTensorOps::float_abs).
|
||||
/// Int => [abs](crate::ops::IntTensorOps::int_abs).
|
||||
Abs(UnaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [ones](burn_tensor::ops::FloatTensorOps::float_ones).
|
||||
/// Int => [ones](burn_tensor::ops::IntTensorOps::int_ones).
|
||||
/// Float => [ones](crate::ops::FloatTensorOps::float_ones).
|
||||
/// Int => [ones](crate::ops::IntTensorOps::int_ones).
|
||||
Ones(TensorDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [zeros](burn_tensor::ops::FloatTensorOps::float_zeros).
|
||||
/// Int => [zeros](burn_tensor::ops::IntTensorOps::int_zeros).
|
||||
/// Float => [zeros](crate::ops::FloatTensorOps::float_zeros).
|
||||
/// Int => [zeros](crate::ops::IntTensorOps::int_zeros).
|
||||
Zeros(TensorDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [full](burn_tensor::ops::FloatTensorOps::float_full).
|
||||
/// Int => [full](burn_tensor::ops::IntTensorOps::int_full).
|
||||
/// Float => [full](crate::ops::FloatTensorOps::float_full).
|
||||
/// Int => [full](crate::ops::IntTensorOps::int_full).
|
||||
Full((TensorDescription, E)),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [gather](burn_tensor::ops::FloatTensorOps::float_gather).
|
||||
/// Int => [gather](burn_tensor::ops::IntTensorOps::int_gather).
|
||||
/// Float => [gather](crate::ops::FloatTensorOps::float_gather).
|
||||
/// Int => [gather](crate::ops::IntTensorOps::int_gather).
|
||||
Gather(GatherOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [scatter](burn_tensor::ops::FloatTensorOps::float_scatter).
|
||||
/// Int => [scatter](burn_tensor::ops::IntTensorOps::int_scatter).
|
||||
/// Float => [scatter](crate::ops::FloatTensorOps::float_scatter).
|
||||
/// Int => [scatter](crate::ops::IntTensorOps::int_scatter).
|
||||
Scatter(ScatterOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [select](burn_tensor::ops::FloatTensorOps::float_select).
|
||||
/// Int => [select](burn_tensor::ops::IntTensorOps::int_select).
|
||||
/// Float => [select](crate::ops::FloatTensorOps::float_select).
|
||||
/// Int => [select](crate::ops::IntTensorOps::int_select).
|
||||
Select(SelectOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [select assign](burn_tensor::ops::FloatTensorOps::float_select_assign).
|
||||
/// Int => [select assign](burn_tensor::ops::IntTensorOps::int_select_assign).
|
||||
/// Float => [select assign](crate::ops::FloatTensorOps::float_select_assign).
|
||||
/// Int => [select assign](crate::ops::IntTensorOps::int_select_assign).
|
||||
SelectAssign(SelectAssignOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [mask where](burn_tensor::ops::FloatTensorOps::float_mask_where).
|
||||
/// Int => [mask where](burn_tensor::ops::IntTensorOps::int_mask_where).
|
||||
/// Float => [mask where](crate::ops::FloatTensorOps::float_mask_where).
|
||||
/// Int => [mask where](crate::ops::IntTensorOps::int_mask_where).
|
||||
MaskWhere(MaskWhereOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [mask fill](burn_tensor::ops::FloatTensorOps::float_mask_fill).
|
||||
/// Int => [mask fill](burn_tensor::ops::IntTensorOps::int_mask_fill).
|
||||
/// Float => [mask fill](crate::ops::FloatTensorOps::float_mask_fill).
|
||||
/// Int => [mask fill](crate::ops::IntTensorOps::int_mask_fill).
|
||||
MaskFill(MaskFillOperationDescription<E>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [mean dim](burn_tensor::ops::FloatTensorOps::float_mean_dim).
|
||||
/// Int => [mean dim](burn_tensor::ops::IntTensorOps::int_mean_dim).
|
||||
/// Float => [mean dim](crate::ops::FloatTensorOps::float_mean_dim).
|
||||
/// Int => [mean dim](crate::ops::IntTensorOps::int_mean_dim).
|
||||
MeanDim(ScalarOperationDescription<usize>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [mean](burn_tensor::ops::FloatTensorOps::float_mean).
|
||||
/// Int => [mean](burn_tensor::ops::IntTensorOps::int_mean).
|
||||
/// Float => [mean](crate::ops::FloatTensorOps::float_mean).
|
||||
/// Int => [mean](crate::ops::IntTensorOps::int_mean).
|
||||
Mean(UnaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [sum](burn_tensor::ops::FloatTensorOps::float_sum).
|
||||
/// Int => [sum](burn_tensor::ops::IntTensorOps::int_sum).
|
||||
/// Float => [sum](crate::ops::FloatTensorOps::float_sum).
|
||||
/// Int => [sum](crate::ops::IntTensorOps::int_sum).
|
||||
Sum(UnaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [sum dim](burn_tensor::ops::FloatTensorOps::float_sum_dim).
|
||||
/// Int => [sum dim](burn_tensor::ops::IntTensorOps::int_sum_dim).
|
||||
/// Float => [sum dim](crate::ops::FloatTensorOps::float_sum_dim).
|
||||
/// Int => [sum dim](crate::ops::IntTensorOps::int_sum_dim).
|
||||
SumDim(ScalarOperationDescription<usize>),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [prod](burn_tensor::ops::FloatTensorOps::float_prod).
|
||||
/// Int => [prod](burn_tensor::ops::IntTensorOps::int_prod).
|
||||
/// Float => [prod](crate::ops::FloatTensorOps::float_prod).
|
||||
/// Int => [prod](crate::ops::IntTensorOps::int_prod).
|
||||
Prod(UnaryOperationDescription),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [prod dim](burn_tensor::ops::FloatTensorOps::float_prod_dim).
|
||||
/// Int => [prod dim](burn_tensor::ops::IntTensorOps::int_prod_dim).
|
||||
/// Float => [prod dim](crate::ops::FloatTensorOps::float_prod_dim).
|
||||
/// Int => [prod dim](crate::ops::IntTensorOps::int_prod_dim).
|
||||
ProdDim(ScalarOperationDescription<usize>),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [equal elem](burn_tensor::ops::FloatTensorOps::float_equal_elem).
|
||||
/// Int => [equal elem](burn_tensor::ops::IntTensorOps::int_equal_elem).
|
||||
/// Float => [equal elem](crate::ops::FloatTensorOps::float_equal_elem).
|
||||
/// Int => [equal elem](crate::ops::IntTensorOps::int_equal_elem).
|
||||
EqualElem(ScalarOperationDescription<E>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [greater](burn_tensor::ops::FloatTensorOps::float_greater).
|
||||
/// Int => [greater](burn_tensor::ops::IntTensorOps::int_greater).
|
||||
/// Float => [greater](crate::ops::FloatTensorOps::float_greater).
|
||||
/// Int => [greater](crate::ops::IntTensorOps::int_greater).
|
||||
Greater(BinaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [greater elem](burn_tensor::ops::FloatTensorOps::float_greater_elem).
|
||||
/// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem).
|
||||
/// Float => [greater elem](crate::ops::FloatTensorOps::float_greater_elem).
|
||||
/// Int => [greater elem](crate::ops::IntTensorOps::int_greater_elem).
|
||||
GreaterElem(ScalarOperationDescription<E>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [greater equal](burn_tensor::ops::FloatTensorOps::float_greater_elem).
|
||||
/// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem).
|
||||
/// Float => [greater equal](crate::ops::FloatTensorOps::float_greater_elem).
|
||||
/// Int => [greater elem](crate::ops::IntTensorOps::int_greater_elem).
|
||||
GreaterEqual(BinaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [greater equal elem](burn_tensor::ops::FloatTensorOps::float_greater_equal_elem).
|
||||
/// Int => [greater equal elem](burn_tensor::ops::IntTensorOps::int_greater_equal_elem).
|
||||
/// Float => [greater equal elem](crate::ops::FloatTensorOps::float_greater_equal_elem).
|
||||
/// Int => [greater equal elem](crate::ops::IntTensorOps::int_greater_equal_elem).
|
||||
GreaterEqualElem(ScalarOperationDescription<E>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [lower](burn_tensor::ops::FloatTensorOps::float_lower).
|
||||
/// Int => [lower](burn_tensor::ops::IntTensorOps::int_lower).
|
||||
/// Float => [lower](crate::ops::FloatTensorOps::float_lower).
|
||||
/// Int => [lower](crate::ops::IntTensorOps::int_lower).
|
||||
Lower(BinaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [lower elem](burn_tensor::ops::FloatTensorOps::float_lower_elem).
|
||||
/// Int => [lower elem](burn_tensor::ops::IntTensorOps::int_lower_elem).
|
||||
/// Float => [lower elem](crate::ops::FloatTensorOps::float_lower_elem).
|
||||
/// Int => [lower elem](crate::ops::IntTensorOps::int_lower_elem).
|
||||
LowerElem(ScalarOperationDescription<E>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [lower equal](burn_tensor::ops::FloatTensorOps::float_lower_equal).
|
||||
/// Int => [lower equal](burn_tensor::ops::IntTensorOps::int_lower_equal).
|
||||
/// Float => [lower equal](crate::ops::FloatTensorOps::float_lower_equal).
|
||||
/// Int => [lower equal](crate::ops::IntTensorOps::int_lower_equal).
|
||||
LowerEqual(BinaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [lower equal elem](burn_tensor::ops::FloatTensorOps::float_lower_equal_elem).
|
||||
/// Int => [lower equal elem](burn_tensor::ops::IntTensorOps::int_lower_equal_elem).
|
||||
/// Float => [lower equal elem](crate::ops::FloatTensorOps::float_lower_equal_elem).
|
||||
/// Int => [lower equal elem](crate::ops::IntTensorOps::int_lower_equal_elem).
|
||||
LowerEqualElem(ScalarOperationDescription<E>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [argmax](burn_tensor::ops::FloatTensorOps::float_argmax).
|
||||
/// Int => [argmax](burn_tensor::ops::IntTensorOps::int_argmax).
|
||||
/// Float => [argmax](crate::ops::FloatTensorOps::float_argmax).
|
||||
/// Int => [argmax](crate::ops::IntTensorOps::int_argmax).
|
||||
ArgMax(ScalarOperationDescription<usize>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [argmin](burn_tensor::ops::FloatTensorOps::float_argmin).
|
||||
/// Int => [argmin](burn_tensor::ops::IntTensorOps::int_argmin).
|
||||
/// Float => [argmin](crate::ops::FloatTensorOps::float_argmin).
|
||||
/// Int => [argmin](crate::ops::IntTensorOps::int_argmin).
|
||||
ArgMin(ScalarOperationDescription<usize>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [max](burn_tensor::ops::FloatTensorOps::float_max).
|
||||
/// Int => [max](burn_tensor::ops::IntTensorOps::int_max).
|
||||
/// Float => [max](crate::ops::FloatTensorOps::float_max).
|
||||
/// Int => [max](crate::ops::IntTensorOps::int_max).
|
||||
Max(UnaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [max dim with indices](burn_tensor::ops::FloatTensorOps::float_max_dim_with_indices).
|
||||
/// Int => [max dim with indices](burn_tensor::ops::IntTensorOps::int_max_dim_with_indices).
|
||||
/// Float => [max dim with indices](crate::ops::FloatTensorOps::float_max_dim_with_indices).
|
||||
/// Int => [max dim with indices](crate::ops::IntTensorOps::int_max_dim_with_indices).
|
||||
MaxDimWithIndices(ReduceDimWithIndicesDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [min dim with indices](burn_tensor::ops::FloatTensorOps::float_min_dim_with_indices).
|
||||
/// Int => [min dim with indices](burn_tensor::ops::IntTensorOps::int_min_dim_with_indices).
|
||||
/// Float => [min dim with indices](crate::ops::FloatTensorOps::float_min_dim_with_indices).
|
||||
/// Int => [min dim with indices](crate::ops::IntTensorOps::int_min_dim_with_indices).
|
||||
MinDimWithIndices(ReduceDimWithIndicesDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [min](burn_tensor::ops::FloatTensorOps::float_min).
|
||||
/// Int => [min](burn_tensor::ops::IntTensorOps::int_min).
|
||||
/// Float => [min](crate::ops::FloatTensorOps::float_min).
|
||||
/// Int => [min](crate::ops::IntTensorOps::int_min).
|
||||
Min(UnaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [max dim](burn_tensor::ops::FloatTensorOps::float_max_dim).
|
||||
/// Int => [max dim](burn_tensor::ops::IntTensorOps::int_max_dim).
|
||||
/// Float => [max dim](crate::ops::FloatTensorOps::float_max_dim).
|
||||
/// Int => [max dim](crate::ops::IntTensorOps::int_max_dim).
|
||||
MaxDim(ScalarOperationDescription<usize>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [min dim](burn_tensor::ops::FloatTensorOps::float_min_dim).
|
||||
/// Int => [min dim](burn_tensor::ops::IntTensorOps::int_min_dim).
|
||||
/// Float => [min dim](crate::ops::FloatTensorOps::float_min_dim).
|
||||
/// Int => [min dim](crate::ops::IntTensorOps::int_min_dim).
|
||||
MinDim(ScalarOperationDescription<usize>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [clamp](burn_tensor::ops::FloatTensorOps::float_clamp).
|
||||
/// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp).
|
||||
/// Float => [clamp](crate::ops::FloatTensorOps::float_clamp).
|
||||
/// Int => [clamp](crate::ops::IntTensorOps::int_clamp).
|
||||
Clamp(ClampOperationDescription<E>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Int => [random](burn_tensor::ops::IntTensorOps::int_random).
|
||||
/// Int => [random](crate::ops::IntTensorOps::int_random).
|
||||
IntRandom(RandomOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [powf](burn_tensor::ops::FloatTensorOps::float_powf).
|
||||
/// Int => [powf](burn_tensor::ops::IntTensorOps::int_powf).
|
||||
/// Float => [powf](crate::ops::FloatTensorOps::float_powf).
|
||||
/// Int => [powf](crate::ops::IntTensorOps::int_powf).
|
||||
Powf(BinaryOperationDescription),
|
||||
}
|
||||
|
||||
/// Operation description specific to an int tensor.
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
pub enum IntOperationDescription {
|
||||
/// Operation corresponding to [into float](burn_tensor::ops::IntTensorOps::int_into_float).
|
||||
/// Operation corresponding to [into float](crate::ops::IntTensorOps::int_into_float).
|
||||
IntoFloat(UnaryOperationDescription),
|
||||
}
|
||||
|
||||
/// Operation description specific to a bool tensor.
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
pub enum BoolOperationDescription {
|
||||
/// Operation corresponding to [into float](burn_tensor::ops::BoolTensorOps::bool_into_float).
|
||||
/// Operation corresponding to [into float](crate::ops::BoolTensorOps::bool_into_float).
|
||||
IntoFloat(UnaryOperationDescription),
|
||||
/// Operation corresponding to [into int](burn_tensor::ops::BoolTensorOps::bool_into_int).
|
||||
/// Operation corresponding to [into int](crate::ops::BoolTensorOps::bool_into_int).
|
||||
IntoInt(UnaryOperationDescription),
|
||||
/// Operation corresponding to [not](burn_tensor::ops::BoolTensorOps::bool_not).
|
||||
/// Operation corresponding to [not](crate::ops::BoolTensorOps::bool_not).
|
||||
Not(UnaryOperationDescription),
|
||||
}
|
||||
|
||||
|
@ -1057,7 +1053,7 @@ pub struct InterpolateBackwardDescription {
|
|||
|
||||
impl OperationDescription {
|
||||
/// Cleanup the remaining tensor handles that have not been used.
|
||||
pub(crate) fn nodes(&self) -> Vec<&TensorDescription> {
|
||||
pub fn nodes(&self) -> Vec<&TensorDescription> {
|
||||
match self {
|
||||
OperationDescription::BaseFloat(ops) => ops.nodes(),
|
||||
OperationDescription::BaseInt(ops) => ops.nodes(),
|
|
@ -0,0 +1,45 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// The tensor unique identifier.
|
||||
#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]
|
||||
pub struct TensorId {
|
||||
value: u64,
|
||||
}
|
||||
|
||||
/// The status of the current tensor.
|
||||
#[derive(Hash, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum TensorStatus {
|
||||
/// The tensor can be read, but not written.
|
||||
ReadOnly,
|
||||
/// The tensor can be mutated inplace.
|
||||
ReadWrite,
|
||||
/// No handle exists for that tensor.
|
||||
NotInit,
|
||||
}
|
||||
|
||||
/// A tensor definition represents a snapshot of a tensor when it was used.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// A tensor that is used multiple times has its status updated for each operation.
|
||||
///
|
||||
/// 1. Status::NotInit
|
||||
/// 2. Status::ReadOnly
|
||||
/// 3. Status::ReadOnly
|
||||
/// 4. Status::ReadWrite
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct TensorDescription {
|
||||
/// The [tensor id](TensorId).
|
||||
pub id: TensorId,
|
||||
/// The shape of the tensor.
|
||||
pub shape: Vec<usize>,
|
||||
/// The [status](TensorStatus) of the tensor when it was used.
|
||||
pub status: TensorStatus,
|
||||
}
|
||||
|
||||
impl TensorId {
|
||||
/// Create a new tensor id.
|
||||
pub fn new(value: u64) -> Self {
|
||||
Self { value }
|
||||
}
|
||||
}
|
|
@ -3,7 +3,7 @@ use alloc::string::String;
|
|||
use crate::ops::*;
|
||||
use crate::tensor::Element;
|
||||
|
||||
use super::BackendBridge;
|
||||
use super::{BackendBridge, DeviceOps};
|
||||
|
||||
/// This trait defines all types and functions needed for a backend to be used with burn.
|
||||
///
|
||||
|
@ -66,7 +66,7 @@ pub trait Backend:
|
|||
+ 'static
|
||||
{
|
||||
/// Device type.
|
||||
type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync;
|
||||
type Device: DeviceOps;
|
||||
|
||||
/// A bridge that can cast tensors to full precision.
|
||||
type FullPrecisionBridge: BackendBridge<Self> + 'static;
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
/// The device id.
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
|
||||
pub struct DeviceId {
|
||||
/// The type id identifies the type of the device.
|
||||
pub type_id: u16,
|
||||
/// The index id identifies the device number.
|
||||
pub index_id: u32,
|
||||
}
|
||||
|
||||
/// The handle device trait allows to get an id for a backend device.
|
||||
pub trait DeviceOps: Clone + Default + PartialEq + Send + Sync + core::fmt::Debug {
|
||||
/// Return the [device id](DeviceId).
|
||||
fn id(&self) -> DeviceId;
|
||||
}
|
|
@ -1,8 +1,10 @@
|
|||
mod base;
|
||||
mod bridge;
|
||||
mod device;
|
||||
|
||||
pub use base::*;
|
||||
pub use bridge::*;
|
||||
pub use device::*;
|
||||
|
||||
// Not needed for now, useful for different tensor memory layout
|
||||
// pub mod conversion;
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
use crate::WgpuDevice;
|
||||
use burn_fusion::{DeviceId, FusionDevice};
|
||||
|
||||
impl FusionDevice for WgpuDevice {
|
||||
fn id(&self) -> DeviceId {
|
||||
match self {
|
||||
WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32),
|
||||
WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32),
|
||||
WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32),
|
||||
WgpuDevice::Cpu => DeviceId::new(3, 0),
|
||||
WgpuDevice::BestAvailable => DeviceId::new(4, 0),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -10,9 +10,6 @@ mod element;
|
|||
mod graphics;
|
||||
mod runtime;
|
||||
|
||||
#[cfg(feature = "fusion")]
|
||||
mod fusion;
|
||||
|
||||
#[cfg(feature = "template")]
|
||||
pub use burn_jit::{
|
||||
compute::Kernel,
|
||||
|
|
|
@ -13,6 +13,7 @@ use burn_compute::{
|
|||
ComputeRuntime,
|
||||
};
|
||||
use burn_jit::Runtime;
|
||||
use burn_tensor::backend::{DeviceId, DeviceOps};
|
||||
use std::marker::PhantomData;
|
||||
use wgpu::{AdapterInfo, DeviceDescriptor};
|
||||
|
||||
|
@ -52,6 +53,18 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> Runtime for WgpuRuntime<G,
|
|||
}
|
||||
}
|
||||
|
||||
impl DeviceOps for WgpuDevice {
|
||||
fn id(&self) -> DeviceId {
|
||||
match self {
|
||||
WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32),
|
||||
WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32),
|
||||
WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32),
|
||||
WgpuDevice::Cpu => DeviceId::new(3, 0),
|
||||
WgpuDevice::BestAvailable => DeviceId::new(4, 0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The values that control how a WGPU Runtime will perform its calculations.
|
||||
pub struct RuntimeOptions {
|
||||
/// How the buffers are deallocated.
|
||||
|
|
Loading…
Reference in New Issue