Feat/fusion/cache (#1020)

This commit is contained in:
Nathaniel Simard 2023-12-01 12:05:11 -05:00 committed by GitHub
parent b0de56da29
commit 670280dda2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 2182 additions and 364 deletions

View File

@ -87,7 +87,7 @@ fn erf_positive<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
fn bench<B: Backend>(device: &B::Device) {
const D: usize = 3;
let shape: Shape<D> = [32, 512, 2048].into();
let num_repeats = 1;
let num_repeats = 10;
let reference_gelu = CustomGeluBenchmark::<B, D>::new(
shape.clone(),

View File

@ -17,5 +17,6 @@ std = []
[dependencies]
burn-tensor = {path = "../burn-tensor", version = "0.11.0", default-features = false }
burn-common = {path = "../burn-common", version = "0.11.0" }
hashbrown = { workspace = true }
derive-new = {workspace = true}
spin = {workspace = true}

View File

@ -1,9 +1,10 @@
use crate::{
client::FusionClient, graph::TensorOpsDescription, FusionClientLocator, FusionTensor,
HandleContainer,
client::FusionClient,
graph::{Context, OptimizationFactory, TensorOpsDescription},
FusionClientLocator, FusionTensor,
};
use burn_tensor::{backend::Backend, Device, Shape};
use core::marker::PhantomData;
use std::marker::PhantomData;
pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new();
@ -49,17 +50,18 @@ impl<B: FusionBackend> Backend for Fusion<B> {
}
}
/// The status of a [fusion ops](FusionOps).
pub enum FusionStatus {
/// The status of a [builder](OptimizationBuilder).
#[derive(Clone, Debug, Copy)]
pub enum OptimizationStatus {
/// No more operations can be fused.
Closed(FusionProperties),
Closed,
/// More operations can be fused.
Open(FusionProperties),
Open,
}
/// The properties of a [fusion ops](FusionOps).
/// The properties of a [builder](OptimizationProperties).
#[derive(Debug, Clone, Copy, Default)]
pub struct FusionProperties {
pub struct OptimizationProperties {
/// The score of the optimization, higher is better.
pub score: u64,
/// If the operation is ready to be executed.
@ -78,28 +80,42 @@ pub struct FusionProperties {
///
/// Also, it is important to return (FusionStatus::Closed) when no more registered operation can
/// improve the performance.
pub trait FusionOps<B: FusionBackend>: Send {
pub trait OptimizationBuilder<B: FusionBackend>: Send {
/// Register a new [tensor operation](TensorOpsDescription).
///
/// The return value should be either [closed](FusionStatus::Closed) or
/// [open](FusionStatus::Open).
///
/// When [closed](FusionStatus::Closed), it's assumed that no more operation can be added
/// to the current fusion operation. No [tensor operation](TensorOpsDescription) can be
/// ignored, they are either accepted or rejected, and the [status](FusionStatus) describes it.
fn register(&mut self, ops: &TensorOpsDescription) -> FusionStatus;
/// Execute the operation.
fn execute(&mut self, handles: &mut HandleContainer<B>);
fn register(&mut self, ops: &TensorOpsDescription);
/// Finish the optimization and create a fusion operation.
fn build(&self) -> Box<dyn Optimization<B>>;
/// Reset the state.
fn reset(&mut self);
/// The size of operations fused.
/// Return the builder [status](OptimizationStatus).
fn status(&self) -> OptimizationStatus;
/// Return the builder [properties](OptimizationProperties).
fn properties(&self) -> OptimizationProperties;
}
/// The operation created from the [builder](OptimizationBuilder).
pub trait Optimization<B: FusionBackend>: Send {
/// Execute the operation.
fn execute(&self, context: &mut Context<'_, B>);
/// The number of registered operations in this optimization.
fn len(&self) -> usize;
/// If the current operation is empty.
/// If the current optimization is empty.
fn is_empty(&self) -> bool {
self.len() == 0
}
}
// We implement the OptimizationFactory for all boxed optimization to be used with the Optimization
// Cache. The factory is only used to simplify types and allows better testing. It isn't a public
// crate.
impl<B: FusionBackend> OptimizationFactory<Box<dyn Optimization<B>>>
for Box<dyn OptimizationBuilder<B>>
{
fn create(&self) -> Box<dyn Optimization<B>> {
OptimizationBuilder::build(self.as_ref())
}
}
/// The device id.
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
pub struct DeviceId {
@ -116,7 +132,7 @@ pub trait FusionDevice: Clone + Send + Sync + PartialEq {
}
/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
/// [fusion operation](crate::FusionOps).
/// [operation builder](crate::OptimizationBuilder).
pub trait FusionBackend: Backend {
/// The device type that can return an ID.
///
@ -127,8 +143,8 @@ pub trait FusionBackend: Backend {
/// What kind of client should be used.
type FusionClient: FusionClient<FusionBackend = Self>;
/// The list of operations that will be used to optimize the computational graph.
fn operations(device: &Device<Self>) -> Vec<Box<dyn FusionOps<Self>>>;
/// The list of optimizations that will be used to optimize the computational graph.
fn optimizations(device: &Device<Self>) -> Vec<Box<dyn OptimizationBuilder<Self>>>;
/// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive).
fn float_tensor<const D: usize>(

View File

@ -1,5 +1,5 @@
use crate::{
graph::{GraphExecution, Ops, TensorOpsDescription},
graph::{Ops, TensorOpsDescription},
FusionBackend, FusionTensor, Handle, TensorDescription, TensorId,
};
use burn_tensor::{
@ -11,8 +11,6 @@ use burn_tensor::{
pub trait FusionClient: Send + Sync + Clone {
/// The [fusion backend](FusionBackend) associated type.
type FusionBackend: FusionBackend;
/// The [graph execution](GraphExecution) associated type.
type GraphExecution: GraphExecution<Self::FusionBackend>;
/// Create a new client for the given [fusion device](FusionBackend::FusionDevice).
fn new(device: <Self::FusionBackend as FusionBackend>::FusionDevice) -> Self;

View File

@ -1,26 +1,21 @@
use super::FusionClient;
use crate::{
graph::{GraphExecution, TensorOpsDescription},
FusionBackend, FusionServer, FusionTensor, Handle,
};
use crate::{graph::TensorOpsDescription, FusionBackend, FusionServer, FusionTensor, Handle};
use burn_tensor::ops::FloatElem;
use spin::Mutex;
use std::sync::Arc;
/// Use a mutex to communicate with the fusion server.
pub struct MutexFusionClient<B, G>
pub struct MutexFusionClient<B>
where
B: FusionBackend,
G: GraphExecution<B>,
{
server: Arc<Mutex<FusionServer<B, G>>>,
server: Arc<Mutex<FusionServer<B>>>,
device: B::FusionDevice,
}
impl<B, G> Clone for MutexFusionClient<B, G>
impl<B> Clone for MutexFusionClient<B>
where
B: FusionBackend,
G: GraphExecution<B>,
{
fn clone(&self) -> Self {
Self {
@ -30,13 +25,11 @@ where
}
}
impl<B, G> FusionClient for MutexFusionClient<B, G>
impl<B> FusionClient for MutexFusionClient<B>
where
B: FusionBackend,
G: GraphExecution<B>,
{
type FusionBackend = B;
type GraphExecution = G;
fn new(device: B::FusionDevice) -> Self {
Self {

View File

@ -1,100 +1,97 @@
use super::Ops;
use super::RelativeGraphConverter;
use super::TensorOpsDescription;
use crate::{FusionBackend, FusionOps, FusionProperties, FusionStatus, HandleContainer};
use std::{ops::RangeBounds, sync::Arc};
use crate::Optimization;
use crate::{FusionBackend, HandleContainer};
use std::ops::RangeBounds;
/// The computational graph containing a list of [tensor operation descriptions](TensorOpsDescription).
pub struct Graph<B: FusionBackend> {
operations: Vec<Arc<TensorOpsDescription>>,
pub(crate) global: Vec<TensorOpsDescription>,
pub(crate) relative: Vec<TensorOpsDescription>,
converter: RelativeGraphConverter,
ops: Vec<Box<dyn Ops<B>>>,
}
impl<B: FusionBackend> Graph<B> {
pub(crate) fn new() -> Self {
Self {
operations: Vec::new(),
global: Vec::new(),
relative: Vec::new(),
converter: RelativeGraphConverter::default(),
ops: Vec::new(),
}
}
pub(crate) fn add(&mut self, description: Arc<TensorOpsDescription>, ops: Box<dyn Ops<B>>) {
self.operations.push(description);
pub(crate) fn split_relative_graph(
&self,
) -> (&[TensorOpsDescription], Option<&TensorOpsDescription>) {
let len = self.relative.len();
if len < 1 {
return (&self.relative, None);
}
(&self.relative[0..len - 1], self.relative.last())
}
pub(crate) fn add(&mut self, global: TensorOpsDescription, ops: Box<dyn Ops<B>>) {
let relative = global.to_relative(&mut self.converter);
self.relative.push(relative);
self.global.push(global);
self.ops.push(ops);
}
/// The size of the graph.
pub fn len(&self) -> usize {
self.operations.len()
pub(crate) fn len(&self) -> usize {
self.global.len()
}
/// If the graph is empty.
pub fn is_empty(&self) -> bool {
self.operations.len() == 0
}
fn remove<R: RangeBounds<usize> + Clone>(
&mut self,
range: R,
handles: &mut HandleContainer<B>,
) {
for ops in self.operations.drain(range.clone()) {
ops.cleanup_tensor(handles)
}
self.ops.drain(range);
}
fn nodes(&self) -> &[Arc<TensorOpsDescription>] {
&self.operations
pub(crate) fn is_empty(&self) -> bool {
self.len() == 0
}
pub(crate) fn execute_optimization(
&mut self,
handles: &mut HandleContainer<B>,
index: usize,
optimizations: &mut [Optimization<B>],
optimization: &dyn Optimization<B>,
) {
let optimization = optimizations.get_mut(index).unwrap();
let num_keep = optimization.ops.len();
optimization.ops.execute(handles);
let num_keep = optimization.len();
let mut context = self.converter.context(handles);
optimization.execute(&mut context);
self.remove(0..num_keep, handles);
for optimization in optimizations.iter_mut() {
optimization.reset();
for node in self.nodes() {
optimization.register(node);
}
}
self.cleanup_partial(0..num_keep, handles);
}
pub(crate) fn execute(&mut self, handles: &mut HandleContainer<B>) {
for (description, ops) in self.operations.drain(..).zip(self.ops.drain(..)) {
pub(crate) fn execute_operations(&mut self, handles: &mut HandleContainer<B>) {
for (description, ops) in self.global.drain(..).zip(self.ops.drain(..)) {
ops.execute(handles);
description.cleanup_tensor(handles);
}
self.cleanup_relative_graph();
}
}
/// An optimization that can be executed.
#[derive(new)]
pub struct Optimization<B: FusionBackend> {
/// The [fusion operation](FusionOps) to potentially be executed.
pub ops: Box<dyn FusionOps<B>>,
/// The current status of the optimization.
pub status: FusionStatus,
}
impl<B: FusionBackend> Optimization<B> {
pub(crate) fn register(&mut self, ops: &TensorOpsDescription) {
if let FusionStatus::Closed(_) = self.status {
return;
fn cleanup_partial<R: RangeBounds<usize> + Clone>(
&mut self,
range: R,
handles: &mut HandleContainer<B>,
) {
for ops in self.global.drain(range.clone()) {
ops.cleanup_tensor(handles)
}
self.ops.drain(range);
self.status = self.ops.register(ops);
// Rebuild the relative graph when partially removing the global graph.
self.cleanup_relative_graph();
for node in self.global.iter() {
let relative = node.to_relative(&mut self.converter);
self.relative.push(relative);
}
}
pub(crate) fn reset(&mut self) {
self.ops.reset();
self.status = FusionStatus::Open(FusionProperties::default());
fn cleanup_relative_graph(&mut self) {
self.relative.clear();
self.converter.clear();
}
}

View File

@ -0,0 +1,845 @@
use super::{
AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription,
AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription,
AvgPool2dBackwardDescription, AvgPool2dDescription, BaseOpsDescription, BinaryOpsDescription,
BoolOpsDescription, ClampOpsDescription, Conv1dDescription, Conv2dDescription,
ConvTranspose1dDescription, ConvTranspose2dDescription, EmbeddingBackwardDescription,
EmbeddingDescription, FloatOpsDescription, GatherOpsDescription, IntOpsDescription,
MaskFillOpsDescription, MaskWhereOpsDescription, MaxPool1dDescription,
MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription, MaxPool2dDescription,
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, ModuleOpsDescription,
NumericOpsDescription, RandomOpsDescription, ReduceDimWithIndicesDescription,
ReshapeDescription, ScalarOpsDescription, ScatterOpsDescription, SelectAssignOpsDescription,
SelectOpsDescription, SliceOpsDescription, SwapDimsDescription, TensorOpsDescription,
UnaryOpsDescription,
};
use crate::{FusionBackend, HandleContainer, TensorDescription, TensorId};
use burn_tensor::{Element, ElementConversion};
use hashbrown::HashMap;
/// The context contains the relative graph tensor mapping so that a relative tensor id can be
/// mapped to an existing tensor that can be fetched and updated with the
/// [handle container](HandleContainer).
///
/// It also contains all scalar values, which can change even for the same graph. They are sorted
/// in the order in which they appear in the graph.
#[derive(new)]
pub struct Context<'a, B: FusionBackend> {
/// The tensor mapping where local tensor id points to the updated tensor description.
pub tensors: &'a HashMap<TensorId, TensorDescription>,
/// Handle container to retrieve tensors based on their description.
pub handles: &'a mut HandleContainer<B>,
/// Float scalars found in the graph in the order they appeared.
pub scalar_floats: &'a Vec<f32>,
/// Int scalars found in the graph in the order they appeared.
pub scalar_ints: &'a Vec<i32>,
}
#[derive(Default)]
pub(crate) struct RelativeGraphConverter {
tensors_relative2global: HashMap<TensorId, TensorDescription>,
tensors_global2relative: HashMap<TensorId, TensorDescription>,
/// Only useful to create new shape ID.
/// You should use tensor descriptions to retrieve the proper shape.
shapes_global2relative: HashMap<usize, usize>,
scalar_floats: Vec<f32>,
scalar_ints: Vec<i32>,
}
impl RelativeGraphConverter {
pub(crate) fn context<'a, B: FusionBackend>(
&'a self,
handles: &'a mut HandleContainer<B>,
) -> Context<'a, B> {
Context {
handles,
tensors: &self.tensors_relative2global,
scalar_floats: &self.scalar_floats,
scalar_ints: &self.scalar_ints,
}
}
pub(crate) fn clear(&mut self) {
self.tensors_relative2global.clear();
self.tensors_global2relative.clear();
self.shapes_global2relative.clear();
self.scalar_floats.clear();
self.scalar_ints.clear();
}
pub(crate) fn relative_float<E: Element>(&mut self, elem: &E) -> E {
self.scalar_floats.push(elem.elem());
// We return 0 so that the id from a scalar operation is the same no matter its scalar
// value.
0.elem()
}
pub(crate) fn relative_int<E: Element>(&mut self, elem: &E) -> E {
self.scalar_ints.push(elem.elem());
// We return 0 so that the id from a scalar operation is the same no matter its scalar
// value.
0.elem()
}
}
impl TensorOpsDescription {
pub(crate) fn to_relative(&self, converter: &mut RelativeGraphConverter) -> Self {
match self {
TensorOpsDescription::BaseOpsFloat(ops) => {
TensorOpsDescription::BaseOpsFloat(ops.to_relative(converter))
}
TensorOpsDescription::BaseOpsInt(ops) => {
TensorOpsDescription::BaseOpsInt(ops.to_relative(converter))
}
TensorOpsDescription::BaseOpsBool(ops) => {
TensorOpsDescription::BaseOpsBool(ops.to_relative(converter))
}
TensorOpsDescription::NumericOpsFloat(ops) => TensorOpsDescription::NumericOpsFloat(
ops.to_relative(converter, |converter, e| converter.relative_float(e)),
),
TensorOpsDescription::NumericOpsInt(ops) => TensorOpsDescription::NumericOpsInt(
ops.to_relative(converter, |converter, e| converter.relative_int(e)),
),
TensorOpsDescription::BoolOps(ops) => {
TensorOpsDescription::BoolOps(ops.to_relative(converter))
}
TensorOpsDescription::IntOps(ops) => {
TensorOpsDescription::IntOps(ops.to_relative(converter))
}
TensorOpsDescription::FloatOps(ops) => {
TensorOpsDescription::FloatOps(ops.to_relative(converter))
}
TensorOpsDescription::ModuleOps(ops) => {
TensorOpsDescription::ModuleOps(ops.to_relative(converter))
}
}
}
}
impl ModuleOpsDescription {
pub(crate) fn to_relative(&self, converter: &mut RelativeGraphConverter) -> Self {
match self {
ModuleOpsDescription::Embedding(desc) => {
ModuleOpsDescription::Embedding(EmbeddingDescription {
weights: desc.weights.to_relative(converter),
indices: desc.indices.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
ModuleOpsDescription::EmbeddingBackward(desc) => {
ModuleOpsDescription::EmbeddingBackward(EmbeddingBackwardDescription {
weights: desc.weights.to_relative(converter),
out_grad: desc.out_grad.to_relative(converter),
indices: desc.indices.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
ModuleOpsDescription::Conv1d(desc) => ModuleOpsDescription::Conv1d(Conv1dDescription {
x: desc.x.to_relative(converter),
weight: desc.weight.to_relative(converter),
bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),
options: desc.options.clone(),
out: desc.out.to_relative(converter),
}),
ModuleOpsDescription::Conv2d(desc) => ModuleOpsDescription::Conv2d(Conv2dDescription {
x: desc.x.to_relative(converter),
weight: desc.weight.to_relative(converter),
bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),
options: desc.options.clone(),
out: desc.out.to_relative(converter),
}),
ModuleOpsDescription::ConvTranspose1d(desc) => {
ModuleOpsDescription::ConvTranspose1d(ConvTranspose1dDescription {
x: desc.x.to_relative(converter),
weight: desc.weight.to_relative(converter),
bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),
options: desc.options.clone(),
out: desc.out.to_relative(converter),
})
}
ModuleOpsDescription::ConvTranspose2d(desc) => {
ModuleOpsDescription::ConvTranspose2d(ConvTranspose2dDescription {
x: desc.x.to_relative(converter),
weight: desc.weight.to_relative(converter),
bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),
options: desc.options.clone(),
out: desc.out.to_relative(converter),
})
}
ModuleOpsDescription::AvgPool1d(desc) => {
ModuleOpsDescription::AvgPool1d(super::AvgPool1dDescription {
x: desc.x.to_relative(converter),
kernel_size: desc.kernel_size,
stride: desc.stride,
padding: desc.padding,
count_include_pad: desc.count_include_pad,
out: desc.out.to_relative(converter),
})
}
ModuleOpsDescription::AvgPool2d(desc) => {
ModuleOpsDescription::AvgPool2d(AvgPool2dDescription {
x: desc.x.to_relative(converter),
kernel_size: desc.kernel_size,
stride: desc.stride,
padding: desc.padding,
count_include_pad: desc.count_include_pad,
out: desc.out.to_relative(converter),
})
}
ModuleOpsDescription::AvgPool1dBackward(desc) => {
ModuleOpsDescription::AvgPool1dBackward(super::AvgPool1dBackwardDescription {
x: desc.x.to_relative(converter),
grad: desc.grad.to_relative(converter),
kernel_size: desc.kernel_size,
stride: desc.stride,
padding: desc.padding,
count_include_pad: desc.count_include_pad,
out: desc.out.to_relative(converter),
})
}
ModuleOpsDescription::AvgPool2dBackward(desc) => {
ModuleOpsDescription::AvgPool2dBackward(AvgPool2dBackwardDescription {
x: desc.x.to_relative(converter),
grad: desc.grad.to_relative(converter),
kernel_size: desc.kernel_size,
stride: desc.stride,
padding: desc.padding,
count_include_pad: desc.count_include_pad,
out: desc.out.to_relative(converter),
})
}
ModuleOpsDescription::AdaptiveAvgPool1d(desc) => {
ModuleOpsDescription::AdaptiveAvgPool1d(AdaptiveAvgPool1dDescription {
x: desc.x.to_relative(converter),
output_size: desc.output_size,
out: desc.out.to_relative(converter),
})
}
ModuleOpsDescription::AdaptiveAvgPool2d(desc) => {
ModuleOpsDescription::AdaptiveAvgPool2d(AdaptiveAvgPool2dDescription {
x: desc.x.to_relative(converter),
output_size: desc.output_size,
out: desc.out.to_relative(converter),
})
}
ModuleOpsDescription::AdaptiveAvgPool1dBackward(desc) => {
ModuleOpsDescription::AdaptiveAvgPool1dBackward(
AdaptiveAvgPool1dBackwardDescription {
x: desc.x.to_relative(converter),
grad: desc.grad.to_relative(converter),
out: desc.out.to_relative(converter),
},
)
}
ModuleOpsDescription::AdaptiveAvgPool2dBackward(desc) => {
ModuleOpsDescription::AdaptiveAvgPool2dBackward(
AdaptiveAvgPool2dBackwardDescription {
x: desc.x.to_relative(converter),
grad: desc.grad.to_relative(converter),
out: desc.out.to_relative(converter),
},
)
}
ModuleOpsDescription::MaxPool1d(desc) => {
ModuleOpsDescription::MaxPool1d(MaxPool1dDescription {
x: desc.x.to_relative(converter),
kernel_size: desc.kernel_size,
stride: desc.stride,
padding: desc.padding,
dilation: desc.dilation,
out: desc.out.to_relative(converter),
})
}
ModuleOpsDescription::MaxPool1dWithIndices(desc) => {
ModuleOpsDescription::MaxPool1dWithIndices(MaxPool1dWithIndicesDescription {
x: desc.x.to_relative(converter),
kernel_size: desc.kernel_size,
stride: desc.stride,
padding: desc.padding,
dilation: desc.dilation,
out: desc.out.to_relative(converter),
out_indices: desc.out_indices.to_relative(converter),
})
}
ModuleOpsDescription::MaxPool1dWithIndicesBackward(desc) => {
ModuleOpsDescription::MaxPool1dWithIndicesBackward(
MaxPool1dWithIndicesBackwardDescription {
x: desc.x.to_relative(converter),
grad: desc.grad.to_relative(converter),
indices: desc.indices.to_relative(converter),
kernel_size: desc.kernel_size,
stride: desc.stride,
padding: desc.padding,
dilation: desc.dilation,
out: desc.out.to_relative(converter),
},
)
}
ModuleOpsDescription::MaxPool2d(desc) => {
ModuleOpsDescription::MaxPool2d(MaxPool2dDescription {
x: desc.x.to_relative(converter),
kernel_size: desc.kernel_size,
stride: desc.stride,
padding: desc.padding,
dilation: desc.dilation,
out: desc.out.to_relative(converter),
})
}
ModuleOpsDescription::MaxPool2dWithIndices(desc) => {
ModuleOpsDescription::MaxPool2dWithIndices(MaxPool2dWithIndicesDescription {
x: desc.x.to_relative(converter),
kernel_size: desc.kernel_size,
stride: desc.stride,
padding: desc.padding,
dilation: desc.dilation,
out: desc.out.to_relative(converter),
out_indices: desc.out_indices.to_relative(converter),
})
}
ModuleOpsDescription::MaxPool2dWithIndicesBackward(desc) => {
ModuleOpsDescription::MaxPool2dWithIndicesBackward(
MaxPool2dWithIndicesBackwardDescription {
x: desc.x.to_relative(converter),
grad: desc.grad.to_relative(converter),
indices: desc.indices.to_relative(converter),
kernel_size: desc.kernel_size,
stride: desc.stride,
padding: desc.padding,
dilation: desc.dilation,
out: desc.out.to_relative(converter),
},
)
}
}
}
}
impl FloatOpsDescription {
pub(crate) fn to_relative(&self, converter: &mut RelativeGraphConverter) -> Self {
match self {
FloatOpsDescription::Exp(desc) => FloatOpsDescription::Exp(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
FloatOpsDescription::Log(desc) => FloatOpsDescription::Log(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
FloatOpsDescription::Log1p(desc) => FloatOpsDescription::Log1p(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
FloatOpsDescription::Erf(desc) => FloatOpsDescription::Erf(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
FloatOpsDescription::Powf(desc) => FloatOpsDescription::Powf(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: converter.relative_float(&desc.rhs),
out: desc.out.to_relative(converter),
}),
FloatOpsDescription::Sqrt(desc) => FloatOpsDescription::Sqrt(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
FloatOpsDescription::Cos(desc) => FloatOpsDescription::Cos(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
FloatOpsDescription::Sin(desc) => FloatOpsDescription::Sin(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
FloatOpsDescription::Tanh(desc) => FloatOpsDescription::Tanh(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
FloatOpsDescription::IntoInt(desc) => {
FloatOpsDescription::IntoInt(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
FloatOpsDescription::Matmul(desc) => {
FloatOpsDescription::Matmul(BinaryOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
FloatOpsDescription::Random(desc) => {
FloatOpsDescription::Random(RandomOpsDescription {
out: desc.out.to_relative(converter),
distribution: desc.distribution,
})
}
FloatOpsDescription::Recip(desc) => FloatOpsDescription::Recip(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
}
}
}
impl BoolOpsDescription {
pub(crate) fn to_relative(&self, converter: &mut RelativeGraphConverter) -> Self {
match self {
BoolOpsDescription::IntoFloat(desc) => {
BoolOpsDescription::IntoFloat(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
BoolOpsDescription::IntoInt(desc) => BoolOpsDescription::IntoInt(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
BoolOpsDescription::Not(desc) => BoolOpsDescription::Not(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
}
}
}
impl IntOpsDescription {
pub(crate) fn to_relative(&self, converter: &mut RelativeGraphConverter) -> Self {
match self {
IntOpsDescription::IntoFloat(desc) => {
IntOpsDescription::IntoFloat(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
}
}
}
impl<E: Element> NumericOpsDescription<E> {
pub(crate) fn to_relative<F>(
&self,
converter: &mut RelativeGraphConverter,
local_elem: F,
) -> Self
where
F: Fn(&mut RelativeGraphConverter, &E) -> E,
{
match self {
NumericOpsDescription::Add(desc) => NumericOpsDescription::Add(BinaryOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
}),
NumericOpsDescription::AddScalar(desc) => {
NumericOpsDescription::AddScalar(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::Sub(desc) => NumericOpsDescription::Sub(BinaryOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
}),
NumericOpsDescription::SubScalar(desc) => {
NumericOpsDescription::SubScalar(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::Div(desc) => NumericOpsDescription::Div(BinaryOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
}),
NumericOpsDescription::DivScalar(desc) => {
NumericOpsDescription::DivScalar(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::Mul(desc) => NumericOpsDescription::Mul(BinaryOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
}),
NumericOpsDescription::MulScalar(desc) => {
NumericOpsDescription::MulScalar(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::Abs(desc) => NumericOpsDescription::Abs(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
NumericOpsDescription::Ones(desc) => {
NumericOpsDescription::Ones(desc.to_relative(converter))
}
NumericOpsDescription::Zeros(desc) => {
NumericOpsDescription::Zeros(desc.to_relative(converter))
}
NumericOpsDescription::Full(desc) => NumericOpsDescription::Full((
desc.0.to_relative(converter),
local_elem(converter, &desc.1),
)),
NumericOpsDescription::Gather(desc) => {
NumericOpsDescription::Gather(GatherOpsDescription {
tensor: desc.tensor.to_relative(converter),
dim: desc.dim,
indices: desc.indices.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::Scatter(desc) => {
NumericOpsDescription::Scatter(ScatterOpsDescription {
tensor: desc.tensor.to_relative(converter),
dim: desc.dim,
indices: desc.indices.to_relative(converter),
value: desc.value.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::Select(desc) => {
NumericOpsDescription::Select(SelectOpsDescription {
tensor: desc.tensor.to_relative(converter),
dim: desc.dim,
indices: desc.indices.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::SelectAssign(desc) => {
NumericOpsDescription::SelectAssign(SelectAssignOpsDescription {
tensor: desc.tensor.to_relative(converter),
dim: desc.dim,
indices: desc.indices.to_relative(converter),
value: desc.value.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::MaskWhere(desc) => {
NumericOpsDescription::MaskWhere(MaskWhereOpsDescription {
tensor: desc.tensor.to_relative(converter),
mask: desc.mask.to_relative(converter),
value: desc.value.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::MaskFill(desc) => {
NumericOpsDescription::MaskFill(MaskFillOpsDescription {
tensor: desc.tensor.to_relative(converter),
mask: desc.mask.to_relative(converter),
value: local_elem(converter, &desc.value),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::MeanDim(desc) => {
NumericOpsDescription::MeanDim(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs, // Dim should stay the same.
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::Mean(desc) => NumericOpsDescription::Mean(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
NumericOpsDescription::Sum(desc) => NumericOpsDescription::Sum(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
NumericOpsDescription::SumDim(desc) => {
NumericOpsDescription::SumDim(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs, // Dim should stay the same.
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::EqualElem(desc) => {
NumericOpsDescription::EqualElem(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::Greater(desc) => {
NumericOpsDescription::Greater(BinaryOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::GreaterElem(desc) => {
NumericOpsDescription::GreaterElem(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::GreaterEqual(desc) => {
NumericOpsDescription::GreaterEqual(BinaryOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::GreaterEqualElem(desc) => {
NumericOpsDescription::GreaterEqualElem(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::Lower(desc) => {
NumericOpsDescription::Lower(BinaryOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::LowerElem(desc) => {
NumericOpsDescription::LowerElem(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::LowerEqual(desc) => {
NumericOpsDescription::LowerEqual(BinaryOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::LowerEqualElem(desc) => {
NumericOpsDescription::LowerEqualElem(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::ArgMax(desc) => {
NumericOpsDescription::ArgMax(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs,
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::ArgMin(desc) => {
NumericOpsDescription::ArgMin(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs,
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::Max(desc) => NumericOpsDescription::Max(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
NumericOpsDescription::MaxDimWithIndices(desc) => {
NumericOpsDescription::MaxDimWithIndices(ReduceDimWithIndicesDescription {
tensor: desc.tensor.to_relative(converter),
dim: desc.dim,
out: desc.out.to_relative(converter),
out_indices: desc.out_indices.to_relative(converter),
})
}
NumericOpsDescription::MinDimWithIndices(desc) => {
NumericOpsDescription::MinDimWithIndices(ReduceDimWithIndicesDescription {
tensor: desc.tensor.to_relative(converter),
dim: desc.dim,
out: desc.out.to_relative(converter),
out_indices: desc.out_indices.to_relative(converter),
})
}
NumericOpsDescription::Min(desc) => NumericOpsDescription::Min(UnaryOpsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
NumericOpsDescription::MaxDim(desc) => {
NumericOpsDescription::MaxDim(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs,
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::MinDim(desc) => {
NumericOpsDescription::MinDim(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs,
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::Clamp(desc) => {
NumericOpsDescription::Clamp(ClampOpsDescription {
tensor: desc.tensor.to_relative(converter),
min: local_elem(converter, &desc.min),
max: local_elem(converter, &desc.max),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::ClampMax(desc) => {
NumericOpsDescription::ClampMax(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::ClampMin(desc) => {
NumericOpsDescription::ClampMin(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
}
}
}
impl BaseOpsDescription {
pub(crate) fn to_relative(&self, converter: &mut RelativeGraphConverter) -> Self {
match self {
BaseOpsDescription::ToDevice(desc) => {
BaseOpsDescription::ToDevice(desc.to_relative(converter))
}
BaseOpsDescription::Reshape(desc) => BaseOpsDescription::Reshape(ReshapeDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
}),
BaseOpsDescription::SwapDims(desc) => {
BaseOpsDescription::SwapDims(SwapDimsDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
dim1: desc.dim1,
dim2: desc.dim2,
})
}
BaseOpsDescription::Slice(desc) => BaseOpsDescription::Slice(SliceOpsDescription {
tensor: desc.tensor.to_relative(converter),
ranges: desc.ranges.clone(),
out: desc.out.to_relative(converter),
}),
BaseOpsDescription::SliceAssign(desc) => {
BaseOpsDescription::SliceAssign(super::SliceAssignOpsDescription {
tensor: desc.tensor.to_relative(converter),
ranges: desc.ranges.clone(),
value: desc.value.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
BaseOpsDescription::Equal(desc) => {
BaseOpsDescription::Equal(super::BinaryOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
BaseOpsDescription::Repeat(desc) => {
BaseOpsDescription::Repeat(super::RepeatOpsDescription {
tensor: desc.tensor.to_relative(converter),
dim: desc.dim,
times: desc.times,
out: desc.out.to_relative(converter),
})
}
BaseOpsDescription::Cat(desc) => BaseOpsDescription::Cat(super::CatOpsDescription {
tensors: desc
.tensors
.iter()
.map(|tensor| tensor.to_relative(converter))
.collect(),
dim: desc.dim,
out: desc.out.to_relative(converter),
}),
}
}
}
impl TensorDescription {
pub(crate) fn to_relative(&self, converter: &mut RelativeGraphConverter) -> Self {
let relative_id = if let Some(value) = converter.tensors_global2relative.get(&self.id) {
// If we already have the same tensor registered, we have to update its value, but not
// its id.
value.id.clone()
} else {
// We create a new relative id since we never seen this tensor in the graph before.
TensorId::new(converter.tensors_relative2global.len() as u64)
};
// We can create relative shapes by mapping each shape found to an ID, which is a `usize`.
let mut relative_shape = Vec::with_capacity(self.shape.len());
for dim in self.shape.iter() {
if let Some(dim_id) = converter.shapes_global2relative.get(dim) {
// We already saw that dim value before, so we retrieve its ID.
relative_shape.push(*dim_id);
} else {
// We never saw this dim value before, therefore we create a new ID.
let dim_id = converter.shapes_global2relative.len();
relative_shape.push(dim_id);
converter.shapes_global2relative.insert(*dim, dim_id);
}
}
// We create the relative tensor.
let relative_tensor = TensorDescription {
id: relative_id.clone(),
shape: relative_shape,
status: self.status.clone(),
};
// We update both mappings.
converter
.tensors_relative2global
.insert(relative_id, self.clone());
converter
.tensors_global2relative
.insert(self.id.clone(), relative_tensor.clone());
relative_tensor
}
}
#[cfg(test)]
mod tests {
use crate::TensorStatus;
use super::*;
#[test]
fn tensor_description_to_relative() {
let tensor1 = TensorDescription {
id: TensorId::new(500),
shape: vec![512, 32, 2048],
status: TensorStatus::ReadOnly,
};
let tensor2 = TensorDescription {
id: TensorId::new(501),
shape: vec![512, 128, 2048],
status: TensorStatus::ReadOnly,
};
let mut converter = RelativeGraphConverter::default();
let tensor1_local = tensor1.to_relative(&mut converter);
let tensor2_local = tensor2.to_relative(&mut converter);
assert_eq!(
tensor1_local,
TensorDescription {
id: TensorId::new(0),
shape: vec![0, 1, 2],
status: TensorStatus::ReadOnly
}
);
assert_eq!(
tensor2_local,
TensorDescription {
id: TensorId::new(1),
shape: vec![0, 3, 2],
status: TensorStatus::ReadOnly
}
);
}
}

View File

@ -1,59 +1,199 @@
use super::{Graph, Optimization};
use crate::{FusionBackend, FusionStatus, HandleContainer};
/// The graph execution trait abstracts the way the graph is executing optimizations.
pub trait GraphExecution<B: FusionBackend>: Default + Send {
/// Execute the given graph using the list of potential [optimizations](Optimization).
/// May do nothing if empty or not ready
fn maybe_execute(
&mut self,
graph: &mut Graph<B>,
handles: &mut HandleContainer<B>,
optimizations: &mut [Optimization<B>],
force: bool,
);
}
use super::{CacheResult, Condition, Graph, OptimizationCache, TensorOpsDescription};
use crate::{
FusionBackend, HandleContainer, Optimization, OptimizationBuilder, OptimizationStatus,
};
/// Execute an optimization following a greedy algorithm.
#[derive(Default)]
pub struct GreedyGraphExecution;
pub(crate) struct GraphExecution<B: FusionBackend> {
optimization_cache: OptimizationCache<Box<dyn Optimization<B>>>,
optimizations: Vec<Box<dyn OptimizationBuilder<B>>>,
num_skipped: usize,
}
impl<B: FusionBackend> GraphExecution<B> for GreedyGraphExecution {
fn maybe_execute(
#[derive(Clone, Copy, Debug)]
pub(crate) enum ExecutionMode {
// Signal that we execute the graph after a new ops is added to the graph.
NewOps,
// Signal that we execute the graph because of a sync without any new ops added to the graph.
Sync,
}
impl<B: FusionBackend> GraphExecution<B> {
/// Create a new graph execution with the given optimization builders.
pub fn new(optimizations: Vec<Box<dyn OptimizationBuilder<B>>>) -> Self {
Self {
optimization_cache: OptimizationCache::new(),
optimizations,
num_skipped: 0,
}
}
/// Execute the graph with the provided mode.
pub fn execute(
&mut self,
graph: &mut Graph<B>,
handles: &mut HandleContainer<B>,
optimizations: &mut [Optimization<B>],
force: bool,
mode: ExecutionMode,
) {
loop {
if !force && still_optimizing(optimizations) {
if graph.is_empty() {
break;
}
match find_best_optimization_index(optimizations) {
Some(index) => {
graph.execute_optimization(handles, index, optimizations);
}
None => {
graph.execute(handles);
optimizations.iter_mut().for_each(|ops| ops.reset());
}
}
match self.cache(graph, mode) {
CacheResult::Miss => {
match self.build(graph, mode) {
BuildAction::ExecuteOptimization(ops) => {
graph.execute_optimization(handles, ops);
self.reset(graph);
}
BuildAction::ExecuteOperations => {
graph.execute_operations(handles);
self.reset(graph);
}
BuildAction::ContinueBuilding => {
if let ExecutionMode::Sync = mode {
panic!("Can't continue building when sync is called.")
}
}
};
if graph.is_empty() {
// No more ops to fuse.
if self.num_skipped == 0 {
break;
}
}
CacheResult::OnPath => {
self.num_skipped += 1;
match mode {
ExecutionMode::NewOps => break,
ExecutionMode::Sync => panic!("Can't wait while sync"),
};
}
CacheResult::Found(ops) => {
graph.execute_optimization(handles, ops.as_ref());
self.reset(graph);
}
};
if let ExecutionMode::NewOps = mode {
break;
}
}
}
fn build(&mut self, graph: &mut Graph<B>, mode: ExecutionMode) -> BuildAction<'_, B> {
// When we are executing with the new ops mode, we need to register the last ops of the
// graph even when there is no skipped operation.
let offset = match mode {
ExecutionMode::NewOps => 1,
ExecutionMode::Sync => 0,
};
for i in (0..self.num_skipped + offset).rev() {
let index = graph.relative.len() - 1 - i;
let relative = &graph.relative[index];
for ops in self.optimizations.iter_mut() {
ops.register(relative);
}
}
self.num_skipped = 0;
// Can only be lazy when not sync.
if let ExecutionMode::NewOps = mode {
if still_optimizing(&self.optimizations) {
return BuildAction::ContinueBuilding;
}
}
match find_best_optimization_index(&self.optimizations) {
Some(index) => {
let (relative, next_ops) = Self::split_relative_graph_owned(graph, mode);
let optimization = &self.optimizations[index];
let ops = self
.optimization_cache
.complete(optimization, relative, next_ops);
BuildAction::ExecuteOptimization(ops.as_ref())
}
None => {
// TODO: Cache this result too.
BuildAction::ExecuteOperations
}
}
}
fn reset(&mut self, graph: &mut Graph<B>) {
for ops in self.optimizations.iter_mut() {
ops.reset();
}
self.num_skipped = graph.relative.len();
self.optimization_cache.reset();
// Reset the policy state.
for i in 0..self.num_skipped {
let _ = self.optimization_cache.follow(
&graph.relative[0..i],
Condition::NextOps(&graph.relative[i]),
);
}
}
fn cache<'a>(
&'a mut self,
graph: &mut Graph<B>,
mode: ExecutionMode,
) -> CacheResult<'a, Box<dyn Optimization<B>>> {
let (graph, next_ops) = Self::split_relative_graph_ref(graph, mode);
let end_condition = next_ops.map(Condition::NextOps).unwrap_or(Condition::Sync);
let action = self.optimization_cache.follow(graph, end_condition);
match mode {
ExecutionMode::NewOps => action,
ExecutionMode::Sync => match action {
CacheResult::Miss => CacheResult::Miss,
CacheResult::OnPath => CacheResult::Miss,
CacheResult::Found(ops) => CacheResult::Found(ops),
},
}
}
fn split_relative_graph_owned(
graph: &Graph<B>,
mode: ExecutionMode,
) -> (Vec<TensorOpsDescription>, Option<TensorOpsDescription>) {
match mode {
ExecutionMode::NewOps => {
let graph = graph.split_relative_graph();
(graph.0.to_vec(), graph.1.cloned())
}
ExecutionMode::Sync => (graph.relative.clone(), None),
}
}
fn split_relative_graph_ref(
graph: &Graph<B>,
mode: ExecutionMode,
) -> (&[TensorOpsDescription], Option<&TensorOpsDescription>) {
match mode {
ExecutionMode::NewOps => graph.split_relative_graph(),
ExecutionMode::Sync => (graph.relative.as_slice(), None),
}
}
}
fn still_optimizing<B: FusionBackend>(optimizations: &[Optimization<B>]) -> bool {
enum BuildAction<'a, B: FusionBackend> {
ExecuteOptimization(&'a dyn Optimization<B>),
ExecuteOperations,
ContinueBuilding,
}
fn still_optimizing<B: FusionBackend>(optimizations: &[Box<dyn OptimizationBuilder<B>>]) -> bool {
let mut num_stopped = 0;
for optimization in optimizations.iter() {
if let FusionStatus::Closed(_) = optimization.status {
if let OptimizationStatus::Closed = optimization.status() {
num_stopped += 1
}
}
@ -62,16 +202,13 @@ fn still_optimizing<B: FusionBackend>(optimizations: &[Optimization<B>]) -> bool
}
fn find_best_optimization_index<B: FusionBackend>(
optimizations: &[Optimization<B>],
optimizations: &[Box<dyn OptimizationBuilder<B>>],
) -> Option<usize> {
let mut best_index = None;
let mut best_score = 0;
for (i, optimization) in optimizations.iter().enumerate() {
let properties = match optimization.status {
FusionStatus::Closed(properties) => properties,
FusionStatus::Open(properties) => properties,
};
let properties = optimization.properties();
if properties.ready && properties.score >= best_score {
best_index = Some(i);

View File

@ -1,7 +1,11 @@
pub(crate) mod execution;
mod base;
mod execution;
mod context;
mod ops;
mod path;
pub use base::*;
pub use execution::*;
pub use context::*;
pub use ops::*;
pub use path::*;

View File

@ -13,7 +13,7 @@ pub trait Ops<B: FusionBackend>: Send + Sync {
}
/// Describe all tensor operations possible.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
pub enum TensorOpsDescription {
/// Basic operation on a float tensor.
BaseOpsFloat(BaseOpsDescription),
@ -36,7 +36,7 @@ pub enum TensorOpsDescription {
}
/// Operation description specific to a float tensor.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
pub enum FloatOpsDescription {
/// Operation corresponding to [exp](burn_tensor::ops::TensorOps::exp).
Exp(UnaryOpsDescription),
@ -61,13 +61,13 @@ pub enum FloatOpsDescription {
/// Operation corresponding to [matmul](burn_tensor::ops::TensorOps::matmul).
Matmul(BinaryOpsDescription),
/// Operation corresponding to [random](burn_tensor::ops::TensorOps::random).
Random((TensorDescription, Distribution)),
Random(RandomOpsDescription),
/// Operation corresponding to [recip](burn_tensor::ops::TensorOps::recip).
Recip(UnaryOpsDescription),
}
/// Operation description specific to module.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
pub enum ModuleOpsDescription {
/// Operation corresponding to [embedding](burn_tensor::ops::ModuleOps::embedding).
Embedding(EmbeddingDescription),
@ -124,7 +124,7 @@ pub enum ModuleOpsDescription {
}
/// Basic operations that can be done on any tensor type.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
pub enum BaseOpsDescription {
/// Operation corresponding to:
///
@ -177,8 +177,8 @@ pub enum BaseOpsDescription {
}
/// Numeric operations on int and float tensors.
#[derive(Clone, Debug)]
pub enum NumericOpsDescription<E: Element> {
#[derive(Clone, Debug, PartialEq)]
pub enum NumericOpsDescription<E> {
/// Operation corresponding to:
///
/// Float => [add](burn_tensor::ops::TensorOps::add).
@ -392,14 +392,14 @@ pub enum NumericOpsDescription<E: Element> {
}
/// Operation description specific to an int tensor.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
pub enum IntOpsDescription {
/// Operation corresponding to [into float](burn_tensor::ops::IntTensorOps::int_into_float).
IntoFloat(UnaryOpsDescription),
}
/// Operation description specific to a bool tensor.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
pub enum BoolOpsDescription {
/// Operation corresponding to [into float](burn_tensor::ops::BoolTensorOps::bool_into_float).
IntoFloat(UnaryOpsDescription),
@ -409,7 +409,7 @@ pub enum BoolOpsDescription {
Not(UnaryOpsDescription),
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
/// Swap dim operation description.
pub struct SwapDimsDescription {
/// Input tensor description.
@ -422,15 +422,21 @@ pub struct SwapDimsDescription {
pub dim2: usize,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
#[allow(missing_docs)]
pub struct RandomOpsDescription {
pub out: TensorDescription,
pub distribution: Distribution,
}
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct ReshapeDescription {
pub input: TensorDescription,
pub out: TensorDescription,
pub shape: Vec<usize>,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct BinaryOpsDescription {
pub lhs: TensorDescription,
@ -438,14 +444,14 @@ pub struct BinaryOpsDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct UnaryOpsDescription {
pub input: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
#[allow(missing_docs)]
pub struct ScalarOpsDescription<E> {
pub lhs: TensorDescription,
@ -453,7 +459,7 @@ pub struct ScalarOpsDescription<E> {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct GatherOpsDescription {
pub tensor: TensorDescription,
@ -462,7 +468,7 @@ pub struct GatherOpsDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct ScatterOpsDescription {
pub tensor: TensorDescription,
@ -472,7 +478,7 @@ pub struct ScatterOpsDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct SelectOpsDescription {
pub tensor: TensorDescription,
@ -481,7 +487,7 @@ pub struct SelectOpsDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct SelectAssignOpsDescription {
pub tensor: TensorDescription,
@ -491,7 +497,7 @@ pub struct SelectAssignOpsDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct SliceOpsDescription {
pub tensor: TensorDescription,
@ -499,7 +505,7 @@ pub struct SliceOpsDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct SliceAssignOpsDescription {
pub tensor: TensorDescription,
@ -508,7 +514,7 @@ pub struct SliceAssignOpsDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct MaskWhereOpsDescription {
pub tensor: TensorDescription,
@ -517,7 +523,7 @@ pub struct MaskWhereOpsDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
#[allow(missing_docs)]
pub struct MaskFillOpsDescription<E> {
pub tensor: TensorDescription,
@ -526,7 +532,7 @@ pub struct MaskFillOpsDescription<E> {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
#[allow(missing_docs)]
pub struct ClampOpsDescription<E> {
pub tensor: TensorDescription,
@ -535,17 +541,16 @@ pub struct ClampOpsDescription<E> {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct RepeatOpsDescription {
pub tensor: TensorDescription,
pub dim: usize,
pub times: usize,
pub shape: Vec<usize>,
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct CatOpsDescription {
pub tensors: Vec<TensorDescription>,
@ -553,7 +558,7 @@ pub struct CatOpsDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct ReduceDimWithIndicesDescription {
pub tensor: TensorDescription,
@ -562,7 +567,7 @@ pub struct ReduceDimWithIndicesDescription {
pub out_indices: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct EmbeddingDescription {
pub weights: TensorDescription,
@ -570,7 +575,7 @@ pub struct EmbeddingDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct EmbeddingBackwardDescription {
pub weights: TensorDescription,
@ -579,7 +584,7 @@ pub struct EmbeddingBackwardDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct Conv1dDescription {
pub x: TensorDescription,
@ -589,7 +594,7 @@ pub struct Conv1dDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct Conv2dDescription {
pub x: TensorDescription,
@ -599,7 +604,7 @@ pub struct Conv2dDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct ConvTranspose1dDescription {
pub x: TensorDescription,
@ -609,7 +614,7 @@ pub struct ConvTranspose1dDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct ConvTranspose2dDescription {
pub x: TensorDescription,
@ -619,7 +624,7 @@ pub struct ConvTranspose2dDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct AvgPool1dDescription {
pub x: TensorDescription,
@ -630,7 +635,7 @@ pub struct AvgPool1dDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct AvgPool2dDescription {
pub x: TensorDescription,
@ -641,7 +646,7 @@ pub struct AvgPool2dDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct AvgPool1dBackwardDescription {
pub x: TensorDescription,
@ -653,7 +658,7 @@ pub struct AvgPool1dBackwardDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct AvgPool2dBackwardDescription {
pub x: TensorDescription,
@ -665,7 +670,7 @@ pub struct AvgPool2dBackwardDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool1dDescription {
pub x: TensorDescription,
@ -673,7 +678,7 @@ pub struct AdaptiveAvgPool1dDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool2dDescription {
pub x: TensorDescription,
@ -681,7 +686,7 @@ pub struct AdaptiveAvgPool2dDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool1dBackwardDescription {
pub x: TensorDescription,
@ -689,7 +694,7 @@ pub struct AdaptiveAvgPool1dBackwardDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool2dBackwardDescription {
pub x: TensorDescription,
@ -697,7 +702,7 @@ pub struct AdaptiveAvgPool2dBackwardDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct MaxPool1dDescription {
pub x: TensorDescription,
@ -708,7 +713,7 @@ pub struct MaxPool1dDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct MaxPool1dWithIndicesDescription {
pub x: TensorDescription,
@ -720,7 +725,7 @@ pub struct MaxPool1dWithIndicesDescription {
pub out_indices: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct MaxPool1dWithIndicesBackwardDescription {
pub x: TensorDescription,
@ -733,7 +738,7 @@ pub struct MaxPool1dWithIndicesBackwardDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct MaxPool2dDescription {
pub x: TensorDescription,
@ -744,8 +749,8 @@ pub struct MaxPool2dDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug)]
#[allow(missing_docs)]
#[derive(Clone, Debug, Hash, PartialEq)]
pub struct MaxPool2dWithIndicesDescription {
pub x: TensorDescription,
pub kernel_size: [usize; 2],
@ -756,7 +761,7 @@ pub struct MaxPool2dWithIndicesDescription {
pub out_indices: TensorDescription,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Hash, PartialEq)]
#[allow(missing_docs)]
pub struct MaxPool2dWithIndicesBackwardDescription {
pub x: TensorDescription,
@ -1112,3 +1117,85 @@ impl ModuleOpsDescription {
}
}
}
impl core::hash::Hash for RandomOpsDescription {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.out.hash(state);
match self.distribution {
Distribution::Default => 1u8.hash(state),
Distribution::Bernoulli(_) => 2u8.hash(state),
Distribution::Uniform(_, _) => 3u8.hash(state),
Distribution::Normal(_, _) => 4u8.hash(state),
}
}
}
impl<E> core::hash::Hash for ScalarOpsDescription<E> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.lhs.hash(state);
self.out.hash(state);
}
}
impl<E> core::hash::Hash for MaskFillOpsDescription<E> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.tensor.hash(state);
self.mask.hash(state);
self.out.hash(state);
}
}
impl<E> core::hash::Hash for ClampOpsDescription<E> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.tensor.hash(state);
self.out.hash(state);
}
}
impl<E> core::hash::Hash for NumericOpsDescription<E> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
NumericOpsDescription::Add(desc) => desc.hash(state),
NumericOpsDescription::AddScalar(desc) => desc.hash(state),
NumericOpsDescription::Sub(desc) => desc.hash(state),
NumericOpsDescription::SubScalar(desc) => desc.hash(state),
NumericOpsDescription::Div(desc) => desc.hash(state),
NumericOpsDescription::DivScalar(desc) => desc.hash(state),
NumericOpsDescription::Mul(desc) => desc.hash(state),
NumericOpsDescription::MulScalar(desc) => desc.hash(state),
NumericOpsDescription::Abs(desc) => desc.hash(state),
NumericOpsDescription::Ones(desc) => desc.hash(state),
NumericOpsDescription::Zeros(desc) => desc.hash(state),
NumericOpsDescription::Full(desc) => desc.0.hash(state),
NumericOpsDescription::Gather(desc) => desc.hash(state),
NumericOpsDescription::Scatter(desc) => desc.hash(state),
NumericOpsDescription::Select(desc) => desc.hash(state),
NumericOpsDescription::SelectAssign(desc) => desc.hash(state),
NumericOpsDescription::MaskWhere(desc) => desc.hash(state),
NumericOpsDescription::MaskFill(desc) => desc.hash(state),
NumericOpsDescription::MeanDim(desc) => desc.hash(state),
NumericOpsDescription::Mean(desc) => desc.hash(state),
NumericOpsDescription::Sum(desc) => desc.hash(state),
NumericOpsDescription::SumDim(desc) => desc.hash(state),
NumericOpsDescription::EqualElem(desc) => desc.hash(state),
NumericOpsDescription::Greater(desc) => desc.hash(state),
NumericOpsDescription::GreaterElem(desc) => desc.hash(state),
NumericOpsDescription::GreaterEqual(desc) => desc.hash(state),
NumericOpsDescription::GreaterEqualElem(desc) => desc.hash(state),
NumericOpsDescription::Lower(desc) => desc.hash(state),
NumericOpsDescription::LowerElem(desc) => desc.hash(state),
NumericOpsDescription::LowerEqual(desc) => desc.hash(state),
NumericOpsDescription::LowerEqualElem(desc) => desc.hash(state),
NumericOpsDescription::ArgMax(desc) => desc.hash(state),
NumericOpsDescription::ArgMin(desc) => desc.hash(state),
NumericOpsDescription::Max(desc) => desc.hash(state),
NumericOpsDescription::MaxDimWithIndices(desc) => desc.hash(state),
NumericOpsDescription::MinDimWithIndices(desc) => desc.hash(state),
NumericOpsDescription::Min(desc) => desc.hash(state),
NumericOpsDescription::MaxDim(desc) => desc.hash(state),
NumericOpsDescription::MinDim(desc) => desc.hash(state),
NumericOpsDescription::Clamp(desc) => desc.hash(state),
NumericOpsDescription::ClampMax(desc) => desc.hash(state),
NumericOpsDescription::ClampMin(desc) => desc.hash(state),
}
}
}

View File

@ -0,0 +1,535 @@
use super::starter::Starters;
use crate::graph::TensorOpsDescription;
/// The cache works by keeping track of all possible optimizations for the current graph path.
///
/// # Details
///
/// This is pretty different from a normal key-value cache.
/// There is no key to access the cached values, since computing a key for a graph is very expensive.
/// Instead, we keep track of each new edge added to the graph and invalidate potential optimizations
/// when we see a different edge is added while keeping track of the current graph path.
///
/// Therefore, the overhead is very minimal, since the time-complexity of checking the cache
/// scales with the number of concurrent potential optimizations for the current path, which isn't
/// supposed to be big at any time.
pub(crate) struct OptimizationCache<O> {
candidates: Vec<OptimizationId>,
availables: Vec<(OptimizationId, usize)>,
optimizations: Vec<OptimizationItem<O>>,
starters: Starters,
found: Option<OptimizationId>,
}
impl<O> OptimizationCache<O> {
pub(crate) fn new() -> Self {
Self {
candidates: Vec::new(),
availables: Vec::new(),
optimizations: Vec::new(),
starters: Starters::default(),
found: None,
}
}
/// Follow the current path on the provided graph with the start/end condition.
///
/// # Notes
///
/// It is assumed that this function will be called for each new edge added to the graph (for
/// each new operation). Only one graph can be cached at a time.
pub(crate) fn follow<'a>(
&'a mut self,
graph: &[TensorOpsDescription],
condition: Condition,
) -> CacheResult<'a, O> {
if graph.is_empty() {
// When the graph is empty, we use the condition as the first operation to determine
// the new possible opitmizations.
let ops = match condition {
Condition::NextOps(ops) => ops,
Condition::Sync => return CacheResult::Miss, // Sync an empty graph doesn't make
// sense.
};
let candidates = self.starters.get(ops);
if candidates.is_empty() {
return CacheResult::Miss;
}
self.candidates = candidates;
return CacheResult::OnPath;
}
if let Some(candidate) = self.found {
return CacheResult::Found(&self.optimizations.get(candidate).unwrap().value);
}
// Invalidate candidates.
let mut invalidated_candidate = Vec::new();
for id in self.candidates.iter() {
let item = match self.optimizations.get(*id) {
Some(item) => item,
None => panic!("Should have an optimization"),
};
let next_ops = graph.last().expect("Validated earlier");
let next_ops_index = graph.len() - 1;
let next_ops_candidate = match item.graph.get(next_ops_index) {
Some(val) => val,
None => {
// Graph of different size, invalidated.
invalidated_candidate.push(*id);
continue;
}
};
if next_ops_candidate != next_ops {
// Graph with different node at the current position, invalidated.
invalidated_candidate.push(*id);
continue;
}
// Is it optimal?
if item.graph.len() == graph.len() {
let ops = match condition {
Condition::NextOps(ops) => ops,
Condition::Sync => {
self.found = Some(*id);
return CacheResult::Found(&item.value);
}
};
if item.end_conditions.contains(ops) {
self.found = Some(*id);
return CacheResult::Found(&item.value);
} else {
self.availables.push((*id, graph.len()));
invalidated_candidate.push(*id);
}
}
}
let mut updated_candidates = Vec::new();
core::mem::swap(&mut updated_candidates, &mut self.candidates);
self.candidates = updated_candidates
.into_iter()
.filter(|candidate| !invalidated_candidate.contains(candidate))
.collect();
if self.candidates.is_empty() {
CacheResult::Miss
} else {
CacheResult::OnPath
}
}
/// Signal the completion of a graph path that reached a new optimization.
///
/// # Notes
///
/// The optimization factory will only be called if the optimization is on a new graph.
/// When the optimization already exists, but with a different end condition, a new end
/// condition will be registered, but the old optimization will be used in following call. This
/// is intended since we want to factory to be called only once per graph, but reused as much as
/// possible.
pub fn complete<'a, Factory: OptimizationFactory<O>>(
&'a mut self,
factory: &Factory,
graph: Vec<TensorOpsDescription>,
next_ops: Option<TensorOpsDescription>,
) -> &'a O {
let existing_optim = self
.availables
.iter()
.find(|(_candidate, len)| *len == graph.len());
if let Some((id, _)) = existing_optim {
let optimization = self.optimizations.get_mut(*id).unwrap();
if let Some(ops) = next_ops {
optimization.end_conditions.push(ops)
};
return &optimization.value;
};
self.starters
.insert(graph.first().unwrap(), self.optimizations.len());
let optimization = OptimizationItem {
graph,
end_conditions: match next_ops {
Some(val) => vec![val],
None => Vec::new(),
},
value: factory.create(),
};
self.optimizations.push(optimization);
&self.optimizations.last().unwrap().value
}
// Signal that a new path will begin.
pub(crate) fn reset(&mut self) {
self.candidates.clear();
self.availables.clear();
self.found = None;
}
}
/// Action to be made depending on the graph.
#[derive(PartialEq, Eq)]
pub enum CacheResult<'a, T> {
/// Continue exploring optimizations using the [builder](crate::OptimizationBuilder).
Miss,
/// The current graph indicates that an optimization may be possible in the future, so the
/// best action is to wait for the optimization to become available.
///
/// Sometimes, it can be a false positive and a new optimization should be built from scratch.
/// Therefore it's important to keep the previous operations to rebuild the state if it
/// happens.
OnPath,
/// An optimization has been found, and the best action is to execute it!
Found(&'a T),
}
/// When checking if an optimization is possible, a start or an end condition ensures that this optimization is
/// always optimal.
#[derive(Clone)]
pub enum Condition<'a> {
/// The next operation that signals the start or end of the operation.
NextOps(&'a TensorOpsDescription),
/// When sync, we should execute the optimization if found no matter what comes next.
Sync,
}
impl<'a, T> core::fmt::Debug for CacheResult<'a, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CacheResult::Miss => f.write_str("CacheResult::Miss"),
CacheResult::OnPath => f.write_str("CacheResult::OnPath"),
CacheResult::Found(_) => f.write_str("CacheResult::Found"),
}
}
}
/// Create an optimization.
pub(crate) trait OptimizationFactory<T> {
/// Call only when a new optimization is found.
fn create(&self) -> T;
}
pub(super) type OptimizationId = usize;
struct OptimizationItem<O> {
graph: Vec<TensorOpsDescription>,
end_conditions: Vec<TensorOpsDescription>,
value: O,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
graph::{FloatOpsDescription, UnaryOpsDescription},
TensorDescription, TensorId, TensorStatus,
};
#[test]
fn should_cache_optimization_end_condition_forced() {
// A graph with 3 ops.
let graph = TestGraph::new(2);
let mut path = OptimizationCache::new();
// First following
graph.follow_misses(&mut path);
// Register the action.
let optimization = path.complete(&Optimization1, graph.edges[0..2].to_vec(), None);
assert_eq!(optimization, &Optimization1.create());
// Second following on the same ops.
path.reset();
let result1 = path.follow(&[], Condition::NextOps(&graph.edges[0]));
assert_eq!(result1, CacheResult::OnPath);
let result2 = path.follow(&graph.edges[0..1], Condition::NextOps(&graph.edges[1]));
assert_eq!(result2, CacheResult::OnPath);
let result3 = path.follow(&graph.edges[0..2], Condition::Sync);
match result3 {
CacheResult::Found(ops) => assert_eq!(ops, &Optimization1.create()),
_ => panic!("Should have found the cached operation"),
};
}
#[test]
fn once_found_perfect_should_always_return_found() {
let mut graph = TestGraph::new(2);
let mut path = OptimizationCache::new();
graph.follow_misses(&mut path);
// Register the action.
let _optimization = path.complete(
&Optimization1,
graph.edges[0..1].to_vec(),
Some(graph.edges[1].clone()),
);
path.reset();
graph.new_ops();
graph.new_ops();
let result = path.follow(&[], Condition::NextOps(&graph.edges[0]));
assert_eq!(result, CacheResult::OnPath);
let result = path.follow(&graph.edges[0..1], Condition::NextOps(&graph.edges[1]));
match result {
CacheResult::Found(ops) => assert_eq!(ops, &Optimization1.create()),
_ => panic!("Should have found the cached operation"),
}
let result = path.follow(&graph.edges[0..2], Condition::NextOps(&graph.edges[2]));
match result {
CacheResult::Found(ops) => assert_eq!(ops, &Optimization1.create()),
_ => panic!("Should have found the cached operation"),
}
}
#[test]
fn should_cache_optimization_end_condition_next_ops() {
// A graph with 4 ops.
let graph = TestGraph::new(3);
let mut path = OptimizationCache::new();
// First following
graph.follow_misses(&mut path);
// Register the action.
let optimization = path.complete(
&Optimization1,
graph.edges[0..2].to_vec(),
Some(graph.edges[2].clone()),
);
assert_eq!(optimization, &Optimization1.create());
// Second following on the same ops.
path.reset();
let result1 = path.follow(&[], Condition::NextOps(&graph.edges[0]));
assert_eq!(result1, CacheResult::OnPath);
let result2 = path.follow(&graph.edges[0..1], Condition::NextOps(&graph.edges[1]));
assert_eq!(result2, CacheResult::OnPath);
let result3 = path.follow(&graph.edges[0..2], Condition::NextOps(&graph.edges[2]));
match result3 {
CacheResult::Found(ops) => assert_eq!(ops, &Optimization1.create()),
_ => panic!("Should have found the cached operation"),
};
}
#[test]
fn should_support_many_different_end_conditions() {
let mut graph1 = TestGraph::new(2);
graph1.register_ops(|desc| TensorOpsDescription::FloatOps(FloatOpsDescription::Exp(desc)));
let mut graph2 = TestGraph::new(2);
graph2.register_ops(|desc| TensorOpsDescription::FloatOps(FloatOpsDescription::Log(desc)));
let mut path = OptimizationCache::<String>::new();
let last_edge_index = graph1.edges.len() - 1;
// Follow graph 1 with only misses.
graph1.follow_misses(&mut path);
let _ = path.complete(
&Optimization1,
graph1.edges[0..last_edge_index].to_vec(),
Some(graph1.edges[last_edge_index].clone()),
);
// Follow graph 2.
let result = path.follow(&[], Condition::NextOps(&graph2.edges[0]));
assert_eq!(result, CacheResult::OnPath);
let result = path.follow(&graph2.edges[0..1], Condition::NextOps(&graph2.edges[1]));
assert_eq!(result, CacheResult::OnPath);
let result = path.follow(&graph2.edges[0..2], Condition::NextOps(&graph2.edges[2]));
assert_eq!(result, CacheResult::Miss);
let optimization = path.complete(
&Optimization2,
graph2.edges[0..last_edge_index].to_vec(),
Some(graph2.edges[last_edge_index].clone()),
);
assert_eq!(
optimization,
&Optimization1.create(),
"Optimization 1 should still be returned, since same graph but not same end condition."
);
}
#[test]
fn should_support_multiple_concurrent_paths() {
// Two different graphs with a different second ops, but the same last ops.
let mut graph1 = TestGraph::new(1);
graph1.register_ops(|desc| TensorOpsDescription::FloatOps(FloatOpsDescription::Exp(desc)));
graph1.new_ops();
let mut graph2 = TestGraph::new(1);
graph2.register_ops(|desc| TensorOpsDescription::FloatOps(FloatOpsDescription::Cos(desc)));
graph2.new_ops();
let mut path = OptimizationCache::<String>::new();
// Follow graph 1 with only misses.
graph1.follow_misses(&mut path);
// Register the opitmization 1 for graph 1.
let last_edge_index = graph1.edges.len() - 1;
let _ = path.complete(
&Optimization1,
graph1.edges[0..last_edge_index].to_vec(),
Some(graph1.edges[last_edge_index].clone()),
);
// Follow graph 2 and register a new optimization.
path.reset();
let result = path.follow(&[], Condition::NextOps(&graph2.edges[0]));
assert_eq!(result, CacheResult::OnPath);
let result = path.follow(&graph2.edges[0..1], Condition::NextOps(&graph2.edges[1]));
assert_eq!(result, CacheResult::OnPath);
let result = path.follow(&graph2.edges[0..2], Condition::NextOps(&graph2.edges[2]));
assert_eq!(
result,
CacheResult::Miss,
"Should invalidate the second operation"
);
// Register new optimization for path 2.
let _ = path.complete(
&Optimization2,
graph2.edges[0..last_edge_index].to_vec(),
Some(graph2.edges[last_edge_index].clone()),
);
// Now let's validate that the cache works.
// New path instance on graph 1.
path.reset();
let result = path.follow(&[], Condition::NextOps(&graph1.edges[0]));
assert_eq!(result, CacheResult::OnPath);
let result = path.follow(&graph1.edges[0..1], Condition::NextOps(&graph1.edges[1]));
assert_eq!(result, CacheResult::OnPath);
let result = path.follow(&graph1.edges[0..2], Condition::NextOps(&graph1.edges[2]));
match result {
CacheResult::Found(ops) => assert_eq!(ops, &Optimization1.create()),
_ => panic!("Should have found the cached operation"),
};
// New path instance on graph 2.
path.reset();
let result = path.follow(&[], Condition::NextOps(&graph2.edges[0]));
assert_eq!(result, CacheResult::OnPath);
let result = path.follow(&graph2.edges[0..1], Condition::NextOps(&graph2.edges[1]));
assert_eq!(result, CacheResult::OnPath);
let result = path.follow(&graph2.edges[0..2], Condition::NextOps(&graph2.edges[2]));
match result {
CacheResult::Found(ops) => assert_eq!(ops, &Optimization2.create()),
_ => panic!("Should have found the cached operation"),
};
}
#[derive(Default, Debug)]
struct TestGraph {
nodes: Vec<TensorDescription>,
edges: Vec<TensorOpsDescription>,
}
impl TestGraph {
/// Create a new test graph with `num_ops` operations registered.
pub fn new(num_ops: usize) -> Self {
let mut graph = Self::default();
for _ in 0..num_ops {
graph.new_ops();
}
graph
}
/// The first follow should only be cache miss.
pub fn follow_misses(&self, path: &mut OptimizationCache<String>) {
for i in 0..self.edges.len() {
let result = path.follow(&self.edges[0..i], Condition::NextOps(&self.edges[i]));
assert_eq!(result, CacheResult::Miss);
}
}
/// Register a unary operation in the graph.
pub fn register_ops<F>(&mut self, func: F)
where
F: Fn(UnaryOpsDescription) -> TensorOpsDescription,
{
self.new_empty_node();
let desc = self.unary_description();
self.edges.push(func(desc));
}
/// Add a simple operation to the graph.
pub fn new_ops(&mut self) {
if self.nodes.is_empty() {
// Root node.
self.new_empty_node();
}
// Out node.
self.new_empty_node();
self.edges
.push(TensorOpsDescription::FloatOps(FloatOpsDescription::Log(
self.unary_description(),
)));
}
fn new_empty_node(&mut self) {
self.nodes.push(TensorDescription {
id: TensorId::new(self.nodes.len() as u64),
shape: vec![32, 32, 1],
status: TensorStatus::NotInit,
});
}
fn unary_description(&self) -> UnaryOpsDescription {
let size = self.nodes.len();
UnaryOpsDescription {
input: self.nodes[size - 2].clone(),
out: self.nodes[size - 1].clone(),
}
}
}
struct Optimization1;
struct Optimization2;
impl OptimizationFactory<String> for Optimization1 {
fn create(&self) -> String {
"Optimization1".to_string()
}
}
impl OptimizationFactory<String> for Optimization2 {
fn create(&self) -> String {
"Optimization2".to_string()
}
}
}

View File

@ -0,0 +1,4 @@
mod base;
pub use base::*;
mod starter;

View File

@ -0,0 +1,75 @@
use super::OptimizationId;
use crate::graph::TensorOpsDescription;
use std::{
collections::{hash_map::DefaultHasher, HashMap},
hash::{Hash, Hasher},
};
#[derive(Default)]
pub(crate) struct Starters {
starter_indices: HashMap<u64, Vec<(TensorOpsDescription, usize)>>,
starters: Vec<Vec<OptimizationId>>,
}
impl Starters {
pub(crate) fn get(&self, ops: &TensorOpsDescription) -> Vec<OptimizationId> {
let key = self.graph_key(ops);
let values = match self.starter_indices.get(&key) {
Some(val) => val,
None => return Vec::new(),
};
if values.is_empty() {
return Vec::new();
}
let (_, index) = match values.iter().find(|value| &value.0 == ops) {
Some(val) => val,
None => return Vec::new(),
};
let val = match self.starters.get(*index) {
Some(value) => value.clone(),
None => Vec::new(),
};
val
}
pub(crate) fn insert(&mut self, ops: &TensorOpsDescription, new_id: OptimizationId) {
let key = self.graph_key(ops);
let values = match self.starter_indices.get_mut(&key) {
Some(val) => val,
None => {
// New starter ops.
let index = self.starters.len();
self.starters.push(vec![new_id]);
self.starter_indices.insert(key, vec![(ops.clone(), index)]);
return;
}
};
let (_, index) = match values.iter_mut().find(|value| &value.0 == ops) {
Some(val) => val,
None => {
// New with hash collision.
let index = self.starters.len();
self.starters.push(vec![new_id]);
values.push((ops.clone(), index));
return;
}
};
// New optimization for an existing starter.
self.starters
.get_mut(*index)
.expect("Should exist")
.push(new_id);
}
fn graph_key(&self, ops: &TensorOpsDescription) -> u64 {
let mut hasher = DefaultHasher::new();
ops.hash(&mut hasher);
hasher.finish()
}
}

View File

@ -138,17 +138,16 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for ReshapeDimsOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_bool_tensor::<D1>(&self.desc.input);
let output = B::bool_reshape::<D1, D2>(input, Shape::from(&self.desc.shape));
let output = B::bool_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
handles.register_bool_tensor(&self.desc.out.id, output);
}
}
let shape: Vec<usize> = shape.dims.into();
let out = tensor.client.tensor_uninitialized(shape.clone());
let out = tensor.client.tensor_uninitialized(shape);
let desc = ReshapeDescription {
input: tensor.into_description(),
shape,
out: out.to_description_out(),
};
out.client.register(

View File

@ -5,10 +5,10 @@ use crate::{
graph::{
BaseOpsDescription, BinaryOpsDescription, CatOpsDescription, ClampOpsDescription,
FloatOpsDescription, GatherOpsDescription, MaskFillOpsDescription, MaskWhereOpsDescription,
NumericOpsDescription, Ops, ReduceDimWithIndicesDescription, ReshapeDescription,
ScalarOpsDescription, ScatterOpsDescription, SelectAssignOpsDescription,
SelectOpsDescription, SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription,
TensorOpsDescription, UnaryOpsDescription,
NumericOpsDescription, Ops, RandomOpsDescription, ReduceDimWithIndicesDescription,
ReshapeDescription, ScalarOpsDescription, ScatterOpsDescription,
SelectAssignOpsDescription, SelectOpsDescription, SliceAssignOpsDescription,
SliceOpsDescription, SwapDimsDescription, TensorOpsDescription, UnaryOpsDescription,
},
ops::binary::binary_ops_shape,
scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops, unary_float_ops, Fusion,
@ -39,16 +39,15 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
) -> FloatTensor<Self, D> {
#[derive(new)]
struct RandomOps<const D: usize> {
out: TensorDescription,
distribution: Distribution,
desc: RandomOpsDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for RandomOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let shape = Shape::from(self.out.shape.clone());
let shape = Shape::from(self.desc.out.shape.clone());
let output: B::TensorPrimitive<D> =
B::random(shape, self.distribution, &handles.device);
handles.register_float_tensor(&self.out.id, output);
B::random(shape, self.desc.distribution, &handles.device);
handles.register_float_tensor(&self.desc.out.id, output);
}
}
@ -56,10 +55,13 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let client = get_client::<B>(&device.clone().into());
let out = client.tensor_uninitialized(shape);
let desc = (out.to_description_out(), distribution);
let desc = RandomOpsDescription {
out: out.to_description_out(),
distribution,
};
client.register(
TensorOpsDescription::FloatOps(FloatOpsDescription::Random(desc.clone())),
RandomOps::<D>::new(desc.0, desc.1),
RandomOps::<D>::new(desc),
);
out
@ -548,17 +550,16 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for ReshapeDimsOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_float_tensor::<D1>(&self.desc.input);
let output = B::reshape::<D1, D2>(input, Shape::from(&self.desc.shape));
let output = B::reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
handles.register_float_tensor(&self.desc.out.id, output);
}
}
let shape: Vec<usize> = shape.dims.into();
let out = tensor.client.tensor_uninitialized(shape.clone());
let out = tensor.client.tensor_uninitialized(shape);
let desc = ReshapeDescription {
input: tensor.into_description(),
shape,
out: out.to_description_out(),
};
out.client.register(

View File

@ -81,17 +81,16 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for ReshapeDimsOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_int_tensor::<D1>(&self.desc.input);
let output = B::int_reshape::<D1, D2>(input, Shape::from(&self.desc.shape));
let output = B::int_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
handles.register_int_tensor(&self.desc.out.id, output);
}
}
let shape: Vec<usize> = shape.dims.into();
let out = tensor.client.tensor_uninitialized(shape.clone());
let out = tensor.client.tensor_uninitialized(shape);
let desc = ReshapeDescription {
input: tensor.into_description(),
shape,
out: out.to_description_out(),
};
out.client.register(

View File

@ -1,69 +1,48 @@
use crate::{
graph::{Graph, GraphExecution, Ops, Optimization, TensorOpsDescription},
FusionBackend, FusionProperties, FusionStatus, HandleContainer, TensorId,
graph::{
execution::{ExecutionMode, GraphExecution},
Graph, Ops, TensorOpsDescription,
},
FusionBackend, HandleContainer, TensorId,
};
use burn_tensor::ops::{FloatElem, IntElem};
use std::sync::Arc;
pub struct FusionServer<B, G>
pub struct FusionServer<B>
where
B: FusionBackend,
G: GraphExecution<B>,
{
optimizations: Vec<Optimization<B>>,
execution: GraphExecution<B>,
graph: Graph<B>,
pub(crate) handles: HandleContainer<B>,
execution: G,
pub device: B::FusionDevice,
pub num_skipped: usize,
}
impl<B, G> FusionServer<B, G>
impl<B> FusionServer<B>
where
B: FusionBackend,
G: GraphExecution<B>,
{
pub fn new(device: B::FusionDevice) -> Self {
let optimizations = B::operations(&device.clone().into())
.into_iter()
.map(|ops| Optimization::new(ops, FusionStatus::Open(FusionProperties::default())))
.collect();
Self {
optimizations,
execution: GraphExecution::new(B::optimizations(&device.clone().into())),
graph: Graph::new(),
handles: HandleContainer::new(device.clone()),
execution: G::default(),
num_skipped: 0,
device,
}
}
pub fn register(&mut self, desc: TensorOpsDescription, op: Box<dyn Ops<B>>) {
let ops = Arc::new(desc);
self.graph.add(ops.clone(), op);
self.optimizations
.iter_mut()
.for_each(|optimization| optimization.register(&ops));
self.execution.maybe_execute(
&mut self.graph,
&mut self.handles,
&mut self.optimizations,
false,
);
pub fn register(&mut self, ops_desc: TensorOpsDescription, ops: Box<dyn Ops<B>>) {
self.graph.add(ops_desc, ops);
self.execution
.execute(&mut self.graph, &mut self.handles, ExecutionMode::NewOps);
}
pub fn drain_graph(&mut self) {
if self.graph.is_empty() {
return;
}
self.execution.maybe_execute(
&mut self.graph,
&mut self.handles,
&mut self.optimizations,
true,
);
// Check if we can execute.
self.execution
.execute(&mut self.graph, &mut self.handles, ExecutionMode::Sync);
}
pub fn create_empty_handle(&mut self) -> Arc<TensorId> {

View File

@ -26,7 +26,7 @@ pub struct Data<E, const D: usize> {
}
/// Distribution for random value of a tensor.
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Distribution {
/// Uniform distribution from 0 (inclusive) to 1 (exclusive).
Default,

View File

@ -66,7 +66,7 @@ pub struct Conv1dBackward<B: Backend> {
}
/// Convolution options.
#[derive(new, Debug, Clone, Hash)]
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
pub struct ConvOptions<const N: usize> {
/// Stride.
pub stride: [usize; N],
@ -82,7 +82,7 @@ pub struct ConvOptions<const N: usize> {
}
/// Transposed convolution options.
#[derive(new, Debug, Clone, Hash)]
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
pub struct ConvTransposeOptions<const N: usize> {
/// Stride.
pub stride: [usize; N],

View File

@ -1,13 +1,11 @@
use crate::{
compute::{WgpuComputeClient, WgpuHandle},
element::WgpuElement,
fusion::FloatElementWiseFusionOps,
fusion::FloatElementWiseBuilder,
tensor::WgpuTensor,
FloatElement, GraphicsApi, IntElement, Wgpu, WgpuDevice,
};
use burn_fusion::{
client::MutexFusionClient, graph::GreedyGraphExecution, DeviceId, FusionBackend, FusionDevice,
};
use burn_fusion::{client::MutexFusionClient, DeviceId, FusionBackend, FusionDevice};
use burn_tensor::Shape;
use core::marker::PhantomData;
@ -31,10 +29,10 @@ where
{
type FusionDevice = WgpuDevice;
type Handle = WgpuFusionHandle;
type FusionClient = MutexFusionClient<Self, GreedyGraphExecution>;
type FusionClient = MutexFusionClient<Self>;
fn operations(device: &WgpuDevice) -> Vec<Box<dyn burn_fusion::FusionOps<Self>>> {
vec![Box::new(FloatElementWiseFusionOps::new(device.clone()))]
fn optimizations(device: &WgpuDevice) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self>>> {
vec![Box::new(FloatElementWiseBuilder::new(device.clone()))]
}
fn float_tensor<const D: usize>(

View File

@ -1,7 +1,6 @@
use crate::{
element::WgpuElement,
fusion::codegen::{Elem, Operator, Variable},
fusion::kernel::FusionKernel,
FloatElement, GraphicsApi, IntElement, Wgpu,
};
use burn_fusion::{
@ -9,13 +8,16 @@ use burn_fusion::{
BaseOpsDescription, BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription,
ScalarOpsDescription, TensorOpsDescription, UnaryOpsDescription,
},
FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription, TensorId,
Optimization, OptimizationBuilder, OptimizationProperties, OptimizationStatus,
TensorDescription, TensorId,
};
use burn_tensor::{Device, Element};
use hashbrown::HashMap;
use super::optimization::FloatElementWise;
/// Fused element wise operations that are normally memory bound.
pub struct FloatElementWiseFusionOps<G, F, I>
pub(crate) struct FloatElementWiseBuilder<G, F, I>
where
G: GraphicsApi,
F: FloatElement,
@ -24,48 +26,56 @@ where
pub(crate) inputs: Vec<TensorDescription>,
pub(crate) locals: HashMap<TensorId, u16>,
pub(crate) tensors: HashMap<TensorId, (TensorDescription, Elem)>,
pub(crate) scalars_f32: Vec<f32>,
pub(crate) scalars_i32: Vec<i32>,
pub(crate) scalars_u32: Vec<u32>,
pub(crate) booleans: Vec<bool>,
pub(crate) scalars_f32: usize,
pub(crate) scalars_i32: usize,
pub(crate) scalars_u32: usize,
pub(crate) booleans: usize,
pub(crate) operators: Vec<Operator>,
pub(crate) properties: FusionProperties,
pub(crate) current_output_shape: Vec<usize>,
device: Device<Wgpu<G, F, I>>,
pub(crate) status: OptimizationStatus,
pub(crate) device: Device<Wgpu<G, F, I>>,
}
impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G, F, I>>
for FloatElementWiseFusionOps<G, F, I>
impl<G, F, I> OptimizationBuilder<Wgpu<G, F, I>> for FloatElementWiseBuilder<G, F, I>
where
G: GraphicsApi,
F: FloatElement,
I: IntElement,
{
fn register(&mut self, ops: &TensorOpsDescription) -> FusionStatus {
fn register(&mut self, ops: &TensorOpsDescription) {
if let OptimizationStatus::Closed = self.status {
return;
}
match ops {
TensorOpsDescription::BaseOpsFloat(ops) => {
if !self.register_base::<F>(ops) {
return FusionStatus::Closed(self.properties);
self.status = OptimizationStatus::Closed;
return;
}
}
TensorOpsDescription::FloatOps(ops) => {
if !self.register_float::<F>(ops) {
return FusionStatus::Closed(self.properties);
self.status = OptimizationStatus::Closed;
return;
}
}
TensorOpsDescription::NumericOpsFloat(ops) => {
if !self.register_numeric(ops) {
return FusionStatus::Closed(self.properties);
self.status = OptimizationStatus::Closed;
return;
}
}
_ => {
return FusionStatus::Closed(self.properties);
self.status = OptimizationStatus::Closed;
return;
}
};
self.properties.score += 1;
self.properties.ready = self.operators.len() > 1;
FusionStatus::Open(self.properties)
self.status = OptimizationStatus::Open;
}
fn execute(&mut self, handles: &mut HandleContainer<Wgpu<G, F, I>>) {
fn build(&self) -> Box<dyn Optimization<Wgpu<G, F, I>>> {
let inputs = self.input_descriptions();
let outputs = self.output_descriptions();
let locals = outputs
@ -73,32 +83,47 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G,
.map(|out| *self.locals.get(&out.0.id).unwrap())
.collect::<Vec<_>>();
FusionKernel::new(&self.device)
.inputs(&inputs, &self.scalars_f32)
.body(&self.operators)
.outputs(&outputs, &locals)
.execute(handles);
Box::new(FloatElementWise {
inputs,
outputs,
locals,
operators: self.operators.clone(),
scalars_f32: self.scalars_f32,
device: self.device.clone(),
})
}
fn reset(&mut self) {
self.inputs.clear();
self.locals.drain();
self.tensors.clear();
self.scalars_f32.clear();
self.scalars_i32.clear();
self.scalars_u32.clear();
self.booleans.clear();
self.scalars_f32 = 0;
self.scalars_i32 = 0;
self.scalars_u32 = 0;
self.booleans = 0;
self.operators.clear();
self.properties = FusionProperties::default();
self.status = OptimizationStatus::Open;
self.current_output_shape.clear();
}
fn len(&self) -> usize {
self.operators.len()
fn status(&self) -> OptimizationStatus {
self.status
}
fn properties(&self) -> OptimizationProperties {
let ready = match self.status {
OptimizationStatus::Closed => false,
OptimizationStatus::Open => self.operators.len() > 1,
};
OptimizationProperties {
ready,
score: self.operators.len() as u64,
}
}
}
impl<G, F, I> FloatElementWiseFusionOps<G, F, I>
impl<G, F, I> FloatElementWiseBuilder<G, F, I>
where
G: GraphicsApi,
F: FloatElement,
@ -109,28 +134,28 @@ where
inputs: Vec::new(),
locals: HashMap::new(),
tensors: HashMap::new(),
scalars_f32: Vec::new(),
scalars_i32: Vec::new(),
scalars_u32: Vec::new(),
booleans: Vec::new(),
scalars_f32: 0,
scalars_i32: 0,
scalars_u32: 0,
booleans: 0,
operators: Vec::new(),
current_output_shape: Vec::new(),
properties: FusionProperties::default(),
status: OptimizationStatus::Open,
device,
}
}
fn input_descriptions(&self) -> Vec<&(TensorDescription, Elem)> {
fn input_descriptions(&self) -> Vec<(TensorDescription, Elem)> {
self.inputs
.iter()
.map(|input| {
let updated_tensor = self.tensors.get(&input.id).unwrap();
updated_tensor
updated_tensor.clone()
})
.collect::<Vec<_>>()
}
fn output_descriptions(&self) -> Vec<&(TensorDescription, Elem)> {
fn output_descriptions(&self) -> Vec<(TensorDescription, Elem)> {
let mut outputs = Vec::new();
let mut local_tensor_ids_input = Vec::new();
let mut local_tensor_ids_output = Vec::new();
@ -271,7 +296,7 @@ where
let is_read = local_tensor_ids_input.contains(&out);
if !is_read {
outputs.push(self.tensors.get(&out).unwrap());
outputs.push(self.tensors.get(&out).unwrap().clone());
}
}
@ -281,7 +306,7 @@ where
let (tensor, _) = &entry;
if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
if self.locals.contains_key(&tensor.id) {
outputs.push(entry);
outputs.push(entry.clone());
}
}
}
@ -503,12 +528,13 @@ where
let rhs = self.input_to_var(&desc.tensor, E::elem_type());
let out = self.output_to_var(&desc.out, E::elem_type());
self.operators.push(Operator::ConditionalAssign {
let ops = Operator::ConditionalAssign {
cond,
lhs,
rhs,
out,
});
};
self.operators.push(ops);
true
}
@ -600,19 +626,19 @@ where
true
}
fn scalar_to_var<E: Element>(&mut self, value: &E, elem_type: Elem) -> Variable {
fn scalar_to_var<E: Element>(&mut self, _value: &E, elem_type: Elem) -> Variable {
match elem_type {
Elem::F32 => {
self.scalars_f32.push(value.elem());
Variable::Scalar(self.scalars_f32.len() as u16 - 1, Elem::F32)
self.scalars_f32 += 1;
Variable::Scalar(self.scalars_f32 as u16 - 1, Elem::F32)
}
Elem::I32 => {
self.scalars_i32.push(value.elem());
Variable::Scalar(self.scalars_i32.len() as u16 - 1, Elem::I32)
self.scalars_i32 += 1;
Variable::Scalar(self.scalars_i32 as u16 - 1, Elem::I32)
}
Elem::U32 => {
self.scalars_u32.push(value.elem());
Variable::Scalar(self.scalars_u32.len() as u16 - 1, Elem::U32)
self.scalars_u32 += 1;
Variable::Scalar(self.scalars_u32 as u16 - 1, Elem::U32)
}
Elem::Bool => {
panic!("Bool scalars not supported")
@ -630,50 +656,3 @@ where
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn_fusion::graph::Ops;
use burn_fusion::{Fusion, FusionBackend};
use burn_tensor::Tensor;
struct FakeAddOps;
impl<B: FusionBackend> Ops<B> for FakeAddOps {
fn execute(self: Box<Self>, _: &mut HandleContainer<B>) {
todo!()
}
}
#[test]
fn test_fusion_same_behavior() {
type Backend = Wgpu;
type FusedBackend = Fusion<Wgpu>;
let data_1 =
Tensor::<Backend, 2>::random([1, 32], burn_tensor::Distribution::Default).into_data();
let data_2 =
Tensor::<Backend, 2>::random([32, 32], burn_tensor::Distribution::Default).into_data();
let tensor_1 = Tensor::<Backend, 2>::from_data(data_1.clone());
let tensor_2 = Tensor::<Backend, 2>::from_data(data_2.clone());
let tensor_3 = tensor_1.clone() + tensor_2;
let tensor_4 = tensor_3.clone() - tensor_1;
let tensor_5 = tensor_4.clone() + 5.0;
let tensor_6 = tensor_5 + tensor_3.clone();
let mask = tensor_4.lower_equal(tensor_3);
let result_ref = tensor_6.mask_fill(mask, 0.3).into_data();
let tensor_1 = Tensor::<FusedBackend, 2>::from_data(data_1);
let tensor_2 = Tensor::<FusedBackend, 2>::from_data(data_2);
let tensor_3 = tensor_1.clone() + tensor_2;
let tensor_4 = tensor_3.clone() - tensor_1;
let tensor_5 = tensor_4.clone() + 5.0;
let tensor_6 = tensor_5 + tensor_3.clone();
let mask = tensor_4.lower_equal(tensor_3);
let result_fused = tensor_6.mask_fill(mask, 0.3).into_data();
result_fused.assert_approx_eq(&result_ref, 3);
}
}

View File

@ -1,3 +1,4 @@
mod ops;
mod builder;
mod optimization;
pub use ops::*;
pub(crate) use builder::*;

View File

@ -0,0 +1,170 @@
use crate::{
fusion::codegen::{Elem, Operator},
fusion::kernel::FusionKernel,
FloatElement, GraphicsApi, IntElement, Wgpu,
};
use burn_fusion::{graph::Context, Optimization, TensorDescription};
use burn_tensor::Device;
#[derive(Clone)]
pub(crate) struct FloatElementWise<G, F, I>
where
G: GraphicsApi,
F: FloatElement,
I: IntElement,
{
pub(crate) inputs: Vec<(TensorDescription, Elem)>,
pub(crate) outputs: Vec<(TensorDescription, Elem)>,
pub(crate) locals: Vec<u16>,
pub(crate) operators: Vec<Operator>,
pub(crate) scalars_f32: usize,
pub(crate) device: Device<Wgpu<G, F, I>>,
}
impl<G, F, I> Optimization<Wgpu<G, F, I>> for FloatElementWise<G, F, I>
where
G: GraphicsApi,
F: FloatElement,
I: IntElement,
{
fn execute(&self, context: &mut Context<'_, Wgpu<G, F, I>>) {
let inputs = self
.inputs
.iter()
.map(|(tensor, elem)| (context.tensors.get(&tensor.id).unwrap(), *elem))
.collect::<Vec<_>>();
let outputs = self
.outputs
.iter()
.map(|(tensor, elem)| (context.tensors.get(&tensor.id).unwrap(), *elem))
.collect::<Vec<_>>();
// The context may contain scalars for the end condition, which may vary.
let scalars_f32 = &context.scalar_floats[0..self.scalars_f32];
FusionKernel::new(&self.device)
.inputs(&inputs, scalars_f32)
.body(&self.operators)
.outputs(&outputs, &self.locals)
.execute(context.handles);
}
fn len(&self) -> usize {
self.operators.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn_fusion::graph::Ops;
use burn_fusion::{Fusion, FusionBackend};
use burn_tensor::{backend::Backend, Data, Tensor};
#[test]
fn test_fusion_same_behavior() {
type Backend = Wgpu;
type FusedBackend = Fusion<Wgpu>;
let data_1 = Tensor::<FusedBackend, 2>::random([1, 32], burn_tensor::Distribution::Default)
.into_data();
let data_2 =
Tensor::<Backend, 2>::random([32, 32], burn_tensor::Distribution::Default).into_data();
let result_ref = execute::<Backend>(
data_1.clone(),
data_2.clone(),
ImplementationDetails::Variant1,
);
let result_fused = execute::<FusedBackend>(
data_1.clone(),
data_2.clone(),
ImplementationDetails::Variant1,
);
result_ref.assert_approx_eq(&result_fused, 3);
}
#[test]
fn test_fusion_same_behavior_different_variant() {
type Backend = Wgpu;
type FusedBackend = Fusion<Wgpu>;
let data_1 = Tensor::<FusedBackend, 2>::random([1, 32], burn_tensor::Distribution::Default)
.into_data();
let data_2 =
Tensor::<Backend, 2>::random([32, 32], burn_tensor::Distribution::Default).into_data();
let result_ref = execute::<Backend>(
data_1.clone(),
data_2.clone(),
ImplementationDetails::Variant2,
);
let result_fused_variant1 = execute::<FusedBackend>(
data_1.clone(),
data_2.clone(),
ImplementationDetails::Variant1,
);
let result_fused_variant2 = execute::<FusedBackend>(
data_1.clone(),
data_2.clone(),
ImplementationDetails::Variant2,
);
result_ref.assert_approx_eq(&result_fused_variant1, 3);
result_ref.assert_approx_eq(&result_fused_variant2, 3);
}
#[test]
fn test_end_condition_scalar_ops() {
type Backend = Fusion<Wgpu>;
let tensor1 = Tensor::<Backend, 2>::ones([32, 32]);
let tensor2 = Tensor::<Backend, 2>::ones([32, 42]);
let output = tensor1.exp().log();
// This will add a scalar to the context, even if the actual operation can't be fused with
// the preceding ones because of the shape difference.
let _ = tensor2 + 2;
// When we try to execute the operations, the number of bindings can be different if we are
// not careful.
Backend::sync(&output.device());
}
struct FakeAddOps;
impl<B: FusionBackend> Ops<B> for FakeAddOps {
fn execute(self: Box<Self>, _: &mut burn_fusion::HandleContainer<B>) {
panic!("Should always fused during tests.")
}
}
enum ImplementationDetails {
Variant1,
Variant2,
}
fn execute<B: Backend>(
data_1: Data<f32, 2>,
data_2: Data<f32, 2>,
variant: ImplementationDetails,
) -> Data<f32, 2> {
let tensor_1 = Tensor::<B, 2>::from_data(data_1.convert());
let tensor_2 = Tensor::<B, 2>::from_data(data_2.convert());
let tensor_3 = tensor_1.clone() + tensor_2;
let tensor_4 = tensor_3.clone() - tensor_1;
let mut tensor_5 = tensor_4.clone() + 5.0;
match variant {
ImplementationDetails::Variant1 => {}
ImplementationDetails::Variant2 => {
tensor_5 = tensor_5 + 1;
tensor_5 = tensor_5 - 1;
}
}
let tensor_6 = burn_tensor::activation::gelu(tensor_5 + tensor_3.clone());
let mask = tensor_4.lower_equal(tensor_3);
let tmp = tensor_6.mask_fill(mask, 0.3);
tmp.into_data().convert()
}
}

View File

@ -83,7 +83,7 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, Input
/// Register the inputs used by the kernel.
pub fn inputs(
mut self,
inputs_tensor: &[&(TensorDescription, Elem)],
inputs_tensor: &[(&TensorDescription, Elem)],
inputs_scalar_f32: &[f32],
) -> FusionKernel<G, F, I, BodyPhase> {
for (i, (input, elem)) in inputs_tensor.iter().enumerate() {
@ -198,7 +198,7 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, Outpu
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
pub fn outputs(
mut self,
outputs: &[&(TensorDescription, Elem)],
outputs: &[(&TensorDescription, Elem)],
locals: &[u16],
) -> FusionKernel<G, F, I, ExecutionPhase> {
let mut num_elems_launch_option = 0;

View File

@ -5,4 +5,4 @@ pub(crate) mod codegen;
pub(crate) mod kernel;
pub use base::*;
pub use elemwise::*;
pub(crate) use elemwise::*;