Fusion wgpu compilation cache (#1069)

* Refactor fusion in the wgpu backend

* WIP

* Refactor

* WIP

* Fix inplace ops

* Works ish

* Cleanup Output

* Refactoring

* Refactor Clamp

* Cleanup

* Cleanup

* Updates

* Fix CI

* Code review
This commit is contained in:
Nathaniel Simard 2023-12-18 12:16:08 -05:00 committed by GitHub
parent 042454a9db
commit b5c49c5bf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 1040 additions and 1358 deletions

View File

@ -96,7 +96,7 @@ pub trait OptimizationBuilder<B: FusionBackend>: Send {
/// The operation created from the [builder](OptimizationBuilder).
pub trait Optimization<B: FusionBackend>: Send {
/// Execute the operation.
fn execute(&self, context: &mut Context<'_, B>);
fn execute(&mut self, context: &mut Context<'_, B>);
/// The number of registered operations in this optimization.
fn len(&self) -> usize;
/// If the current optimization is empty.

View File

@ -53,7 +53,7 @@ impl<B: FusionBackend> Graph<B> {
pub(crate) fn execute_optimization(
&mut self,
handles: &mut HandleContainer<B>,
optimization: &dyn Optimization<B>,
optimization: &mut dyn Optimization<B>,
) {
let num_keep = optimization.len();
let mut context = self.converter.context(handles);

View File

@ -682,20 +682,6 @@ impl<E: Element> NumericOpsDescription<E> {
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::ClampMax(desc) => {
NumericOpsDescription::ClampMax(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::ClampMin(desc) => {
NumericOpsDescription::ClampMin(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
}
}
}

View File

@ -71,7 +71,7 @@ impl<B: FusionBackend> GraphExecution<B> {
};
}
CacheResult::Found(ops) => {
graph.execute_optimization(handles, ops.as_ref());
graph.execute_optimization(handles, ops.as_mut());
self.reset(graph);
}
};
@ -107,14 +107,14 @@ impl<B: FusionBackend> GraphExecution<B> {
}
}
match find_best_optimization_index(&self.optimizations) {
match find_best_optimization_index(&mut self.optimizations) {
Some(index) => {
let (relative, next_ops) = Self::split_relative_graph_owned(graph, mode);
let optimization = &self.optimizations[index];
let ops = self
.optimization_cache
.complete(optimization, relative, next_ops);
BuildAction::ExecuteOptimization(ops.as_ref())
BuildAction::ExecuteOptimization(ops.as_mut())
}
None => {
// TODO: Cache this result too.
@ -184,7 +184,7 @@ impl<B: FusionBackend> GraphExecution<B> {
}
enum BuildAction<'a, B: FusionBackend> {
ExecuteOptimization(&'a dyn Optimization<B>),
ExecuteOptimization(&'a mut dyn Optimization<B>),
ExecuteOperations,
ContinueBuilding,
}
@ -202,7 +202,7 @@ fn still_optimizing<B: FusionBackend>(optimizations: &[Box<dyn OptimizationBuild
}
fn find_best_optimization_index<B: FusionBackend>(
optimizations: &[Box<dyn OptimizationBuilder<B>>],
optimizations: &mut [Box<dyn OptimizationBuilder<B>>],
) -> Option<usize> {
let mut best_index = None;
let mut best_score = 0;

View File

@ -379,16 +379,6 @@ pub enum NumericOpsDescription<E> {
/// Float => [clamp](burn_tensor::ops::TensorOps::clamp).
/// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp).
Clamp(ClampOpsDescription<E>),
/// Operation corresponding to:
///
/// Float => [clamp max](burn_tensor::ops::TensorOps::clamp_max).
/// Int => [clamp max](burn_tensor::ops::IntTensorOps::int_clamp_max).
ClampMax(ScalarOpsDescription<E>),
/// Operation corresponding to:
///
/// Float => [clamp min](burn_tensor::ops::TensorOps::clamp_min).
/// Int => [cleamp min](burn_tensor::ops::IntTensorOps::int_clamp_min).
ClampMin(ScalarOpsDescription<E>),
}
/// Operation description specific to an int tensor.
@ -900,12 +890,6 @@ impl<E: Element> NumericOpsDescription<E> {
NumericOpsDescription::Clamp(desc) => {
vec![&desc.tensor, &desc.out]
}
NumericOpsDescription::ClampMin(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOpsDescription::ClampMax(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOpsDescription::Abs(desc) => {
vec![&desc.input, &desc.out]
}
@ -1144,8 +1128,6 @@ impl<E> core::hash::Hash for NumericOpsDescription<E> {
NumericOpsDescription::MaxDim(desc) => desc.hash(state),
NumericOpsDescription::MinDim(desc) => desc.hash(state),
NumericOpsDescription::Clamp(desc) => desc.hash(state),
NumericOpsDescription::ClampMax(desc) => desc.hash(state),
NumericOpsDescription::ClampMin(desc) => desc.hash(state),
}
}
}

View File

@ -60,16 +60,13 @@ impl<O> OptimizationCache<O> {
}
if let Some(candidate) = self.found {
return CacheResult::Found(&self.optimizations.get(candidate).unwrap().value);
return CacheResult::Found(&mut self.optimizations.get_mut(candidate).unwrap().value);
}
// Invalidate candidates.
let mut invalidated_candidate = Vec::new();
for id in self.candidates.iter() {
let item = match self.optimizations.get(*id) {
Some(item) => item,
None => panic!("Should have an optimization"),
};
let item = &self.optimizations[*id];
let next_ops = graph.last().expect("Validated earlier");
let next_ops_index = graph.len() - 1;
let next_ops_candidate = match item.graph.get(next_ops_index) {
@ -93,13 +90,13 @@ impl<O> OptimizationCache<O> {
Condition::NextOps(ops) => ops,
Condition::Sync => {
self.found = Some(*id);
return CacheResult::Found(&item.value);
break;
}
};
if item.end_conditions.contains(ops) {
self.found = Some(*id);
return CacheResult::Found(&item.value);
break;
} else {
self.availables.push((*id, graph.len()));
invalidated_candidate.push(*id);
@ -107,6 +104,10 @@ impl<O> OptimizationCache<O> {
}
}
if let Some(id) = self.found {
return CacheResult::Found(&mut self.optimizations[id].value);
}
let mut updated_candidates = Vec::new();
core::mem::swap(&mut updated_candidates, &mut self.candidates);
@ -136,7 +137,7 @@ impl<O> OptimizationCache<O> {
factory: &Factory,
graph: Vec<TensorOpsDescription>,
next_ops: Option<TensorOpsDescription>,
) -> &'a O {
) -> &'a mut O {
let existing_optim = self
.availables
.iter()
@ -149,7 +150,7 @@ impl<O> OptimizationCache<O> {
optimization.end_conditions.push(ops)
};
return &optimization.value;
return &mut optimization.value;
};
self.starters
@ -164,7 +165,9 @@ impl<O> OptimizationCache<O> {
};
self.optimizations.push(optimization);
&self.optimizations.last().unwrap().value
let last_index = self.optimizations.len() - 1;
&mut self.optimizations[last_index].value
}
// Signal that a new path will begin.
@ -188,7 +191,7 @@ pub enum CacheResult<'a, T> {
/// happens.
OnPath,
/// An optimization has been found, and the best action is to execute it!
Found(&'a T),
Found(&'a mut T),
}
/// When checking if an optimization is possible, a start or an end condition ensures that this optimization is

View File

@ -265,48 +265,6 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn clamp_min<const D: usize>(
tensor: FloatTensor<Self, D>,
min: FloatElem<Self>,
) -> FloatTensor<Self, D> {
scalar_float_ops!(ClampMinOps, B::clamp_min);
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
let desc = ScalarOpsDescription {
lhs: tensor.into_description(),
rhs: min.elem(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::ClampMin(desc.clone())),
ClampMinOps::<D>::new(desc),
);
out
}
fn clamp_max<const D: usize>(
tensor: FloatTensor<Self, D>,
max: FloatElem<Self>,
) -> FloatTensor<Self, D> {
scalar_float_ops!(ClampMaxOps, B::clamp_max);
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
let desc = ScalarOpsDescription {
lhs: tensor.into_description(),
rhs: max.elem(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::ClampMax(desc.clone())),
ClampMaxOps::<D>::new(desc),
);
out
}
fn clamp<const D: usize>(
tensor: FloatTensor<Self, D>,
min: FloatElem<Self>,

View File

@ -1034,48 +1034,6 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out
}
fn int_clamp_min<const D: usize>(
tensor: IntTensor<Self, D>,
min: IntElem<Self>,
) -> IntTensor<Self, D> {
scalar_int_ops!(ClampMinOps, B::int_clamp_min);
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
let desc = ScalarOpsDescription {
lhs: tensor.into_description(),
rhs: min.elem(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::ClampMin(desc.clone())),
ClampMinOps::<D>::new(desc),
);
out
}
fn int_clamp_max<const D: usize>(
tensor: IntTensor<Self, D>,
max: IntElem<Self>,
) -> IntTensor<Self, D> {
scalar_int_ops!(ClampMaxOps, B::int_clamp_max);
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
let desc = ScalarOpsDescription {
lhs: tensor.into_description(),
rhs: max.elem(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::ClampMax(desc.clone())),
ClampMaxOps::<D>::new(desc),
);
out
}
fn int_clamp<const D: usize>(
tensor: IntTensor<Self, D>,
min: IntElem<Self>,

View File

@ -5,7 +5,7 @@ use std::fmt::Display;
///
/// Note that the body assumes that the kernel will run on a 2D grid defined by the workgroup size
/// X and Y, but with Z=1.
#[derive(Hash, new)]
#[derive(new)]
pub struct Body {
operators: Vec<Operator>,
}

View File

@ -2,7 +2,7 @@ use super::Elem;
use std::fmt::Display;
/// Not all functions are native to WGSL, so this struct allows to support more functions.
#[derive(Hash, PartialEq, Eq, Clone)]
#[derive(PartialEq, Eq, Clone)]
pub enum Function {
Powf(Elem),
Erf(Elem),

View File

@ -0,0 +1,359 @@
use crate::codegen::{
Binding, Body, ComputeShader, Elem, Function, Location, Operator, Variable, Visibility,
WorkgroupSize,
};
use crate::compute::{StaticKernel, WgpuComputeClient, WgpuHandle};
use crate::element::WgpuElement;
use crate::kernel::{elemwise_workgroup, StaticKernelSource, WORKGROUP_DEFAULT};
use std::marker::PhantomData;
/// Kernel creation input phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
pub struct InputPhase;
/// Kernel creation body phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
pub struct BodyPhase;
/// Kernel creation output phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
pub struct OutputPhase;
/// Kernel compilation phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
pub struct CompilationPhase;
/// Allows to create custom wgsl kernels based on configured inputs, body and outputs.
///
/// This type has 4 phases that must be executed in order, but no worry the type system won't allow
/// you to make mistakes.
///
/// 1. [Input Phase](InputPhase)
/// This phase focuses on registering the input arrays and scalars that are going to be used by
/// the kernel.
/// 2. [Body Phase](BodyPhase)
/// After the input phase is done, all the operations that happen in the body must be
/// registered.
/// 3. [Output Phase](OutputPhase)
/// This step focuses on registering all output arrays or inputs that the kernel needs to write to.
/// 4. [Compilation Phase](CompilationPhase)
/// Now that all other phases are completed, we can actually compile the kernel.
pub struct ElemWiseKernelCodegen<Phase = InputPhase> {
operations: Vec<Operator>,
input_bindings: Vec<Binding>,
output_bindings: Vec<Binding>,
named_bindings: Vec<(String, Binding)>,
functions: Vec<Function>,
_phase: PhantomData<Phase>,
}
pub enum Input {
Array {
elem: Elem,
visibility: Visibility,
strategy: ReadingStrategy,
},
Scalar {
elem: Elem,
size: usize,
},
}
pub enum ReadingStrategy {
IntoContiguous,
Plain,
}
pub enum Output {
Array { elem: Elem, local: u16 },
Input { elem: Elem, input: u16, local: u16 },
}
impl ElemWiseKernelCodegen<InputPhase> {
/// Create a new fusion kernel on the given device.
pub fn new() -> Self {
Self {
operations: Vec::new(),
input_bindings: Vec::new(),
output_bindings: Vec::new(),
named_bindings: Vec::new(),
functions: Vec::new(),
_phase: PhantomData,
}
}
/// Register the inputs used by the kernel.
pub fn inputs(mut self, inputs: &[Input]) -> ElemWiseKernelCodegen<BodyPhase> {
let mut index: u16 = 0;
let first_output_index = inputs
.iter()
.filter(|input| match input {
Input::Array {
elem: _,
visibility: _,
strategy: _,
} => true,
Input::Scalar { elem: _, size: _ } => false,
})
.count();
for input in inputs {
match input {
Input::Array {
elem,
visibility,
strategy,
} => {
self.input_bindings.push(Binding {
elem: bool_elem(*elem),
visibility: *visibility,
location: Location::Storage,
size: None,
});
match strategy {
ReadingStrategy::IntoContiguous => {
self.operations.push(Operator::ReadGlobalIntoContiguous {
variable: Variable::Input(index, *elem),
position: index as usize,
position_out: first_output_index, // First output
});
}
ReadingStrategy::Plain => {
self.operations.push(Operator::ReadGlobal {
variable: Variable::Input(index, *elem),
});
}
}
index += 1;
}
Input::Scalar { elem, size } => {
let elem = bool_elem(*elem);
self.named_bindings.push((
format!("scalars_{}", elem),
Binding {
elem,
visibility: Visibility::Read,
location: Location::Storage,
size: Some(*size),
},
));
}
}
}
ElemWiseKernelCodegen {
operations: self.operations,
input_bindings: self.input_bindings,
output_bindings: self.output_bindings,
named_bindings: self.named_bindings,
functions: self.functions,
_phase: PhantomData,
}
}
}
impl ElemWiseKernelCodegen<BodyPhase> {
/// Register the [operators](Operator) that the kernel must execute in the order provided.
pub fn body(mut self, operators: &[Operator]) -> ElemWiseKernelCodegen<OutputPhase> {
let mut register_function = |function: Function| {
if !self.functions.contains(&function) {
self.functions.push(function);
}
};
// Since not all operators are native to WGSL, we need to add the custom ones.
for ops in operators.iter() {
match ops {
Operator::Powf {
lhs: _,
rhs: _,
out: _,
} => {
register_function(Function::Powf(Elem::F32));
}
Operator::Erf { input: _, out: _ } => {
register_function(Function::Erf(Elem::F32));
}
_ => {}
}
self.operations.push(ops.clone());
}
ElemWiseKernelCodegen {
operations: self.operations,
input_bindings: self.input_bindings,
output_bindings: self.output_bindings,
named_bindings: self.named_bindings,
functions: self.functions,
_phase: PhantomData,
}
}
}
impl ElemWiseKernelCodegen<OutputPhase> {
/// Register the outputs with their local variable index.
///
/// Note that the index corresponds to the registered [operator](Operator) number at the
/// [body phase](BodyPhase).
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
pub fn outputs(mut self, outputs: &[Output]) -> ElemWiseKernelCodegen<CompilationPhase> {
let mut index = 0;
for array in outputs {
match array {
Output::Array { elem, local } => {
let elem_adapted = bool_elem(*elem);
self.output_bindings.push(Binding {
elem: elem_adapted,
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
});
self.operations.push(Operator::AssignGlobal {
input: Variable::Local(*local, *elem),
out: Variable::Output(index, elem_adapted),
});
index += 1;
}
Output::Input { elem, input, local } => {
self.operations.push(Operator::AssignGlobal {
input: Variable::Local(*local, *elem),
out: Variable::Input(*input, bool_elem(*elem)),
});
}
}
}
ElemWiseKernelCodegen {
operations: self.operations,
input_bindings: self.input_bindings,
output_bindings: self.output_bindings,
named_bindings: self.named_bindings,
functions: self.functions,
_phase: PhantomData,
}
}
}
impl ElemWiseKernelCodegen<CompilationPhase> {
/// Compile the kernel into a [compute shader](ComputeShader).
pub fn compile(self) -> ComputeShader {
let inputs = self.input_bindings;
let outputs = self.output_bindings;
let mut named = Vec::with_capacity(2);
named.push((
"info".to_string(),
Binding {
elem: Elem::U32,
visibility: Visibility::Read,
location: Location::Storage,
size: None, // We avoid putting the length here since it will force a new kernel
// for each tensor rank.
},
));
for (name, binding) in self.named_bindings.into_iter() {
named.push((name, binding));
}
ComputeShader {
inputs,
outputs,
named,
workgroup_size: WorkgroupSize::default(),
body: Body::new(self.operations),
num_workgroups: true,
global_invocation_id: true,
functions: self.functions,
}
}
}
#[derive(new)]
pub struct StaticHandle<'a> {
handle: &'a WgpuHandle,
strides: &'a [usize],
shape: &'a [usize],
}
/// Execute a static kernel.
///
///
/// The limitation from this method is that you can't launch a kernel with multiple types of
/// scalar.
pub fn execute_static<K, E: WgpuElement>(
inputs: &[StaticHandle],
outputs: &[StaticHandle],
scalar_elems: Option<&[E]>,
client: WgpuComputeClient,
) where
K: StaticKernelSource + 'static,
{
let mut info = Vec::new();
let mut handles = Vec::with_capacity(inputs.len() + outputs.len() + 2);
// Inner function to fill the info buffer.
let mut register_info_tensor = |strides: &[usize], shape: &[usize]| {
if info.is_empty() {
info.push(strides.len() as u32);
}
for s in strides.iter() {
info.push(*s as u32);
}
for s in shape.iter() {
info.push(*s as u32);
}
};
// We start by registering the inputs.
for input in inputs.iter() {
register_info_tensor(input.strides, input.shape);
handles.push(input.handle);
}
let mut num_elems_output = 0;
// Then we follow with the outputs.
for output in outputs.iter() {
let num_elems = calculate_num_elems_dyn_rank(output.shape);
if num_elems > num_elems_output {
num_elems_output = num_elems;
}
register_info_tensor(output.strides, output.shape);
handles.push(output.handle);
}
let info = &client.create(bytemuck::cast_slice(&info));
handles.push(info);
// Finally we finish with the named bindings.
let mut scalars = None;
if let Some(values) = &scalar_elems {
scalars = Some(client.create(bytemuck::cast_slice(values)));
}
if let Some(scalars) = scalars.as_ref() {
handles.push(scalars);
}
let workgroup = elemwise_workgroup(num_elems_output, WORKGROUP_DEFAULT);
let kernel = Box::new(StaticKernel::<K>::new(workgroup));
client.execute(kernel, &handles);
}
pub(crate) fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
let mut num_elems = 1;
for i in shape.iter() {
num_elems *= i;
}
num_elems
}
fn bool_elem(elem: Elem) -> Elem {
match elem {
// I32 are used for bool tensors
Elem::Bool => Elem::I32,
_ => elem,
}
}

View File

@ -0,0 +1,13 @@
mod body;
mod function;
mod kernel;
mod operator;
mod shader;
mod variable;
pub(crate) use body::*;
pub(crate) use function::*;
pub(crate) use kernel::*;
pub(crate) use operator::*;
pub(crate) use shader::*;
pub(crate) use variable::*;

View File

@ -1,8 +1,9 @@
use super::Variable;
use super::variable::Variable;
use std::fmt::Display;
/// All operators that can be fused in a WGSL compute shader.
#[derive(Debug, Hash, Clone)]
#[derive(Debug, Clone)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub enum Operator {
Add {
lhs: Variable,
@ -57,6 +58,10 @@ pub enum Operator {
rhs: Variable,
out: Variable,
},
Sqrt {
input: Variable,
out: Variable,
},
Erf {
input: Variable,
out: Variable,
@ -75,6 +80,12 @@ pub enum Operator {
rhs: Variable,
out: Variable,
},
Clamp {
input: Variable,
min_value: Variable,
max_value: Variable,
out: Variable,
},
Greater {
lhs: Variable,
rhs: Variable,
@ -100,8 +111,15 @@ pub enum Operator {
input: Variable,
out: Variable,
},
AssignLocal {
input: Variable,
out: Variable,
},
ReadGlobal {
variable: Variable,
},
ReadGlobalIntoContiguous {
variable: Variable,
position: usize,
position_out: usize,
},
@ -125,9 +143,20 @@ impl Display for Operator {
Operator::Abs { input, out } => f.write_fmt(format_args!("let {out} = abs({input});")),
Operator::Exp { input, out } => f.write_fmt(format_args!("let {out} = exp({input});")),
Operator::Log { input, out } => f.write_fmt(format_args!("let {out} = log({input});")),
Operator::Clamp {
input,
min_value,
max_value,
out,
} => f.write_fmt(format_args!(
"let {out} = clamp({input}, {min_value}, {max_value});"
)),
Operator::Powf { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = powf({lhs}, {rhs});"))
}
Operator::Sqrt { input, out } => {
f.write_fmt(format_args!("let {out} = sqrt({input});"))
}
Operator::Log1p { input, out } => {
f.write_fmt(format_args!("let {out} = log({input} + 1.0);"))
}
@ -159,7 +188,21 @@ impl Display for Operator {
let elem = out.elem();
f.write_fmt(format_args!("{out}_global[id] = {elem}({input});"))
}
Operator::ReadGlobal {
Operator::AssignLocal { input, out } => {
let elem = out.elem();
f.write_fmt(format_args!("let {out} = {elem}({input});"))
}
Operator::ReadGlobal { variable } => match variable {
Variable::Input(number, _elem) => f.write_fmt(format_args!(
"let input_{number} = input_{number}_global[id];"
)),
Variable::Local(_, _) => panic!("can't read global local variable."),
Variable::Output(number, _elem) => f.write_fmt(format_args!(
"let output_{number} = output_{number}_global[id];"
)),
Variable::Scalar(_, _) => panic!("Can't read global scalar variable."),
},
Operator::ReadGlobalIntoContiguous {
variable,
position,
position_out,

View File

@ -1,34 +1,29 @@
use super::{Body, Function};
use crate::kernel::{DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT};
use std::{
collections::hash_map::DefaultHasher,
fmt::Display,
hash::{Hash, Hasher},
};
use crate::kernel::WORKGROUP_DEFAULT;
use std::fmt::Display;
#[derive(Hash, PartialEq, Eq)]
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum Location {
Storage,
#[allow(dead_code)]
Workgroup,
}
#[derive(Hash, PartialEq, Eq)]
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum Visibility {
Read,
ReadWrite,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum Elem {
F32,
#[allow(dead_code)]
I32,
U32,
Bool,
}
#[derive(Hash, PartialEq, Eq)]
#[derive(PartialEq, Eq, Clone)]
pub struct Binding {
pub location: Location,
pub visibility: Visibility,
@ -36,7 +31,7 @@ pub struct Binding {
pub size: Option<usize>,
}
#[derive(Hash, PartialEq, Eq)]
#[derive(PartialEq, Eq)]
pub struct WorkgroupSize {
pub x: usize,
pub y: usize,
@ -53,7 +48,6 @@ impl Default for WorkgroupSize {
}
}
#[derive(Hash)]
pub struct ComputeShader {
pub inputs: Vec<Binding>,
pub outputs: Vec<Binding>,
@ -65,19 +59,6 @@ pub struct ComputeShader {
pub functions: Vec<Function>,
}
impl DynamicKernelSource for ComputeShader {
fn source(&self) -> SourceTemplate {
SourceTemplate::new(self.to_string())
}
fn id(&self) -> String {
let mut s = DefaultHasher::new();
self.hash(&mut s);
s.finish().to_string()
}
}
impl Display for ComputeShader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::format_bindings(f, "input", &self.inputs, 0)?;

View File

@ -1,7 +1,7 @@
use super::Elem;
use std::fmt::Display;
#[derive(Debug, Hash, Clone)]
#[derive(Debug, Clone)]
pub enum Variable {
Input(u16, Elem),
Scalar(u16, Elem),

View File

@ -9,8 +9,7 @@ where
fn type_name() -> &'static str;
fn as_bytes(slice: &[Self]) -> &[u8];
fn from_bytes(bytes: &[u8]) -> &[Self];
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem;
fn elem_type() -> crate::codegen::Elem;
}
/// The float element type for the wgpu backend.
@ -29,9 +28,8 @@ impl WgpuElement for u32 {
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem {
crate::fusion::codegen::Elem::U32
fn elem_type() -> crate::codegen::Elem {
crate::codegen::Elem::U32
}
}
@ -45,9 +43,8 @@ impl WgpuElement for i32 {
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem {
crate::fusion::codegen::Elem::I32
fn elem_type() -> crate::codegen::Elem {
crate::codegen::Elem::I32
}
}
@ -62,9 +59,8 @@ impl WgpuElement for f32 {
bytemuck::cast_slice(bytes)
}
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem {
crate::fusion::codegen::Elem::F32
fn elem_type() -> crate::codegen::Elem {
crate::codegen::Elem::F32
}
}

View File

@ -81,14 +81,6 @@ pub fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
strides
}
pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
let mut num_elems = 1;
for i in shape.iter() {
num_elems *= i;
}
num_elems
}
#[derive(new, Debug, Clone)]
/// Handle to be used when fusing operations.
pub struct WgpuFusionHandle {

View File

@ -0,0 +1,57 @@
use crate::{
codegen::ComputeShader,
kernel::{DynamicKernelSource, SourceTemplate},
};
use hashbrown::HashSet;
/// This cache ensures that the generation of the source code is only done once when the kernel is
/// executed for the first time. Following, we only include the ID in the dynamic kernel source,
/// since we rely on the compilation cache of the WGPU compute server.
///
/// If it ever causes problems, we could cache the compute shader and put it into an Arc to avoid deep
/// cloning.
#[derive(Default, Debug)]
pub struct KernelCompilationCache {
already_compiled_ids: HashSet<String>,
}
#[derive(new)]
pub enum FusedKernelSource {
AlreadyCompiled { id: String },
NewKernel { id: String, shader: ComputeShader },
}
impl DynamicKernelSource for FusedKernelSource {
fn source(&self) -> SourceTemplate {
match self {
FusedKernelSource::AlreadyCompiled { id: _ } => {
panic!("Can't get the source of an already compiled kernel.")
}
FusedKernelSource::NewKernel {
id: _,
shader: source,
} => SourceTemplate::new(source.to_string()),
}
}
fn id(&self) -> String {
match self {
FusedKernelSource::AlreadyCompiled { id } => id.clone(),
FusedKernelSource::NewKernel { id, shader: _ } => id.clone(),
}
}
}
impl KernelCompilationCache {
pub fn get(&self, id: &str) -> Option<FusedKernelSource> {
if self.already_compiled_ids.contains(id) {
return Some(FusedKernelSource::AlreadyCompiled { id: id.to_string() });
}
None
}
pub fn insert(&mut self, id: String) {
self.already_compiled_ids.insert(id);
}
}

View File

@ -1,11 +0,0 @@
mod body;
mod function;
mod operator;
mod shader;
mod variable;
pub use body::*;
pub use function::*;
pub use operator::*;
pub use shader::*;
pub use variable::*;

View File

@ -1,8 +1,10 @@
use crate::{
codegen::{Elem, Operator, Variable},
element::WgpuElement,
fusion::codegen::{Elem, Operator, Variable},
fusion::cache::KernelCompilationCache,
FloatElement, GraphicsApi, IntElement, Wgpu,
};
use burn_common::id::IdGenerator;
use burn_fusion::{
graph::{
BaseOpsDescription, BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription,
@ -84,12 +86,14 @@ where
.collect::<Vec<_>>();
Box::new(FloatElementWise {
id: IdGenerator::generate(),
inputs,
outputs,
locals,
operators: self.operators.clone(),
scalars_f32: self.scalars_f32,
device: self.device.clone(),
cache: KernelCompilationCache::default(),
})
}
@ -183,13 +187,19 @@ where
Operator::AssignGlobal { input: _, out: _ } => {
// Nothing to do here.
}
Operator::ReadGlobal {
Operator::AssignLocal { input: _, out: _ } => {
// Nothing to do here.
}
Operator::ReadGlobalIntoContiguous {
variable: _,
position: _,
position_out: _,
} => {
// Nothing to do here.
}
Operator::ReadGlobal { variable: _ } => {
// Nothing to do here.
}
Operator::Add { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
@ -242,6 +252,15 @@ where
mark(input, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Clamp {
input,
min_value: _,
max_value: _,
out,
} => {
mark(input, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Powf { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
@ -287,6 +306,10 @@ where
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Sqrt { input, out } => {
mark(input, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
}
}

View File

@ -1,24 +1,73 @@
use crate::{
fusion::codegen::{Elem, Operator},
fusion::kernel::FusionKernel,
codegen::{
ComputeShader, Elem, ElemWiseKernelCodegen, Input, Operator, Output, ReadingStrategy,
Visibility,
},
fusion::{
cache::{FusedKernelSource, KernelCompilationCache},
kernel,
},
FloatElement, GraphicsApi, IntElement, Wgpu,
};
use burn_fusion::{graph::Context, Optimization, TensorDescription};
use burn_tensor::Device;
#[derive(Clone)]
pub(crate) struct FloatElementWise<G, F, I>
where
G: GraphicsApi,
F: FloatElement,
I: IntElement,
{
pub(crate) id: String,
pub(crate) inputs: Vec<(TensorDescription, Elem)>,
pub(crate) outputs: Vec<(TensorDescription, Elem)>,
pub(crate) locals: Vec<u16>,
pub(crate) operators: Vec<Operator>,
pub(crate) scalars_f32: usize,
pub(crate) device: Device<Wgpu<G, F, I>>,
pub(crate) cache: KernelCompilationCache,
}
impl<G, F, I> FloatElementWise<G, F, I>
where
G: GraphicsApi,
F: FloatElement,
I: IntElement,
{
pub fn compile(&mut self) -> ComputeShader {
let mut inputs = self
.inputs
.iter()
.map(|(_tensor, elem)| Input::Array {
elem: *elem,
visibility: Visibility::Read,
strategy: ReadingStrategy::IntoContiguous,
})
.collect::<Vec<_>>();
let outputs = self
.outputs
.iter()
.zip(self.locals.iter())
.map(|((_tensor, elem), local)| Output::Array {
elem: *elem,
local: *local,
})
.collect::<Vec<_>>();
if self.scalars_f32 > 0 {
inputs.push(Input::Scalar {
elem: Elem::F32,
size: self.scalars_f32,
})
}
ElemWiseKernelCodegen::new()
.inputs(&inputs)
.body(&self.operators)
.outputs(&outputs)
.compile()
}
}
impl<G, F, I> Optimization<Wgpu<G, F, I>> for FloatElementWise<G, F, I>
@ -27,27 +76,33 @@ where
F: FloatElement,
I: IntElement,
{
fn execute(&self, context: &mut Context<'_, Wgpu<G, F, I>>) {
let inputs = self
.inputs
.iter()
.map(|(tensor, elem)| (context.tensors.get(&tensor.id).unwrap(), *elem))
.collect::<Vec<_>>();
fn execute(&mut self, context: &mut Context<'_, Wgpu<G, F, I>>) {
if let Some(kernel) = self.cache.get(&self.id) {
kernel::execute_fusion(
&self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
&self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
self.scalars_f32,
kernel,
context,
self.device.clone(),
);
} else {
let shader = self.compile();
let outputs = self
.outputs
.iter()
.map(|(tensor, elem)| (context.tensors.get(&tensor.id).unwrap(), *elem))
.collect::<Vec<_>>();
kernel::execute_fusion(
&self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
&self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
self.scalars_f32,
FusedKernelSource::NewKernel {
id: self.id.to_string(),
shader,
},
context,
self.device.clone(),
);
// The context may contain scalars for the end condition, which may vary.
let scalars_f32 = &context.scalar_floats[0..self.scalars_f32];
FusionKernel::new(&self.device)
.inputs(&inputs, scalars_f32)
.body(&self.operators)
.outputs(&outputs, &self.locals)
.execute(context.handles);
self.cache.insert(self.id.clone());
}
}
fn len(&self) -> usize {
@ -144,6 +199,7 @@ mod tests {
Variant1,
Variant2,
}
fn execute<B: Backend>(
data_1: Data<f32, 2>,
data_2: Data<f32, 2>,

View File

@ -1,355 +1,82 @@
use super::codegen::Body;
use crate::compute::{compute_client, DynamicKernel, WgpuComputeClient};
use crate::fusion::codegen::Function;
use crate::fusion::{calculate_num_elems_dyn_rank, strides_dyn_rank};
use crate::fusion::{
codegen::{
Binding, ComputeShader, Elem, Location, Operator, Variable, Visibility, WorkgroupSize,
},
WgpuFusionHandle,
};
use super::cache::FusedKernelSource;
use crate::codegen::calculate_num_elems_dyn_rank;
use crate::compute::{compute_client, DynamicKernel};
use crate::fusion::strides_dyn_rank;
use crate::fusion::WgpuFusionHandle;
use crate::kernel::{elemwise_workgroup, WORKGROUP_DEFAULT};
use crate::{FloatElement, GraphicsApi, IntElement, Wgpu};
use burn_fusion::{HandleContainer, TensorDescription};
use burn_fusion::graph::Context;
use burn_fusion::TensorDescription;
use burn_tensor::Device;
use std::marker::PhantomData;
/// Kernel creation input phase, see [fusion kernel](FusionKernel) for more details.
pub struct InputPhase;
/// Kernel creation body phase, see [fusion kernel](FusionKernel) for more details.
pub struct BodyPhase;
/// Kernel creation output phase, see [fusion kernel](FusionKernel) for more details.
pub struct OutputPhase;
/// Kernel execution phase, see [fusion kernel](FusionKernel) for more details.
pub struct ExecutionPhase;
/// Allows to create custom wgsl kernels based on configured inputs, body and outputs.
///
/// This type has 4 phases that must be executed in order, but no worry the type system won't allow
/// you to make mistakes.
///
/// 1. [Input Phase](InputPhase)
/// This phase focuses on registering the input tensor descriptions that are going to be used by
/// the fused kernel.
/// 2. [Body Phase](BodyPhase)
/// After the input phase is done, all the operations that happen in the body must be
/// registered.
/// 3. [Output Phase](OutputPhase)
/// This step focuses on registering all tensor descriptions that the kernel needs to write to.
/// 4. [Execution Phase](ExecutionPhase)
/// Now that all other phases are completed, we can actually run the kernel on the given
/// [handles](HandleContainer). Note that the actual chosen kernel may vary based on the
/// handles provided.
pub struct FusionKernel<G, F, I, Phase = InputPhase>
where
G: GraphicsApi,
F: FloatElement,
I: IntElement,
{
operations: Vec<Operator>,
input_bindings: Vec<(Binding, TensorDescription)>,
output_bindings: Vec<(Binding, TensorDescription)>,
named_bindings: Vec<(String, Binding, DataBuffer)>,
functions: Vec<Function>,
num_elems_output: usize,
pub fn execute_fusion<G: GraphicsApi, F: FloatElement, I: IntElement>(
inputs: &[&TensorDescription],
outputs: &[&TensorDescription],
scalars_f32: usize,
kernel: FusedKernelSource,
context: &mut Context<'_, Wgpu<G, F, I>>,
device: Device<Wgpu<G, F, I>>,
client: WgpuComputeClient,
_phase: PhantomData<Phase>,
}
) {
let client = compute_client::<G>(&device);
let mut info = Vec::new();
let mut handles = Vec::with_capacity(inputs.len() + outputs.len() + 2);
enum DataBuffer {
F32(Vec<f32>),
U32(Vec<u32>),
}
// Inner function to fill the info buffer.
let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| {
if info.is_empty() {
info.push(handle.strides.len() as u32);
}
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, InputPhase> {
/// Create a new fusion kernel on the given device.
pub fn new(device: &Device<Wgpu<G, F, I>>) -> Self {
let client = compute_client::<G>(device);
for s in handle.strides.iter() {
info.push(*s as u32);
}
for s in tensor.shape.iter() {
info.push(*s as u32);
}
};
Self {
operations: Vec::new(),
input_bindings: Vec::new(),
output_bindings: Vec::new(),
named_bindings: Vec::new(),
functions: Vec::new(),
num_elems_output: 0,
// We start by registering the inputs.
for tensor in inputs.iter() {
let tensor = context.tensors.get(&tensor.id).unwrap();
let handle = context.handles.get_handle(tensor);
register_info_tensor(tensor, &handle);
handles.push(handle.handle);
}
let mut num_elems_output = 0;
// Then we follow with the outputs.
for tensor in outputs.iter() {
let tensor = context.tensors.get(&tensor.id).unwrap();
let num_elems = calculate_num_elems_dyn_rank(&tensor.shape);
if num_elems > num_elems_output {
num_elems_output = num_elems;
}
let handle_fusion = WgpuFusionHandle {
client: client.clone(),
device: device.clone(),
client,
_phase: PhantomData,
}
}
/// Register the inputs used by the kernel.
pub fn inputs(
mut self,
inputs_tensor: &[(&TensorDescription, Elem)],
inputs_scalar_f32: &[f32],
) -> FusionKernel<G, F, I, BodyPhase> {
for (i, (input, elem)) in inputs_tensor.iter().enumerate() {
if elem != &Elem::Bool {
self.input_bindings.push((
Binding {
elem: *elem,
visibility: Visibility::Read,
location: Location::Storage,
size: None,
},
(*input).clone(),
));
self.operations.push(Operator::ReadGlobal {
variable: Variable::Input(i as u16, *elem),
position: i,
position_out: inputs_tensor.len(), // First output
});
} else {
self.input_bindings.push((
Binding {
elem: Elem::I32,
visibility: Visibility::Read,
location: Location::Storage,
size: None,
},
(*input).clone(),
));
self.operations.push(Operator::ReadGlobal {
variable: Variable::Input(i as u16, *elem),
position: i,
position_out: inputs_tensor.len(), // First output
});
}
}
if !inputs_scalar_f32.is_empty() {
self.named_bindings.push((
"scalars_f32".to_string(),
Binding {
elem: Elem::F32,
visibility: Visibility::Read,
location: Location::Storage,
size: Some(inputs_scalar_f32.len()),
},
DataBuffer::F32(inputs_scalar_f32.to_vec()),
));
}
FusionKernel {
operations: self.operations,
input_bindings: self.input_bindings,
output_bindings: self.output_bindings,
named_bindings: self.named_bindings,
functions: self.functions,
num_elems_output: self.num_elems_output,
device: self.device,
client: self.client,
_phase: PhantomData,
}
}
}
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, BodyPhase> {
/// Register the [operators](Operator) that the kernel must execute in the order provided.
pub fn body(mut self, operators: &[Operator]) -> FusionKernel<G, F, I, OutputPhase> {
let mut register_function = |function: Function| {
if !self.functions.contains(&function) {
self.functions.push(function);
}
strides: strides_dyn_rank(&tensor.shape),
handle: client.empty(core::mem::size_of::<F>() * num_elems),
};
// Since not all operators are native to WGSL, we need to add the custom ones.
for ops in operators.iter() {
match ops {
Operator::Powf {
lhs: _,
rhs: _,
out: _,
} => {
register_function(Function::Powf(Elem::F32));
}
Operator::Erf { input: _, out: _ } => {
register_function(Function::Erf(Elem::F32));
}
_ => {}
}
self.operations.push(ops.clone());
}
register_info_tensor(tensor, &handle_fusion);
FusionKernel {
operations: self.operations,
input_bindings: self.input_bindings,
output_bindings: self.output_bindings,
named_bindings: self.named_bindings,
functions: self.functions,
num_elems_output: self.num_elems_output,
device: self.device,
client: self.client,
_phase: PhantomData,
}
}
}
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, OutputPhase> {
/// Register the outputs with their local variable index.
///
/// Note that the index corresponds to the registered [operator](Operator) number at the
/// [body phase](BodyPhase).
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
pub fn outputs(
mut self,
outputs: &[(&TensorDescription, Elem)],
locals: &[u16],
) -> FusionKernel<G, F, I, ExecutionPhase> {
let mut num_elems_launch_option = 0;
for (i, ((output, elem), local)) in outputs.iter().zip(locals).enumerate() {
let num_elems_output = calculate_num_elems_dyn_rank(&output.shape);
if num_elems_output > num_elems_launch_option {
num_elems_launch_option = num_elems_output;
}
if elem != &Elem::Bool {
self.output_bindings.push((
Binding {
elem: *elem,
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
},
(*output).clone(),
));
self.operations.push(Operator::AssignGlobal {
input: Variable::Local(*local, *elem),
out: Variable::Output(i as u16, *elem),
});
} else {
self.output_bindings.push((
Binding {
elem: Elem::I32, // I32 are used for bool tensors
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
},
(*output).clone(),
));
self.operations.push(Operator::AssignGlobal {
input: Variable::Local(*local, *elem),
out: Variable::Output(i as u16, Elem::I32),
});
}
}
self.num_elems_output = num_elems_launch_option;
FusionKernel {
operations: self.operations,
input_bindings: self.input_bindings,
output_bindings: self.output_bindings,
named_bindings: self.named_bindings,
functions: self.functions,
num_elems_output: self.num_elems_output,
device: self.device,
client: self.client,
_phase: PhantomData,
}
}
}
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, ExecutionPhase> {
/// Execute the kernel on the provided [handles](HandleContainer).
pub fn execute(mut self, handle_container: &mut HandleContainer<Wgpu<G, F, I>>) {
let mut inputs = Vec::with_capacity(self.input_bindings.len());
let mut outputs = Vec::with_capacity(self.output_bindings.len());
let mut named = Vec::with_capacity(2);
let mut info = Vec::new();
let mut handles =
Vec::with_capacity(inputs.capacity() + outputs.capacity() + named.capacity());
// Inner function to fill the info buffer.
let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| {
if info.is_empty() {
info.push(handle.strides.len() as u32);
}
for s in handle.strides.iter() {
info.push(*s as u32);
}
for s in tensor.shape.iter() {
info.push(*s as u32);
}
};
// We start by registering the inputs.
for (binding, tensor) in self.input_bindings.into_iter() {
let handle = handle_container.get_handle(&tensor);
register_info_tensor(&tensor, &handle);
inputs.push(binding);
handles.push(handle.handle);
}
// Then we follow with the outputs.
for (binding, tensor) in self.output_bindings {
let num_elems = calculate_num_elems_dyn_rank(&tensor.shape);
let handle_fusion = WgpuFusionHandle {
client: self.client.clone(),
device: self.device.clone(),
strides: strides_dyn_rank(&tensor.shape),
handle: self.client.empty(core::mem::size_of::<F>() * num_elems),
};
register_info_tensor(&tensor, &handle_fusion);
handles.push(handle_fusion.handle.clone());
handle_container.register_handle(tensor.id, handle_fusion);
outputs.push(binding);
}
// Now we can create the info handle.
Self::build_info_handle(&mut self.named_bindings, info);
// Finally we finish with the named bindings.
for (name, binding, data) in self.named_bindings {
let handle = self.client.create(match &data {
DataBuffer::F32(values) => bytemuck::cast_slice(values),
DataBuffer::U32(values) => bytemuck::cast_slice(values),
});
named.push((name, binding));
handles.push(handle);
}
// We create the shader codegen type and launch the kernel.
let kernel = ComputeShader {
inputs,
outputs,
named,
workgroup_size: WorkgroupSize::default(),
body: Body::new(self.operations),
num_workgroups: true,
global_invocation_id: true,
functions: self.functions,
};
let workgroup = elemwise_workgroup(self.num_elems_output, WORKGROUP_DEFAULT);
let kernel = Box::new(DynamicKernel::new(kernel, workgroup));
self.client
.execute(kernel, &handles.iter().collect::<Vec<_>>());
}
fn build_info_handle(named_bindings: &mut Vec<(String, Binding, DataBuffer)>, info: Vec<u32>) {
named_bindings.push((
"info".to_string(),
Binding {
elem: Elem::U32,
visibility: Visibility::Read,
location: Location::Storage,
size: None, // We avoid putting the length here since it will force a new kernel
// for each tensor rank.
},
DataBuffer::U32(info),
));
handles.push(handle_fusion.handle.clone());
context
.handles
.register_handle(tensor.id.clone(), handle_fusion);
}
handles.push(client.create(bytemuck::cast_slice(&info)));
// Finally we finish with the named bindings.
if scalars_f32 > 0 {
handles.push(client.create(bytemuck::cast_slice(&context.scalar_floats[0..scalars_f32])));
}
let workgroup = elemwise_workgroup(num_elems_output, WORKGROUP_DEFAULT);
let kernel = Box::new(DynamicKernel::new(kernel, workgroup));
client.execute(kernel, &handles.iter().collect::<Vec<_>>());
}

View File

@ -1,7 +1,7 @@
mod base;
mod elemwise;
pub(crate) mod codegen;
pub(crate) mod cache;
pub(crate) mod kernel;
pub use base::*;

View File

@ -1,78 +1,27 @@
use super::unary;
use crate::{
compute::StaticKernel,
codegen::{Operator, Variable},
element::WgpuElement,
kernel::{unary_scalar, unary_scalar_inplace_default, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
unary_scalar, unary_scalar_inplace,
unary,
};
use super::{elemwise_workgroup, KernelSettings};
kernel_wgsl!(Clamp, "../template/clamp/clamp.wgsl");
kernel_wgsl!(ClampInplace, "../template/clamp/clamp_inplace.wgsl");
pub(crate) fn clamp_min<E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
min_value: E,
) -> WgpuTensor<E, D> {
unary_scalar!(ClampMin, func "max");
unary_scalar_inplace!(ClampMinInplace, func "max");
if input.can_mut() {
return unary_scalar_inplace_default::<ClampMinInplace, E, D>(input, min_value);
}
unary_scalar::<ClampMin, E, D, WORKGROUP_DEFAULT>(input, min_value)
}
pub(crate) fn clamp_max<E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
max_value: E,
) -> WgpuTensor<E, D> {
unary_scalar!(ClampMax, func "min");
unary_scalar_inplace!(ClampMaxInPlace, func "min");
if input.can_mut() {
return unary_scalar_inplace_default::<ClampMaxInPlace, E, D>(input, max_value);
}
unary_scalar::<ClampMax, E, D, WORKGROUP_DEFAULT>(input, max_value)
}
unary!(
|elem| Operator::Clamp {
input: Variable::Input(0, elem),
min_value: Variable::Scalar(0, elem),
max_value: Variable::Scalar(1, elem),
out: Variable::Local(0, elem),
},
scalar 2
);
pub(crate) fn clamp<E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
min_value: E,
max_value: E,
) -> WgpuTensor<E, D> {
let num_elems = input.shape.num_elements();
let min_handle = input.client.create(E::as_bytes(&[min_value]));
let max_handle = input.client.create(E::as_bytes(&[max_value]));
if input.can_mut() {
let kernel = StaticKernel::<
KernelSettings<ClampInplace, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
input
.client
.execute(Box::new(kernel), &[&input.handle, &min_handle, &max_handle]);
return input;
}
let output = empty_device(input.client.clone(), input.device.clone(), input.shape);
let kernel = StaticKernel::<
KernelSettings<Clamp, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
input.client.execute(
Box::new(kernel),
&[&input.handle, &output.handle, &min_handle, &max_handle],
);
output
unary::<Ops<E>, OpsInplace<E>, E, D>(input, Some(&[min_value, max_value]))
}
#[cfg(test)]
@ -80,30 +29,6 @@ mod tests {
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{Distribution, Tensor};
#[test]
fn clamp_min_should_match_reference() {
let input = Tensor::<TestBackend, 4>::random([1, 5, 32, 32], Distribution::Default);
let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data());
let output = input.clamp_min(0.5);
output
.into_data()
.assert_approx_eq(&input_ref.clamp_min(0.5).into_data(), 3);
}
#[test]
fn clamp_max_should_match_reference() {
let input = Tensor::<TestBackend, 4>::random([1, 5, 32, 32], Distribution::Default);
let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data());
let output = input.clamp_max(0.5);
output
.into_data()
.assert_approx_eq(&input_ref.clamp_max(0.5).into_data(), 3);
}
#[test]
fn clamp_should_match_reference() {
let input = Tensor::<TestBackend, 4>::random([1, 5, 32, 32], Distribution::Default);

View File

@ -8,14 +8,12 @@ mod index;
mod mask;
mod source;
mod unary;
mod unary_scalar;
pub use base::*;
pub use binary_elemwise::*;
pub use cast::*;
pub use source::*;
pub use unary::*;
pub use unary_scalar::*;
/// Convolution kernels
pub mod conv;

View File

@ -1,169 +1,220 @@
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT};
use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
kernel_wgsl!(UnaryRaw, "../template/unary.wgsl");
kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl");
use super::StaticKernelSource;
use crate::{
codegen::{execute_static, StaticHandle},
element::WgpuElement,
tensor::WgpuTensor,
};
/// Creates a unary kernel.
#[macro_export]
macro_rules! unary {
(
$struct:ident,
func $func:expr
) => {
pub struct $struct;
operator: $ops:expr,
input: $input:expr,
elem: $elem:ty
) => {{
unary!($ops);
impl $crate::kernel::StaticKernelSource for $struct {
$crate::kernel::unary::<Ops<$elem>, OpsInplace<$elem>, $elem, D>($input, None)
}};
(
operator: $ops:expr,
input: $input:expr; $scalar:expr,
elem: $elem:ty
) => {{
unary!($ops, scalar 1);
$crate::kernel::unary::<Ops<$elem>, OpsInplace<$elem>, $elem, D>($input, Some(&[$scalar]))
}};
(
$ops:expr
) => {
pub struct Ops<E> {
_e: core::marker::PhantomData<E>,
}
pub struct OpsInplace<E> {
_e: core::marker::PhantomData<E>,
}
#[allow(clippy::redundant_closure_call)]
impl<E: $crate::element::WgpuElement> $crate::kernel::StaticKernelSource for Ops<E> {
fn source() -> $crate::kernel::SourceTemplate {
let source = $crate::kernel::UnaryRaw::source();
source.register("body", format!("output[id] = {}(input[id]);", $func))
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
.inputs(&[$crate::codegen::Input::Array {
elem: E::elem_type(),
visibility: $crate::codegen::Visibility::Read,
strategy: $crate::codegen::ReadingStrategy::IntoContiguous,
}])
.body(&[$ops(E::elem_type())])
.outputs(&[$crate::codegen::Output::Array {
elem: E::elem_type(),
local: 0,
}])
.compile();
$crate::kernel::SourceTemplate::new(shader.to_string())
}
}
#[allow(clippy::redundant_closure_call)]
impl<E: $crate::element::WgpuElement> $crate::kernel::StaticKernelSource for OpsInplace<E> {
fn source() -> $crate::kernel::SourceTemplate {
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
.inputs(&[$crate::codegen::Input::Array {
elem: E::elem_type(),
visibility: $crate::codegen::Visibility::ReadWrite,
strategy: $crate::codegen::ReadingStrategy::Plain,
}])
.body(&[$ops(E::elem_type())])
.outputs(&[$crate::codegen::Output::Input {
elem: E::elem_type(),
input: 0,
local: 0,
}])
.compile();
$crate::kernel::SourceTemplate::new(shader.to_string())
}
}
};
(
$struct:ident,
body $body:expr
$ops:expr,
scalar $num:expr
) => {
pub struct $struct;
pub struct Ops<E> {
_e: core::marker::PhantomData<E>,
}
pub struct OpsInplace<E> {
_e: core::marker::PhantomData<E>,
}
impl $crate::kernel::StaticKernelSource for $struct {
#[allow(clippy::redundant_closure_call)]
impl<E: $crate::element::WgpuElement> $crate::kernel::StaticKernelSource for Ops<E> {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryRaw::source().register("body", $body)
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
.inputs(&[
$crate::codegen::Input::Array {
elem: E::elem_type(),
visibility: $crate::codegen::Visibility::Read,
strategy: $crate::codegen::ReadingStrategy::IntoContiguous,
},
$crate::codegen::Input::Scalar {
elem: E::elem_type(),
size: $num,
},
])
.body(&[$ops(E::elem_type())])
.outputs(&[$crate::codegen::Output::Array {
elem: E::elem_type(),
local: 0,
}])
.compile();
$crate::kernel::SourceTemplate::new(shader.to_string())
}
}
};
(
$struct:ident,
func $func:expr,
include $file:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
#[allow(clippy::redundant_closure_call)]
impl<E: $crate::element::WgpuElement> $crate::kernel::StaticKernelSource for OpsInplace<E> {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryRaw::source()
.register("body", format!("output[id] = {}(input[id]);", $func))
.add_template(include_str!($file))
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
.inputs(&[
$crate::codegen::Input::Array {
elem: E::elem_type(),
visibility: $crate::codegen::Visibility::ReadWrite,
strategy: $crate::codegen::ReadingStrategy::Plain,
},
$crate::codegen::Input::Scalar {
elem: E::elem_type(),
size: $num,
},
])
.body(&[$ops(E::elem_type())])
.outputs(&[$crate::codegen::Output::Input {
elem: E::elem_type(),
input: 0,
local: 0,
}])
.compile();
$crate::kernel::SourceTemplate::new(shader.to_string())
}
}
};
}
/// Creates a unary inplace kernel.
#[macro_export]
macro_rules! unary_inplace {
(
$struct:ident,
func $func:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryInplaceRaw::source()
.register("body", format!("input[id] = {}(input[id]);", $func))
}
}
};
(
$struct:ident,
body $body:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryInplaceRaw::source().register("body", $body)
}
}
};
(
$struct:ident,
func $func:expr,
include $file:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryInplaceRaw::source()
.register("body", format!("input[id] = {}(input[id]);", $func))
.add_template(include_str!($file))
}
}
};
}
/// Execute a unary kernel using the default settings.
pub fn unary_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
unary::<K, E, D, WORKGROUP_DEFAULT>(input)
}
/// Execute a unary inplace kernel using the default settings.
pub fn unary_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
unary_inplace::<K, E, D, WORKGROUP_DEFAULT>(input)
}
/// Execute a unary inplace kernel using the provided WORKGROUP.
pub fn unary_inplace<
/// Launch an unary operation.
pub fn unary<K, KI, E, const D: usize>(
tensor: WgpuTensor<E, D>,
scalars: Option<&[E]>,
) -> WgpuTensor<E, D>
where
K: StaticKernelSource,
KI: StaticKernelSource,
E: WgpuElement,
const D: usize,
const WORKGROUP: usize,
>(
input: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
let num_elems = input.shape.num_elements();
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
{
if !tensor.can_mut() {
let num_elems = tensor.shape.num_elements();
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(
tensor.client.clone(),
tensor.device,
tensor.shape.clone(),
buffer,
);
input.client.execute(Box::new(kernel), &[&input.handle]);
execute_static::<K, E>(
&[StaticHandle::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
&[StaticHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
scalars,
tensor.client,
);
input
}
output
} else {
execute_static::<KI, E>(
&[],
&[StaticHandle::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
scalars,
tensor.client.clone(),
);
/// Execute a unary kernel using the provided WORKGROUP.
pub fn unary<K: StaticKernelSource, E: WgpuElement, const D: usize, const WORKGROUP: usize>(
input: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
let num_elems = input.shape.num_elements();
let buffer = input.client.empty(num_elems * core::mem::size_of::<E>());
let mut output = WgpuTensor::new(input.client.clone(), input.device, input.shape, buffer);
// Since we don't handle the stride inside the kernel, the output tensor have the same strides
// as the input tensor. It might not be in the default format.
output.strides = input.strides;
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
input
.client
.execute(Box::new(kernel), &[&input.handle, &output.handle]);
output
tensor
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codegen::{Operator, Variable};
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{Distribution, Tensor};
unary!(TestKernel, func "log");
unary_inplace!(TestKernelInplace, func "log");
unary!(|elem| Operator::Tanh {
input: Variable::Input(0, elem),
out: Variable::Local(0, elem),
});
#[test]
fn unary_should_work_with_multiple_invocations() {
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
let actual = unary::<TestKernel, _, 2, 16>(tensor.into_primitive());
let expected = tensor_ref.log();
let actual = unary::<Ops<f32>, OpsInplace<f32>, f32, 2>(tensor.into_primitive(), None);
let expected = tensor_ref.tanh();
expected.into_data().assert_approx_eq(
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
@ -176,8 +227,8 @@ mod tests {
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
let actual = unary_inplace::<TestKernelInplace, _, 2, 16>(tensor.into_primitive());
let expected = tensor_ref.log();
let actual = unary::<Ops<f32>, OpsInplace<f32>, f32, 2>(tensor.into_primitive(), None);
let expected = tensor_ref.tanh();
expected.into_data().assert_approx_eq(
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),

View File

@ -1,220 +0,0 @@
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT};
use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl");
kernel_wgsl!(
UnaryScalarInplaceRaw,
"../template/unary_scalar_inplace.wgsl"
);
/// Creates a unary scalar kernel.
#[macro_export]
macro_rules! unary_scalar {
(
$struct:ident,
ops $ops:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarRaw::source()
.register("body", format!("output[id] = lhs[id] {} rhs;", $ops))
}
}
};
(
$struct:ident,
func $func:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarRaw::source()
.register("body", format!("output[id] = {}(lhs[id], rhs);", $func))
}
}
};
(
$struct:ident,
func $func:expr,
include $file:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarRaw::source()
.register("body", format!("output[id] = {}(lhs[id], rhs);", $func))
.add_template(include_str!($file))
}
}
};
}
/// Creates a unary scalar inplace kernel.
#[macro_export]
macro_rules! unary_scalar_inplace {
(
$struct:ident,
ops $ops:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarInplaceRaw::source()
.register("body", format!("lhs[id] = lhs[id] {} rhs;", $ops))
}
}
};
(
$struct:ident,
body $body:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarInplaceRaw::source().register("body", $body)
}
}
};
(
$struct:ident,
func $func:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarInplaceRaw::source()
.register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func))
}
}
};
(
$struct:ident,
func $func:expr,
include $file:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarInplaceRaw::source()
.register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func))
.add_template(include_str!($file))
}
}
};
}
/// Execute a unary scalar kernel using the default settings.
pub fn unary_scalar_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
scalar: E,
) -> WgpuTensor<E, D> {
unary_scalar::<K, E, D, WORKGROUP_DEFAULT>(lhs, scalar)
}
/// Execute a unary scalar kernel using the provided WORKGROUP.
pub fn unary_scalar<
K: StaticKernelSource,
E: WgpuElement,
const D: usize,
const WORKGROUP: usize,
>(
lhs: WgpuTensor<E, D>,
scalar: E,
) -> WgpuTensor<E, D> {
let num_elems = lhs.shape.num_elements();
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(lhs.client.clone(), lhs.device, lhs.shape, buffer);
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let rhs_handle = lhs.client.create(E::as_bytes(&[scalar]));
lhs.client.execute(
Box::new(kernel),
&[&lhs.handle, &rhs_handle, &output.handle],
);
output
}
/// Execute a unary scalar inplace kernel using the default settings.
pub fn unary_scalar_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
scalar: E,
) -> WgpuTensor<E, D> {
unary_scalar_inplace::<K, E, D, WORKGROUP_DEFAULT>(lhs, scalar)
}
/// Execute a unary scalar inplace kernel using the provided WORKGROUP.
pub fn unary_scalar_inplace<
K: StaticKernelSource,
E: WgpuElement,
const D: usize,
const WORKGROUP: usize,
>(
lhs: WgpuTensor<E, D>,
scalar: E,
) -> WgpuTensor<E, D> {
let num_elems = lhs.shape.num_elements();
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let rhs_handle = lhs.client.create(E::as_bytes(&[scalar]));
lhs.client
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]);
lhs
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{Distribution, Tensor};
unary_scalar!(TestKernel, ops "*");
unary_scalar_inplace!(TestKernelInplace, ops "*");
#[test]
fn unary_scalar_should_work_with_multiple_invocations() {
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
let actual = unary_scalar::<TestKernel, _, 2, 16>(tensor.into_primitive(), 5.0);
let expected = tensor_ref.mul_scalar(5.0);
expected.into_data().assert_approx_eq(
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
3,
);
}
#[test]
fn unary_scalar_inplace_should_work_with_multiple_invocations() {
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
let actual =
unary_scalar_inplace::<TestKernelInplace, _, 2, 16>(tensor.into_primitive(), 5.0);
let expected = tensor_ref.mul_scalar(5.0);
expected.into_data().assert_approx_eq(
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
3,
);
}
}

View File

@ -15,6 +15,8 @@ pub mod kernel;
/// Tensor module.
pub mod tensor;
pub(crate) mod codegen;
mod element;
pub use element::{FloatElement, IntElement};

View File

@ -1,10 +1,8 @@
use burn_tensor::ops::{ActivationOps, FloatTensor};
use crate::{
element::{FloatElement, IntElement},
kernel::{unary_default, unary_inplace_default},
unary, unary_inplace, GraphicsApi, Wgpu,
GraphicsApi, Wgpu,
};
use burn_tensor::ops::ActivationOps;
impl<G, F, I> ActivationOps<Wgpu<G, F, I>> for Wgpu<G, F, I>
where
@ -12,14 +10,4 @@ where
F: FloatElement,
I: IntElement,
{
fn relu<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Relu, body "output[id] = max(input[id], 0.0);");
unary_inplace!(ReluInplace, body "input[id] = max(input[id], 0.0);");
if tensor.can_mut() {
return unary_inplace_default::<ReluInplace, F, D>(tensor);
}
unary_default::<Relu, F, D>(tensor)
}
}

View File

@ -1,4 +1,5 @@
use super::numeric;
use crate::codegen::{Elem, Operator, Variable};
#[cfg(not(feature = "autotune"))]
use crate::kernel::matmul::init_matmul_output;
#[cfg(feature = "autotune")]
@ -8,18 +9,14 @@ use crate::kernel::matmul::vec4::matmul_tiling_2d_vec4;
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
#[cfg(not(feature = "autotune"))]
use crate::kernel::reduce::init_reduce_output;
use crate::kernel::{
self, reduce, unary_default, unary_inplace_default, unary_scalar_default,
unary_scalar_inplace_default,
};
use crate::{unary, unary_inplace, unary_scalar, FloatElement, GraphicsApi, IntElement, Wgpu};
use crate::{unary_scalar_inplace, WgpuDevice};
use crate::kernel::{self, reduce};
use crate::WgpuDevice;
use crate::{unary, FloatElement, GraphicsApi, IntElement, Wgpu};
use burn_tensor::ops::{
BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor,
};
use burn_tensor::{ops::TensorOps, Data, Distribution, Shape};
use burn_tensor::{ElementConversion, Reader};
use std::ops::Range;
impl<G, F, I> TensorOps<Wgpu<G, F, I>> for Wgpu<G, F, I>
@ -357,122 +354,115 @@ where
kernel::cast(tensor)
}
fn exp<const D: usize>(lhs: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Exp, func "exp");
unary_inplace!(ExpInplace, func "exp");
if lhs.can_mut() {
return unary_inplace_default::<ExpInplace, F, D>(lhs);
}
unary_default::<Exp, F, D>(lhs)
fn exp<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operator: |elem: Elem| Operator::Exp {
input: Variable::Input(0, elem),
out: Variable::Local(0, elem),
},
input: tensor,
elem: F
)
}
fn log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Log, func "log");
unary_inplace!(LogInplace, func "log");
if tensor.can_mut() {
return unary_inplace_default::<LogInplace, F, D>(tensor);
}
unary_default::<Log, F, D>(tensor)
unary!(
operator: |elem: Elem| Operator::Log {
input: Variable::Input(0, elem),
out: Variable::Local(0, elem),
},
input: tensor,
elem: F
)
}
fn log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Log1p, body "output[id] = log(1.0 + input[id]);");
unary_inplace!(Log1pInplace, body "input[id] = log(1.0 + input[id]);");
if tensor.can_mut() {
return unary_inplace_default::<Log1pInplace, F, D>(tensor);
}
unary_default::<Log1p, F, D>(tensor)
unary!(
operator: |elem: Elem| Operator::Log1p {
input: Variable::Input(0, elem),
out: Variable::Local(0, elem),
},
input: tensor,
elem: F
)
}
fn powf<const D: usize>(lhs: FloatTensor<Self, D>, rhs: f32) -> FloatTensor<Self, D> {
unary_scalar!(Powf, func "powf", include "../template/powf.wgsl");
unary_scalar_inplace!(PowfInplace, func "powf", include "../template/powf.wgsl");
if lhs.can_mut() {
return unary_scalar_inplace_default::<PowfInplace, F, D>(lhs, rhs.elem());
}
unary_scalar_default::<Powf, F, D>(lhs, rhs.elem())
unary!(
operator: |elem: Elem| Operator::Powf {
lhs: Variable::Input(0, elem),
rhs: Variable::Scalar(0, elem),
out: Variable::Local(0, elem),
},
input: lhs; rhs.elem(),
elem: F
)
}
fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Sqrt, func "sqrt");
unary_inplace!(SqrtInplace, func "sqrt");
if tensor.can_mut() {
return unary_inplace_default::<SqrtInplace, F, D>(tensor);
}
unary_default::<Sqrt, F, D>(tensor)
unary!(
operator: |elem: Elem| Operator::Sqrt {
input: Variable::Input(0, elem),
out: Variable::Local(0, elem),
},
input: tensor,
elem: F
)
}
fn abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Abs, func "abs");
unary_inplace!(AbsInplace, func "abs");
if tensor.can_mut() {
return unary_inplace_default::<AbsInplace, F, D>(tensor);
}
unary_default::<Abs, F, D>(tensor)
unary!(
operator: |elem: Elem| Operator::Abs {
input: Variable::Input(0, elem),
out: Variable::Local(0, elem),
},
input: tensor,
elem: F
)
}
fn cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Cos, func "cos");
unary_inplace!(CosInplace, func "cos");
if tensor.can_mut() {
return unary_inplace_default::<CosInplace, F, D>(tensor);
}
unary_default::<Cos, F, D>(tensor)
unary!(
operator: |elem: Elem| Operator::Cos {
input: Variable::Input(0, elem),
out: Variable::Local(0, elem),
},
input: tensor,
elem: F
)
}
fn sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Sin, func "sin");
unary_inplace!(SinInplace, func "sin");
if tensor.can_mut() {
return unary_inplace_default::<SinInplace, F, D>(tensor);
}
unary_default::<Sin, F, D>(tensor)
unary!(
operator: |elem: Elem| Operator::Sin {
input: Variable::Input(0, elem),
out: Variable::Local(0, elem),
},
input: tensor,
elem: F
)
}
fn tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
// Metal has a weird numerical behaviour with tanh which require a new function
#[cfg(target_os = "macos")]
unary!(Tanh, func "safe_tanh", include "../template/safe_tanh.wgsl");
#[cfg(target_os = "macos")]
unary_inplace!(TanhInplace, func "safe_tanh", include "../template/safe_tanh.wgsl");
#[cfg(not(target_os = "macos"))]
unary!(Tanh, func "tanh");
#[cfg(not(target_os = "macos"))]
unary_inplace!(TanhInplace, func "tanh");
if tensor.can_mut() {
return unary_inplace_default::<TanhInplace, F, D>(tensor);
}
unary_default::<Tanh, F, D>(tensor)
unary!(
operator: |elem: Elem| Operator::Tanh {
input: Variable::Input(0, elem),
out: Variable::Local(0, elem),
},
input: tensor,
elem: F
)
}
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Erf, func "erf", include "../template/erf.wgsl");
unary_inplace!(ErfInplace, func "erf", include "../template/erf.wgsl");
if tensor.can_mut() {
return unary_inplace_default::<ErfInplace, F, D>(tensor);
}
unary_default::<Erf, F, D>(tensor)
unary!(
operator: |elem: Elem| Operator::Erf {
input: Variable::Input(0, elem),
out: Variable::Local(0, elem),
},
input: tensor,
elem: F
)
}
fn cat<const D: usize>(tensors: Vec<FloatTensor<Self, D>>, dim: usize) -> FloatTensor<Self, D> {
@ -491,20 +481,6 @@ where
kernel::cast(tensor)
}
fn clamp_min<const D: usize>(
tensor: FloatTensor<Self, D>,
min: FloatElem<Self>,
) -> FloatTensor<Self, D> {
kernel::clamp_min(tensor, min)
}
fn clamp_max<const D: usize>(
tensor: FloatTensor<Self, D>,
max: FloatElem<Self>,
) -> FloatTensor<Self, D> {
kernel::clamp_max(tensor, max)
}
fn clamp<const D: usize>(
tensor: FloatTensor<Self, D>,
min: FloatElem<Self>,
@ -516,14 +492,14 @@ where
fn recip<const D: usize>(
tensor: FloatTensor<Wgpu<G, F, I>, D>,
) -> FloatTensor<Wgpu<G, F, I>, D> {
unary!(Recip, func "1.0 /");
unary_inplace!(RecipInplace, func "1.0 /");
if tensor.can_mut() {
return unary_inplace_default::<RecipInplace, F, D>(tensor);
}
unary_default::<Recip, F, D>(tensor)
unary!(
operator: |elem: Elem| Operator::Recip {
input: Variable::Input(0, elem),
out: Variable::Local(0, elem),
},
input: tensor,
elem: F
)
}
fn repeat<const D: usize>(

View File

@ -1,10 +1,10 @@
use super::numeric;
use crate::codegen::{Elem, Operator, Variable};
use crate::kernel::reduce::{self, init_reduce_output};
use crate::kernel::{unary_default, unary_inplace_default};
use crate::{
element::{FloatElement, IntElement},
kernel, unary, unary_inplace, GraphicsApi, Wgpu,
kernel, unary, GraphicsApi, Wgpu,
};
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
@ -280,20 +280,6 @@ where
kernel::reduce::argmin(tensor, dim)
}
fn int_clamp_min<const D: usize>(
tensor: IntTensor<Self, D>,
min: IntElem<Self>,
) -> IntTensor<Self, D> {
kernel::clamp_min(tensor, min)
}
fn int_clamp_max<const D: usize>(
tensor: IntTensor<Self, D>,
max: IntElem<Self>,
) -> IntTensor<Self, D> {
kernel::clamp_max(tensor, max)
}
fn int_clamp<const D: usize>(
tensor: IntTensor<Self, D>,
min: IntElem<Self>,
@ -303,14 +289,14 @@ where
}
fn int_abs<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
unary!(IntAbs, func "abs");
unary_inplace!(IntAbsInplace, func "abs");
if tensor.can_mut() {
return unary_inplace_default::<IntAbsInplace, I, D>(tensor);
}
unary_default::<IntAbs, I, D>(tensor)
unary!(
operator: |elem: Elem| Operator::Abs {
input: Variable::Input(0, elem),
out: Variable::Local(0, elem),
},
input: tensor,
elem: I
)
}
fn int_into_float<const D: usize>(tensor: IntTensor<Self, D>) -> FloatTensor<Self, D> {

View File

@ -1,11 +1,8 @@
use crate::codegen::{Elem, Operator, Variable};
use crate::compute::{compute_client, WgpuComputeClient};
use crate::kernel::{
binary_elemwise_default, binary_elemwise_inplace_default, unary_scalar_default,
unary_scalar_inplace_default,
};
use crate::kernel::{binary_elemwise_default, binary_elemwise_inplace_default};
use crate::{
binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor,
unary_scalar, unary_scalar_inplace,
binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor, unary,
};
use crate::{GraphicsApi, WgpuDevice};
use burn_tensor::{Element, ElementConversion, Shape};
@ -28,8 +25,14 @@ pub fn full_device<E: WgpuElement + Element, const D: usize>(
) -> WgpuTensor<E, D> {
let empty = empty_device(client, device, shape);
unary_scalar_inplace!(Full, body "lhs[id] = rhs;");
unary_scalar_inplace_default::<Full, E, D>(empty, value)
unary!(
operator: |elem: Elem| Operator::AssignLocal {
input: Variable::Scalar(0, elem),
out: Variable::Local(0, elem),
},
input: empty; value,
elem: E
)
}
pub fn zeros<G: GraphicsApi, E: WgpuElement + Element, const D: usize>(
@ -98,14 +101,15 @@ pub fn add_scalar<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<E, D> {
unary_scalar!(AddScalar, ops "+");
unary_scalar_inplace!(AddScalarInplace, ops "+");
if lhs.can_mut() {
return unary_scalar_inplace_default::<AddScalarInplace, E, D>(lhs, rhs);
}
unary_scalar_default::<AddScalar, E, D>(lhs, rhs)
unary!(
operator: |elem: Elem| Operator::Add {
lhs: Variable::Input(0, elem),
rhs: Variable::Scalar(0, elem),
out: Variable::Local(0, elem),
},
input: lhs; rhs,
elem: E
)
}
pub fn sub<E: WgpuElement, const D: usize>(
@ -126,14 +130,15 @@ pub fn sub_scalar<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<E, D> {
unary_scalar!(SubScalar, ops "-");
unary_scalar_inplace!(SubScalarInplace, ops "-");
if lhs.can_mut() {
return unary_scalar_inplace_default::<SubScalarInplace, E, D>(lhs, rhs);
}
unary_scalar_default::<SubScalar, E, D>(lhs, rhs)
unary!(
operator: |elem: Elem| Operator::Sub {
lhs: Variable::Input(0, elem),
rhs: Variable::Scalar(0, elem),
out: Variable::Local(0, elem),
},
input: lhs; rhs,
elem: E
)
}
pub fn mul<E: WgpuElement, const D: usize>(
@ -158,14 +163,15 @@ pub fn mul_scalar<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<E, D> {
unary_scalar!(MulScalar, ops "*");
unary_scalar_inplace!(MulScalarInplace, ops "*");
if lhs.can_mut() {
return unary_scalar_inplace_default::<MulScalarInplace, E, D>(lhs, rhs);
}
unary_scalar_default::<MulScalar, E, D>(lhs, rhs)
unary!(
operator: |elem: Elem| Operator::Mul {
lhs: Variable::Input(0, elem),
rhs: Variable::Scalar(0, elem),
out: Variable::Local(0, elem),
},
input: lhs; rhs,
elem: E
)
}
pub fn div<E: WgpuElement, const D: usize>(
@ -186,12 +192,13 @@ pub fn div_scalar<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<E, D> {
unary_scalar!(DivScalar, ops "/");
unary_scalar_inplace!(DivScalarInplace, ops "/");
if lhs.can_mut() {
return unary_scalar_inplace_default::<DivScalarInplace, E, D>(lhs, rhs);
}
unary_scalar_default::<DivScalar, E, D>(lhs, rhs)
unary!(
operator: |elem: Elem| Operator::Div {
lhs: Variable::Input(0, elem),
rhs: Variable::Scalar(0, elem),
out: Variable::Local(0, elem),
},
input: lhs; rhs,
elem: E
)
}

View File

@ -1,25 +0,0 @@
@group(0)
@binding(0)
var<storage, read> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> min_value: {{ elem }};
@group(0)
@binding(3)
var<storage, read> max_value: {{ elem }};
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
output[id] = clamp(input[id], min_value, max_value);
}

View File

@ -1,21 +0,0 @@
@group(0)
@binding(0)
var<storage, read_write> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> min_value: {{ elem }};
@group(0)
@binding(2)
var<storage, read> max_value: {{ elem }};
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
input[id] = clamp(input[id], min_value, max_value);
}

View File

@ -1,25 +0,0 @@
/// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations
///
/// > (maximum error: 1.5×107)
/// > All of these approximations are valid for x 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = erf(x).
fn erf_positive(x: {{ elem }}) -> {{ elem }} {
let p = 0.3275911;
let a1 = 0.254829592;
let a2 = -0.284496736;
let a3 = 1.421413741;
let a4 = -1.453152027;
let a5 = 1.061405429;
let t = 1.0 / (1.0 + p * abs(x));
let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1;
return 1.0 - (tmp * t * exp(-x * x));
}
fn erf(x: {{ elem }}) -> {{ elem }} {
if (x < 0.0) {
return -1.0 * erf_positive(-1.0 * x);
}
return erf_positive(x);
}

View File

@ -1,14 +0,0 @@
fn powf(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {
let modulo = rhs % 2.0;
if (modulo == 0.0) {
// Even number
return pow(abs(lhs), rhs);
} else if (modulo == 1.0 && lhs < 0.0) {
// Odd number
return -1.0 * pow(-1.0 * lhs, rhs);
} else {
// Float number
return pow(lhs, rhs);
}
}

View File

@ -1,8 +0,0 @@
/// Metal has a weird numerical behaviour with tanh for inputs over 43.0
fn safe_tanh(x: {{ elem }}) -> {{ elem }} {
if x > 43.0 {
return 1.0;
} else {
return tanh(x);
}
}

View File

@ -1,17 +0,0 @@
@group(0)
@binding(0)
var<storage, read> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
{{ body }}
}

View File

@ -1,13 +0,0 @@
@group(0)
@binding(0)
var<storage, read_write> input: array<{{ elem }}>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
{{ body }}
}

View File

@ -1,21 +0,0 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: {{ elem }};
@group(0)
@binding(2)
var<storage, read_write> output: array<{{ elem }}>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
{{ body }}
}

View File

@ -1,17 +0,0 @@
@group(0)
@binding(0)
var<storage, read_write> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: {{ elem }};
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
{{ body }}
}

View File

@ -1,8 +1,9 @@
use crate::codegen::{Elem, Operator, Variable};
use crate::element::WgpuElement;
use crate::{
compute::{WgpuComputeClient, WgpuHandle},
unary, WgpuDevice,
};
use crate::{element::WgpuElement, kernel::unary_default};
use burn_tensor::Shape;
use std::marker::PhantomData;
@ -96,8 +97,14 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
// slowdowns.
//
// The solution is just to use a simple unary compute shader.
unary!(CopyBuffer, body "output[id] = input[id];");
unary_default::<CopyBuffer, E, D>(self.clone())
unary!(
operator: |elem: Elem| Operator::AssignLocal {
input: Variable::Input(0, elem),
out: Variable::Local(0, elem),
},
input: self.clone(),
elem: E
)
}
/// Check if the tensor is safe to mutate.