mirror of https://github.com/tracel-ai/burn.git
Fusion wgpu compilation cache (#1069)
* Refactor fusion in the wgpu backend * WIP * Refactor * WIP * Fix inplace ops * Works ish * Cleanup Output * Refactoring * Refactor Clamp * Cleanup * Cleanup * Updates * Fix CI * Code review
This commit is contained in:
parent
042454a9db
commit
b5c49c5bf7
|
@ -96,7 +96,7 @@ pub trait OptimizationBuilder<B: FusionBackend>: Send {
|
||||||
/// The operation created from the [builder](OptimizationBuilder).
|
/// The operation created from the [builder](OptimizationBuilder).
|
||||||
pub trait Optimization<B: FusionBackend>: Send {
|
pub trait Optimization<B: FusionBackend>: Send {
|
||||||
/// Execute the operation.
|
/// Execute the operation.
|
||||||
fn execute(&self, context: &mut Context<'_, B>);
|
fn execute(&mut self, context: &mut Context<'_, B>);
|
||||||
/// The number of registered operations in this optimization.
|
/// The number of registered operations in this optimization.
|
||||||
fn len(&self) -> usize;
|
fn len(&self) -> usize;
|
||||||
/// If the current optimization is empty.
|
/// If the current optimization is empty.
|
||||||
|
|
|
@ -53,7 +53,7 @@ impl<B: FusionBackend> Graph<B> {
|
||||||
pub(crate) fn execute_optimization(
|
pub(crate) fn execute_optimization(
|
||||||
&mut self,
|
&mut self,
|
||||||
handles: &mut HandleContainer<B>,
|
handles: &mut HandleContainer<B>,
|
||||||
optimization: &dyn Optimization<B>,
|
optimization: &mut dyn Optimization<B>,
|
||||||
) {
|
) {
|
||||||
let num_keep = optimization.len();
|
let num_keep = optimization.len();
|
||||||
let mut context = self.converter.context(handles);
|
let mut context = self.converter.context(handles);
|
||||||
|
|
|
@ -682,20 +682,6 @@ impl<E: Element> NumericOpsDescription<E> {
|
||||||
out: desc.out.to_relative(converter),
|
out: desc.out.to_relative(converter),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
NumericOpsDescription::ClampMax(desc) => {
|
|
||||||
NumericOpsDescription::ClampMax(ScalarOpsDescription {
|
|
||||||
lhs: desc.lhs.to_relative(converter),
|
|
||||||
rhs: local_elem(converter, &desc.rhs),
|
|
||||||
out: desc.out.to_relative(converter),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
NumericOpsDescription::ClampMin(desc) => {
|
|
||||||
NumericOpsDescription::ClampMin(ScalarOpsDescription {
|
|
||||||
lhs: desc.lhs.to_relative(converter),
|
|
||||||
rhs: local_elem(converter, &desc.rhs),
|
|
||||||
out: desc.out.to_relative(converter),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,7 +71,7 @@ impl<B: FusionBackend> GraphExecution<B> {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
CacheResult::Found(ops) => {
|
CacheResult::Found(ops) => {
|
||||||
graph.execute_optimization(handles, ops.as_ref());
|
graph.execute_optimization(handles, ops.as_mut());
|
||||||
self.reset(graph);
|
self.reset(graph);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -107,14 +107,14 @@ impl<B: FusionBackend> GraphExecution<B> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
match find_best_optimization_index(&self.optimizations) {
|
match find_best_optimization_index(&mut self.optimizations) {
|
||||||
Some(index) => {
|
Some(index) => {
|
||||||
let (relative, next_ops) = Self::split_relative_graph_owned(graph, mode);
|
let (relative, next_ops) = Self::split_relative_graph_owned(graph, mode);
|
||||||
let optimization = &self.optimizations[index];
|
let optimization = &self.optimizations[index];
|
||||||
let ops = self
|
let ops = self
|
||||||
.optimization_cache
|
.optimization_cache
|
||||||
.complete(optimization, relative, next_ops);
|
.complete(optimization, relative, next_ops);
|
||||||
BuildAction::ExecuteOptimization(ops.as_ref())
|
BuildAction::ExecuteOptimization(ops.as_mut())
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
// TODO: Cache this result too.
|
// TODO: Cache this result too.
|
||||||
|
@ -184,7 +184,7 @@ impl<B: FusionBackend> GraphExecution<B> {
|
||||||
}
|
}
|
||||||
|
|
||||||
enum BuildAction<'a, B: FusionBackend> {
|
enum BuildAction<'a, B: FusionBackend> {
|
||||||
ExecuteOptimization(&'a dyn Optimization<B>),
|
ExecuteOptimization(&'a mut dyn Optimization<B>),
|
||||||
ExecuteOperations,
|
ExecuteOperations,
|
||||||
ContinueBuilding,
|
ContinueBuilding,
|
||||||
}
|
}
|
||||||
|
@ -202,7 +202,7 @@ fn still_optimizing<B: FusionBackend>(optimizations: &[Box<dyn OptimizationBuild
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_best_optimization_index<B: FusionBackend>(
|
fn find_best_optimization_index<B: FusionBackend>(
|
||||||
optimizations: &[Box<dyn OptimizationBuilder<B>>],
|
optimizations: &mut [Box<dyn OptimizationBuilder<B>>],
|
||||||
) -> Option<usize> {
|
) -> Option<usize> {
|
||||||
let mut best_index = None;
|
let mut best_index = None;
|
||||||
let mut best_score = 0;
|
let mut best_score = 0;
|
||||||
|
|
|
@ -379,16 +379,6 @@ pub enum NumericOpsDescription<E> {
|
||||||
/// Float => [clamp](burn_tensor::ops::TensorOps::clamp).
|
/// Float => [clamp](burn_tensor::ops::TensorOps::clamp).
|
||||||
/// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp).
|
/// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp).
|
||||||
Clamp(ClampOpsDescription<E>),
|
Clamp(ClampOpsDescription<E>),
|
||||||
/// Operation corresponding to:
|
|
||||||
///
|
|
||||||
/// Float => [clamp max](burn_tensor::ops::TensorOps::clamp_max).
|
|
||||||
/// Int => [clamp max](burn_tensor::ops::IntTensorOps::int_clamp_max).
|
|
||||||
ClampMax(ScalarOpsDescription<E>),
|
|
||||||
/// Operation corresponding to:
|
|
||||||
///
|
|
||||||
/// Float => [clamp min](burn_tensor::ops::TensorOps::clamp_min).
|
|
||||||
/// Int => [cleamp min](burn_tensor::ops::IntTensorOps::int_clamp_min).
|
|
||||||
ClampMin(ScalarOpsDescription<E>),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Operation description specific to an int tensor.
|
/// Operation description specific to an int tensor.
|
||||||
|
@ -900,12 +890,6 @@ impl<E: Element> NumericOpsDescription<E> {
|
||||||
NumericOpsDescription::Clamp(desc) => {
|
NumericOpsDescription::Clamp(desc) => {
|
||||||
vec![&desc.tensor, &desc.out]
|
vec![&desc.tensor, &desc.out]
|
||||||
}
|
}
|
||||||
NumericOpsDescription::ClampMin(desc) => {
|
|
||||||
vec![&desc.lhs, &desc.out]
|
|
||||||
}
|
|
||||||
NumericOpsDescription::ClampMax(desc) => {
|
|
||||||
vec![&desc.lhs, &desc.out]
|
|
||||||
}
|
|
||||||
NumericOpsDescription::Abs(desc) => {
|
NumericOpsDescription::Abs(desc) => {
|
||||||
vec![&desc.input, &desc.out]
|
vec![&desc.input, &desc.out]
|
||||||
}
|
}
|
||||||
|
@ -1144,8 +1128,6 @@ impl<E> core::hash::Hash for NumericOpsDescription<E> {
|
||||||
NumericOpsDescription::MaxDim(desc) => desc.hash(state),
|
NumericOpsDescription::MaxDim(desc) => desc.hash(state),
|
||||||
NumericOpsDescription::MinDim(desc) => desc.hash(state),
|
NumericOpsDescription::MinDim(desc) => desc.hash(state),
|
||||||
NumericOpsDescription::Clamp(desc) => desc.hash(state),
|
NumericOpsDescription::Clamp(desc) => desc.hash(state),
|
||||||
NumericOpsDescription::ClampMax(desc) => desc.hash(state),
|
|
||||||
NumericOpsDescription::ClampMin(desc) => desc.hash(state),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,16 +60,13 @@ impl<O> OptimizationCache<O> {
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(candidate) = self.found {
|
if let Some(candidate) = self.found {
|
||||||
return CacheResult::Found(&self.optimizations.get(candidate).unwrap().value);
|
return CacheResult::Found(&mut self.optimizations.get_mut(candidate).unwrap().value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invalidate candidates.
|
// Invalidate candidates.
|
||||||
let mut invalidated_candidate = Vec::new();
|
let mut invalidated_candidate = Vec::new();
|
||||||
for id in self.candidates.iter() {
|
for id in self.candidates.iter() {
|
||||||
let item = match self.optimizations.get(*id) {
|
let item = &self.optimizations[*id];
|
||||||
Some(item) => item,
|
|
||||||
None => panic!("Should have an optimization"),
|
|
||||||
};
|
|
||||||
let next_ops = graph.last().expect("Validated earlier");
|
let next_ops = graph.last().expect("Validated earlier");
|
||||||
let next_ops_index = graph.len() - 1;
|
let next_ops_index = graph.len() - 1;
|
||||||
let next_ops_candidate = match item.graph.get(next_ops_index) {
|
let next_ops_candidate = match item.graph.get(next_ops_index) {
|
||||||
|
@ -93,13 +90,13 @@ impl<O> OptimizationCache<O> {
|
||||||
Condition::NextOps(ops) => ops,
|
Condition::NextOps(ops) => ops,
|
||||||
Condition::Sync => {
|
Condition::Sync => {
|
||||||
self.found = Some(*id);
|
self.found = Some(*id);
|
||||||
return CacheResult::Found(&item.value);
|
break;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if item.end_conditions.contains(ops) {
|
if item.end_conditions.contains(ops) {
|
||||||
self.found = Some(*id);
|
self.found = Some(*id);
|
||||||
return CacheResult::Found(&item.value);
|
break;
|
||||||
} else {
|
} else {
|
||||||
self.availables.push((*id, graph.len()));
|
self.availables.push((*id, graph.len()));
|
||||||
invalidated_candidate.push(*id);
|
invalidated_candidate.push(*id);
|
||||||
|
@ -107,6 +104,10 @@ impl<O> OptimizationCache<O> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(id) = self.found {
|
||||||
|
return CacheResult::Found(&mut self.optimizations[id].value);
|
||||||
|
}
|
||||||
|
|
||||||
let mut updated_candidates = Vec::new();
|
let mut updated_candidates = Vec::new();
|
||||||
core::mem::swap(&mut updated_candidates, &mut self.candidates);
|
core::mem::swap(&mut updated_candidates, &mut self.candidates);
|
||||||
|
|
||||||
|
@ -136,7 +137,7 @@ impl<O> OptimizationCache<O> {
|
||||||
factory: &Factory,
|
factory: &Factory,
|
||||||
graph: Vec<TensorOpsDescription>,
|
graph: Vec<TensorOpsDescription>,
|
||||||
next_ops: Option<TensorOpsDescription>,
|
next_ops: Option<TensorOpsDescription>,
|
||||||
) -> &'a O {
|
) -> &'a mut O {
|
||||||
let existing_optim = self
|
let existing_optim = self
|
||||||
.availables
|
.availables
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -149,7 +150,7 @@ impl<O> OptimizationCache<O> {
|
||||||
optimization.end_conditions.push(ops)
|
optimization.end_conditions.push(ops)
|
||||||
};
|
};
|
||||||
|
|
||||||
return &optimization.value;
|
return &mut optimization.value;
|
||||||
};
|
};
|
||||||
|
|
||||||
self.starters
|
self.starters
|
||||||
|
@ -164,7 +165,9 @@ impl<O> OptimizationCache<O> {
|
||||||
};
|
};
|
||||||
|
|
||||||
self.optimizations.push(optimization);
|
self.optimizations.push(optimization);
|
||||||
&self.optimizations.last().unwrap().value
|
|
||||||
|
let last_index = self.optimizations.len() - 1;
|
||||||
|
&mut self.optimizations[last_index].value
|
||||||
}
|
}
|
||||||
|
|
||||||
// Signal that a new path will begin.
|
// Signal that a new path will begin.
|
||||||
|
@ -188,7 +191,7 @@ pub enum CacheResult<'a, T> {
|
||||||
/// happens.
|
/// happens.
|
||||||
OnPath,
|
OnPath,
|
||||||
/// An optimization has been found, and the best action is to execute it!
|
/// An optimization has been found, and the best action is to execute it!
|
||||||
Found(&'a T),
|
Found(&'a mut T),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// When checking if an optimization is possible, a start or an end condition ensures that this optimization is
|
/// When checking if an optimization is possible, a start or an end condition ensures that this optimization is
|
||||||
|
|
|
@ -265,48 +265,6 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clamp_min<const D: usize>(
|
|
||||||
tensor: FloatTensor<Self, D>,
|
|
||||||
min: FloatElem<Self>,
|
|
||||||
) -> FloatTensor<Self, D> {
|
|
||||||
scalar_float_ops!(ClampMinOps, B::clamp_min);
|
|
||||||
|
|
||||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
|
||||||
|
|
||||||
let desc = ScalarOpsDescription {
|
|
||||||
lhs: tensor.into_description(),
|
|
||||||
rhs: min.elem(),
|
|
||||||
out: out.to_description_out(),
|
|
||||||
};
|
|
||||||
out.client.register(
|
|
||||||
TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::ClampMin(desc.clone())),
|
|
||||||
ClampMinOps::<D>::new(desc),
|
|
||||||
);
|
|
||||||
|
|
||||||
out
|
|
||||||
}
|
|
||||||
|
|
||||||
fn clamp_max<const D: usize>(
|
|
||||||
tensor: FloatTensor<Self, D>,
|
|
||||||
max: FloatElem<Self>,
|
|
||||||
) -> FloatTensor<Self, D> {
|
|
||||||
scalar_float_ops!(ClampMaxOps, B::clamp_max);
|
|
||||||
|
|
||||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
|
||||||
|
|
||||||
let desc = ScalarOpsDescription {
|
|
||||||
lhs: tensor.into_description(),
|
|
||||||
rhs: max.elem(),
|
|
||||||
out: out.to_description_out(),
|
|
||||||
};
|
|
||||||
out.client.register(
|
|
||||||
TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::ClampMax(desc.clone())),
|
|
||||||
ClampMaxOps::<D>::new(desc),
|
|
||||||
);
|
|
||||||
|
|
||||||
out
|
|
||||||
}
|
|
||||||
|
|
||||||
fn clamp<const D: usize>(
|
fn clamp<const D: usize>(
|
||||||
tensor: FloatTensor<Self, D>,
|
tensor: FloatTensor<Self, D>,
|
||||||
min: FloatElem<Self>,
|
min: FloatElem<Self>,
|
||||||
|
|
|
@ -1034,48 +1034,6 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
fn int_clamp_min<const D: usize>(
|
|
||||||
tensor: IntTensor<Self, D>,
|
|
||||||
min: IntElem<Self>,
|
|
||||||
) -> IntTensor<Self, D> {
|
|
||||||
scalar_int_ops!(ClampMinOps, B::int_clamp_min);
|
|
||||||
|
|
||||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
|
||||||
|
|
||||||
let desc = ScalarOpsDescription {
|
|
||||||
lhs: tensor.into_description(),
|
|
||||||
rhs: min.elem(),
|
|
||||||
out: out.to_description_out(),
|
|
||||||
};
|
|
||||||
out.client.register(
|
|
||||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::ClampMin(desc.clone())),
|
|
||||||
ClampMinOps::<D>::new(desc),
|
|
||||||
);
|
|
||||||
|
|
||||||
out
|
|
||||||
}
|
|
||||||
|
|
||||||
fn int_clamp_max<const D: usize>(
|
|
||||||
tensor: IntTensor<Self, D>,
|
|
||||||
max: IntElem<Self>,
|
|
||||||
) -> IntTensor<Self, D> {
|
|
||||||
scalar_int_ops!(ClampMaxOps, B::int_clamp_max);
|
|
||||||
|
|
||||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
|
||||||
|
|
||||||
let desc = ScalarOpsDescription {
|
|
||||||
lhs: tensor.into_description(),
|
|
||||||
rhs: max.elem(),
|
|
||||||
out: out.to_description_out(),
|
|
||||||
};
|
|
||||||
out.client.register(
|
|
||||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::ClampMax(desc.clone())),
|
|
||||||
ClampMaxOps::<D>::new(desc),
|
|
||||||
);
|
|
||||||
|
|
||||||
out
|
|
||||||
}
|
|
||||||
|
|
||||||
fn int_clamp<const D: usize>(
|
fn int_clamp<const D: usize>(
|
||||||
tensor: IntTensor<Self, D>,
|
tensor: IntTensor<Self, D>,
|
||||||
min: IntElem<Self>,
|
min: IntElem<Self>,
|
||||||
|
|
|
@ -5,7 +5,7 @@ use std::fmt::Display;
|
||||||
///
|
///
|
||||||
/// Note that the body assumes that the kernel will run on a 2D grid defined by the workgroup size
|
/// Note that the body assumes that the kernel will run on a 2D grid defined by the workgroup size
|
||||||
/// X and Y, but with Z=1.
|
/// X and Y, but with Z=1.
|
||||||
#[derive(Hash, new)]
|
#[derive(new)]
|
||||||
pub struct Body {
|
pub struct Body {
|
||||||
operators: Vec<Operator>,
|
operators: Vec<Operator>,
|
||||||
}
|
}
|
|
@ -2,7 +2,7 @@ use super::Elem;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
|
|
||||||
/// Not all functions are native to WGSL, so this struct allows to support more functions.
|
/// Not all functions are native to WGSL, so this struct allows to support more functions.
|
||||||
#[derive(Hash, PartialEq, Eq, Clone)]
|
#[derive(PartialEq, Eq, Clone)]
|
||||||
pub enum Function {
|
pub enum Function {
|
||||||
Powf(Elem),
|
Powf(Elem),
|
||||||
Erf(Elem),
|
Erf(Elem),
|
|
@ -0,0 +1,359 @@
|
||||||
|
use crate::codegen::{
|
||||||
|
Binding, Body, ComputeShader, Elem, Function, Location, Operator, Variable, Visibility,
|
||||||
|
WorkgroupSize,
|
||||||
|
};
|
||||||
|
use crate::compute::{StaticKernel, WgpuComputeClient, WgpuHandle};
|
||||||
|
use crate::element::WgpuElement;
|
||||||
|
use crate::kernel::{elemwise_workgroup, StaticKernelSource, WORKGROUP_DEFAULT};
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
/// Kernel creation input phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
|
||||||
|
pub struct InputPhase;
|
||||||
|
/// Kernel creation body phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
|
||||||
|
pub struct BodyPhase;
|
||||||
|
/// Kernel creation output phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
|
||||||
|
pub struct OutputPhase;
|
||||||
|
/// Kernel compilation phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
|
||||||
|
pub struct CompilationPhase;
|
||||||
|
|
||||||
|
/// Allows to create custom wgsl kernels based on configured inputs, body and outputs.
|
||||||
|
///
|
||||||
|
/// This type has 4 phases that must be executed in order, but no worry the type system won't allow
|
||||||
|
/// you to make mistakes.
|
||||||
|
///
|
||||||
|
/// 1. [Input Phase](InputPhase)
|
||||||
|
/// This phase focuses on registering the input arrays and scalars that are going to be used by
|
||||||
|
/// the kernel.
|
||||||
|
/// 2. [Body Phase](BodyPhase)
|
||||||
|
/// After the input phase is done, all the operations that happen in the body must be
|
||||||
|
/// registered.
|
||||||
|
/// 3. [Output Phase](OutputPhase)
|
||||||
|
/// This step focuses on registering all output arrays or inputs that the kernel needs to write to.
|
||||||
|
/// 4. [Compilation Phase](CompilationPhase)
|
||||||
|
/// Now that all other phases are completed, we can actually compile the kernel.
|
||||||
|
pub struct ElemWiseKernelCodegen<Phase = InputPhase> {
|
||||||
|
operations: Vec<Operator>,
|
||||||
|
input_bindings: Vec<Binding>,
|
||||||
|
output_bindings: Vec<Binding>,
|
||||||
|
named_bindings: Vec<(String, Binding)>,
|
||||||
|
functions: Vec<Function>,
|
||||||
|
_phase: PhantomData<Phase>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum Input {
|
||||||
|
Array {
|
||||||
|
elem: Elem,
|
||||||
|
visibility: Visibility,
|
||||||
|
strategy: ReadingStrategy,
|
||||||
|
},
|
||||||
|
Scalar {
|
||||||
|
elem: Elem,
|
||||||
|
size: usize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum ReadingStrategy {
|
||||||
|
IntoContiguous,
|
||||||
|
Plain,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum Output {
|
||||||
|
Array { elem: Elem, local: u16 },
|
||||||
|
Input { elem: Elem, input: u16, local: u16 },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ElemWiseKernelCodegen<InputPhase> {
|
||||||
|
/// Create a new fusion kernel on the given device.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
operations: Vec::new(),
|
||||||
|
input_bindings: Vec::new(),
|
||||||
|
output_bindings: Vec::new(),
|
||||||
|
named_bindings: Vec::new(),
|
||||||
|
functions: Vec::new(),
|
||||||
|
_phase: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register the inputs used by the kernel.
|
||||||
|
pub fn inputs(mut self, inputs: &[Input]) -> ElemWiseKernelCodegen<BodyPhase> {
|
||||||
|
let mut index: u16 = 0;
|
||||||
|
|
||||||
|
let first_output_index = inputs
|
||||||
|
.iter()
|
||||||
|
.filter(|input| match input {
|
||||||
|
Input::Array {
|
||||||
|
elem: _,
|
||||||
|
visibility: _,
|
||||||
|
strategy: _,
|
||||||
|
} => true,
|
||||||
|
Input::Scalar { elem: _, size: _ } => false,
|
||||||
|
})
|
||||||
|
.count();
|
||||||
|
|
||||||
|
for input in inputs {
|
||||||
|
match input {
|
||||||
|
Input::Array {
|
||||||
|
elem,
|
||||||
|
visibility,
|
||||||
|
strategy,
|
||||||
|
} => {
|
||||||
|
self.input_bindings.push(Binding {
|
||||||
|
elem: bool_elem(*elem),
|
||||||
|
visibility: *visibility,
|
||||||
|
location: Location::Storage,
|
||||||
|
size: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
match strategy {
|
||||||
|
ReadingStrategy::IntoContiguous => {
|
||||||
|
self.operations.push(Operator::ReadGlobalIntoContiguous {
|
||||||
|
variable: Variable::Input(index, *elem),
|
||||||
|
position: index as usize,
|
||||||
|
position_out: first_output_index, // First output
|
||||||
|
});
|
||||||
|
}
|
||||||
|
ReadingStrategy::Plain => {
|
||||||
|
self.operations.push(Operator::ReadGlobal {
|
||||||
|
variable: Variable::Input(index, *elem),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
Input::Scalar { elem, size } => {
|
||||||
|
let elem = bool_elem(*elem);
|
||||||
|
|
||||||
|
self.named_bindings.push((
|
||||||
|
format!("scalars_{}", elem),
|
||||||
|
Binding {
|
||||||
|
elem,
|
||||||
|
visibility: Visibility::Read,
|
||||||
|
location: Location::Storage,
|
||||||
|
size: Some(*size),
|
||||||
|
},
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ElemWiseKernelCodegen {
|
||||||
|
operations: self.operations,
|
||||||
|
input_bindings: self.input_bindings,
|
||||||
|
output_bindings: self.output_bindings,
|
||||||
|
named_bindings: self.named_bindings,
|
||||||
|
functions: self.functions,
|
||||||
|
_phase: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ElemWiseKernelCodegen<BodyPhase> {
|
||||||
|
/// Register the [operators](Operator) that the kernel must execute in the order provided.
|
||||||
|
pub fn body(mut self, operators: &[Operator]) -> ElemWiseKernelCodegen<OutputPhase> {
|
||||||
|
let mut register_function = |function: Function| {
|
||||||
|
if !self.functions.contains(&function) {
|
||||||
|
self.functions.push(function);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Since not all operators are native to WGSL, we need to add the custom ones.
|
||||||
|
for ops in operators.iter() {
|
||||||
|
match ops {
|
||||||
|
Operator::Powf {
|
||||||
|
lhs: _,
|
||||||
|
rhs: _,
|
||||||
|
out: _,
|
||||||
|
} => {
|
||||||
|
register_function(Function::Powf(Elem::F32));
|
||||||
|
}
|
||||||
|
Operator::Erf { input: _, out: _ } => {
|
||||||
|
register_function(Function::Erf(Elem::F32));
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
self.operations.push(ops.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
ElemWiseKernelCodegen {
|
||||||
|
operations: self.operations,
|
||||||
|
input_bindings: self.input_bindings,
|
||||||
|
output_bindings: self.output_bindings,
|
||||||
|
named_bindings: self.named_bindings,
|
||||||
|
functions: self.functions,
|
||||||
|
_phase: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ElemWiseKernelCodegen<OutputPhase> {
|
||||||
|
/// Register the outputs with their local variable index.
|
||||||
|
///
|
||||||
|
/// Note that the index corresponds to the registered [operator](Operator) number at the
|
||||||
|
/// [body phase](BodyPhase).
|
||||||
|
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
|
||||||
|
pub fn outputs(mut self, outputs: &[Output]) -> ElemWiseKernelCodegen<CompilationPhase> {
|
||||||
|
let mut index = 0;
|
||||||
|
|
||||||
|
for array in outputs {
|
||||||
|
match array {
|
||||||
|
Output::Array { elem, local } => {
|
||||||
|
let elem_adapted = bool_elem(*elem);
|
||||||
|
|
||||||
|
self.output_bindings.push(Binding {
|
||||||
|
elem: elem_adapted,
|
||||||
|
visibility: Visibility::ReadWrite,
|
||||||
|
location: Location::Storage,
|
||||||
|
size: None,
|
||||||
|
});
|
||||||
|
self.operations.push(Operator::AssignGlobal {
|
||||||
|
input: Variable::Local(*local, *elem),
|
||||||
|
out: Variable::Output(index, elem_adapted),
|
||||||
|
});
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
Output::Input { elem, input, local } => {
|
||||||
|
self.operations.push(Operator::AssignGlobal {
|
||||||
|
input: Variable::Local(*local, *elem),
|
||||||
|
out: Variable::Input(*input, bool_elem(*elem)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ElemWiseKernelCodegen {
|
||||||
|
operations: self.operations,
|
||||||
|
input_bindings: self.input_bindings,
|
||||||
|
output_bindings: self.output_bindings,
|
||||||
|
named_bindings: self.named_bindings,
|
||||||
|
functions: self.functions,
|
||||||
|
_phase: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ElemWiseKernelCodegen<CompilationPhase> {
|
||||||
|
/// Compile the kernel into a [compute shader](ComputeShader).
|
||||||
|
pub fn compile(self) -> ComputeShader {
|
||||||
|
let inputs = self.input_bindings;
|
||||||
|
let outputs = self.output_bindings;
|
||||||
|
let mut named = Vec::with_capacity(2);
|
||||||
|
|
||||||
|
named.push((
|
||||||
|
"info".to_string(),
|
||||||
|
Binding {
|
||||||
|
elem: Elem::U32,
|
||||||
|
visibility: Visibility::Read,
|
||||||
|
location: Location::Storage,
|
||||||
|
size: None, // We avoid putting the length here since it will force a new kernel
|
||||||
|
// for each tensor rank.
|
||||||
|
},
|
||||||
|
));
|
||||||
|
|
||||||
|
for (name, binding) in self.named_bindings.into_iter() {
|
||||||
|
named.push((name, binding));
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputeShader {
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
named,
|
||||||
|
workgroup_size: WorkgroupSize::default(),
|
||||||
|
body: Body::new(self.operations),
|
||||||
|
num_workgroups: true,
|
||||||
|
global_invocation_id: true,
|
||||||
|
functions: self.functions,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(new)]
|
||||||
|
pub struct StaticHandle<'a> {
|
||||||
|
handle: &'a WgpuHandle,
|
||||||
|
strides: &'a [usize],
|
||||||
|
shape: &'a [usize],
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a static kernel.
|
||||||
|
///
|
||||||
|
///
|
||||||
|
/// The limitation from this method is that you can't launch a kernel with multiple types of
|
||||||
|
/// scalar.
|
||||||
|
pub fn execute_static<K, E: WgpuElement>(
|
||||||
|
inputs: &[StaticHandle],
|
||||||
|
outputs: &[StaticHandle],
|
||||||
|
scalar_elems: Option<&[E]>,
|
||||||
|
client: WgpuComputeClient,
|
||||||
|
) where
|
||||||
|
K: StaticKernelSource + 'static,
|
||||||
|
{
|
||||||
|
let mut info = Vec::new();
|
||||||
|
let mut handles = Vec::with_capacity(inputs.len() + outputs.len() + 2);
|
||||||
|
|
||||||
|
// Inner function to fill the info buffer.
|
||||||
|
let mut register_info_tensor = |strides: &[usize], shape: &[usize]| {
|
||||||
|
if info.is_empty() {
|
||||||
|
info.push(strides.len() as u32);
|
||||||
|
}
|
||||||
|
|
||||||
|
for s in strides.iter() {
|
||||||
|
info.push(*s as u32);
|
||||||
|
}
|
||||||
|
for s in shape.iter() {
|
||||||
|
info.push(*s as u32);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// We start by registering the inputs.
|
||||||
|
for input in inputs.iter() {
|
||||||
|
register_info_tensor(input.strides, input.shape);
|
||||||
|
handles.push(input.handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut num_elems_output = 0;
|
||||||
|
|
||||||
|
// Then we follow with the outputs.
|
||||||
|
for output in outputs.iter() {
|
||||||
|
let num_elems = calculate_num_elems_dyn_rank(output.shape);
|
||||||
|
if num_elems > num_elems_output {
|
||||||
|
num_elems_output = num_elems;
|
||||||
|
}
|
||||||
|
register_info_tensor(output.strides, output.shape);
|
||||||
|
handles.push(output.handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
let info = &client.create(bytemuck::cast_slice(&info));
|
||||||
|
handles.push(info);
|
||||||
|
|
||||||
|
// Finally we finish with the named bindings.
|
||||||
|
let mut scalars = None;
|
||||||
|
if let Some(values) = &scalar_elems {
|
||||||
|
scalars = Some(client.create(bytemuck::cast_slice(values)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(scalars) = scalars.as_ref() {
|
||||||
|
handles.push(scalars);
|
||||||
|
}
|
||||||
|
|
||||||
|
let workgroup = elemwise_workgroup(num_elems_output, WORKGROUP_DEFAULT);
|
||||||
|
let kernel = Box::new(StaticKernel::<K>::new(workgroup));
|
||||||
|
|
||||||
|
client.execute(kernel, &handles);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
|
||||||
|
let mut num_elems = 1;
|
||||||
|
for i in shape.iter() {
|
||||||
|
num_elems *= i;
|
||||||
|
}
|
||||||
|
num_elems
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bool_elem(elem: Elem) -> Elem {
|
||||||
|
match elem {
|
||||||
|
// I32 are used for bool tensors
|
||||||
|
Elem::Bool => Elem::I32,
|
||||||
|
_ => elem,
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
mod body;
|
||||||
|
mod function;
|
||||||
|
mod kernel;
|
||||||
|
mod operator;
|
||||||
|
mod shader;
|
||||||
|
mod variable;
|
||||||
|
|
||||||
|
pub(crate) use body::*;
|
||||||
|
pub(crate) use function::*;
|
||||||
|
pub(crate) use kernel::*;
|
||||||
|
pub(crate) use operator::*;
|
||||||
|
pub(crate) use shader::*;
|
||||||
|
pub(crate) use variable::*;
|
|
@ -1,8 +1,9 @@
|
||||||
use super::Variable;
|
use super::variable::Variable;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
|
|
||||||
/// All operators that can be fused in a WGSL compute shader.
|
/// All operators that can be fused in a WGSL compute shader.
|
||||||
#[derive(Debug, Hash, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
#[allow(dead_code)] // Some variants might not be used with different flags
|
||||||
pub enum Operator {
|
pub enum Operator {
|
||||||
Add {
|
Add {
|
||||||
lhs: Variable,
|
lhs: Variable,
|
||||||
|
@ -57,6 +58,10 @@ pub enum Operator {
|
||||||
rhs: Variable,
|
rhs: Variable,
|
||||||
out: Variable,
|
out: Variable,
|
||||||
},
|
},
|
||||||
|
Sqrt {
|
||||||
|
input: Variable,
|
||||||
|
out: Variable,
|
||||||
|
},
|
||||||
Erf {
|
Erf {
|
||||||
input: Variable,
|
input: Variable,
|
||||||
out: Variable,
|
out: Variable,
|
||||||
|
@ -75,6 +80,12 @@ pub enum Operator {
|
||||||
rhs: Variable,
|
rhs: Variable,
|
||||||
out: Variable,
|
out: Variable,
|
||||||
},
|
},
|
||||||
|
Clamp {
|
||||||
|
input: Variable,
|
||||||
|
min_value: Variable,
|
||||||
|
max_value: Variable,
|
||||||
|
out: Variable,
|
||||||
|
},
|
||||||
Greater {
|
Greater {
|
||||||
lhs: Variable,
|
lhs: Variable,
|
||||||
rhs: Variable,
|
rhs: Variable,
|
||||||
|
@ -100,8 +111,15 @@ pub enum Operator {
|
||||||
input: Variable,
|
input: Variable,
|
||||||
out: Variable,
|
out: Variable,
|
||||||
},
|
},
|
||||||
|
AssignLocal {
|
||||||
|
input: Variable,
|
||||||
|
out: Variable,
|
||||||
|
},
|
||||||
ReadGlobal {
|
ReadGlobal {
|
||||||
variable: Variable,
|
variable: Variable,
|
||||||
|
},
|
||||||
|
ReadGlobalIntoContiguous {
|
||||||
|
variable: Variable,
|
||||||
position: usize,
|
position: usize,
|
||||||
position_out: usize,
|
position_out: usize,
|
||||||
},
|
},
|
||||||
|
@ -125,9 +143,20 @@ impl Display for Operator {
|
||||||
Operator::Abs { input, out } => f.write_fmt(format_args!("let {out} = abs({input});")),
|
Operator::Abs { input, out } => f.write_fmt(format_args!("let {out} = abs({input});")),
|
||||||
Operator::Exp { input, out } => f.write_fmt(format_args!("let {out} = exp({input});")),
|
Operator::Exp { input, out } => f.write_fmt(format_args!("let {out} = exp({input});")),
|
||||||
Operator::Log { input, out } => f.write_fmt(format_args!("let {out} = log({input});")),
|
Operator::Log { input, out } => f.write_fmt(format_args!("let {out} = log({input});")),
|
||||||
|
Operator::Clamp {
|
||||||
|
input,
|
||||||
|
min_value,
|
||||||
|
max_value,
|
||||||
|
out,
|
||||||
|
} => f.write_fmt(format_args!(
|
||||||
|
"let {out} = clamp({input}, {min_value}, {max_value});"
|
||||||
|
)),
|
||||||
Operator::Powf { lhs, rhs, out } => {
|
Operator::Powf { lhs, rhs, out } => {
|
||||||
f.write_fmt(format_args!("let {out} = powf({lhs}, {rhs});"))
|
f.write_fmt(format_args!("let {out} = powf({lhs}, {rhs});"))
|
||||||
}
|
}
|
||||||
|
Operator::Sqrt { input, out } => {
|
||||||
|
f.write_fmt(format_args!("let {out} = sqrt({input});"))
|
||||||
|
}
|
||||||
Operator::Log1p { input, out } => {
|
Operator::Log1p { input, out } => {
|
||||||
f.write_fmt(format_args!("let {out} = log({input} + 1.0);"))
|
f.write_fmt(format_args!("let {out} = log({input} + 1.0);"))
|
||||||
}
|
}
|
||||||
|
@ -159,7 +188,21 @@ impl Display for Operator {
|
||||||
let elem = out.elem();
|
let elem = out.elem();
|
||||||
f.write_fmt(format_args!("{out}_global[id] = {elem}({input});"))
|
f.write_fmt(format_args!("{out}_global[id] = {elem}({input});"))
|
||||||
}
|
}
|
||||||
Operator::ReadGlobal {
|
Operator::AssignLocal { input, out } => {
|
||||||
|
let elem = out.elem();
|
||||||
|
f.write_fmt(format_args!("let {out} = {elem}({input});"))
|
||||||
|
}
|
||||||
|
Operator::ReadGlobal { variable } => match variable {
|
||||||
|
Variable::Input(number, _elem) => f.write_fmt(format_args!(
|
||||||
|
"let input_{number} = input_{number}_global[id];"
|
||||||
|
)),
|
||||||
|
Variable::Local(_, _) => panic!("can't read global local variable."),
|
||||||
|
Variable::Output(number, _elem) => f.write_fmt(format_args!(
|
||||||
|
"let output_{number} = output_{number}_global[id];"
|
||||||
|
)),
|
||||||
|
Variable::Scalar(_, _) => panic!("Can't read global scalar variable."),
|
||||||
|
},
|
||||||
|
Operator::ReadGlobalIntoContiguous {
|
||||||
variable,
|
variable,
|
||||||
position,
|
position,
|
||||||
position_out,
|
position_out,
|
|
@ -1,34 +1,29 @@
|
||||||
use super::{Body, Function};
|
use super::{Body, Function};
|
||||||
use crate::kernel::{DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT};
|
use crate::kernel::WORKGROUP_DEFAULT;
|
||||||
use std::{
|
use std::fmt::Display;
|
||||||
collections::hash_map::DefaultHasher,
|
|
||||||
fmt::Display,
|
|
||||||
hash::{Hash, Hasher},
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Hash, PartialEq, Eq)]
|
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||||
pub enum Location {
|
pub enum Location {
|
||||||
Storage,
|
Storage,
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
Workgroup,
|
Workgroup,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Hash, PartialEq, Eq)]
|
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||||
pub enum Visibility {
|
pub enum Visibility {
|
||||||
Read,
|
Read,
|
||||||
ReadWrite,
|
ReadWrite,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
|
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||||
pub enum Elem {
|
pub enum Elem {
|
||||||
F32,
|
F32,
|
||||||
#[allow(dead_code)]
|
|
||||||
I32,
|
I32,
|
||||||
U32,
|
U32,
|
||||||
Bool,
|
Bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Hash, PartialEq, Eq)]
|
#[derive(PartialEq, Eq, Clone)]
|
||||||
pub struct Binding {
|
pub struct Binding {
|
||||||
pub location: Location,
|
pub location: Location,
|
||||||
pub visibility: Visibility,
|
pub visibility: Visibility,
|
||||||
|
@ -36,7 +31,7 @@ pub struct Binding {
|
||||||
pub size: Option<usize>,
|
pub size: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Hash, PartialEq, Eq)]
|
#[derive(PartialEq, Eq)]
|
||||||
pub struct WorkgroupSize {
|
pub struct WorkgroupSize {
|
||||||
pub x: usize,
|
pub x: usize,
|
||||||
pub y: usize,
|
pub y: usize,
|
||||||
|
@ -53,7 +48,6 @@ impl Default for WorkgroupSize {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Hash)]
|
|
||||||
pub struct ComputeShader {
|
pub struct ComputeShader {
|
||||||
pub inputs: Vec<Binding>,
|
pub inputs: Vec<Binding>,
|
||||||
pub outputs: Vec<Binding>,
|
pub outputs: Vec<Binding>,
|
||||||
|
@ -65,19 +59,6 @@ pub struct ComputeShader {
|
||||||
pub functions: Vec<Function>,
|
pub functions: Vec<Function>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DynamicKernelSource for ComputeShader {
|
|
||||||
fn source(&self) -> SourceTemplate {
|
|
||||||
SourceTemplate::new(self.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn id(&self) -> String {
|
|
||||||
let mut s = DefaultHasher::new();
|
|
||||||
self.hash(&mut s);
|
|
||||||
|
|
||||||
s.finish().to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for ComputeShader {
|
impl Display for ComputeShader {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
Self::format_bindings(f, "input", &self.inputs, 0)?;
|
Self::format_bindings(f, "input", &self.inputs, 0)?;
|
|
@ -1,7 +1,7 @@
|
||||||
use super::Elem;
|
use super::Elem;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
|
|
||||||
#[derive(Debug, Hash, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Variable {
|
pub enum Variable {
|
||||||
Input(u16, Elem),
|
Input(u16, Elem),
|
||||||
Scalar(u16, Elem),
|
Scalar(u16, Elem),
|
|
@ -9,8 +9,7 @@ where
|
||||||
fn type_name() -> &'static str;
|
fn type_name() -> &'static str;
|
||||||
fn as_bytes(slice: &[Self]) -> &[u8];
|
fn as_bytes(slice: &[Self]) -> &[u8];
|
||||||
fn from_bytes(bytes: &[u8]) -> &[Self];
|
fn from_bytes(bytes: &[u8]) -> &[Self];
|
||||||
#[cfg(any(feature = "fusion", test))]
|
fn elem_type() -> crate::codegen::Elem;
|
||||||
fn elem_type() -> crate::fusion::codegen::Elem;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The float element type for the wgpu backend.
|
/// The float element type for the wgpu backend.
|
||||||
|
@ -29,9 +28,8 @@ impl WgpuElement for u32 {
|
||||||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||||
bytemuck::cast_slice(bytes)
|
bytemuck::cast_slice(bytes)
|
||||||
}
|
}
|
||||||
#[cfg(any(feature = "fusion", test))]
|
fn elem_type() -> crate::codegen::Elem {
|
||||||
fn elem_type() -> crate::fusion::codegen::Elem {
|
crate::codegen::Elem::U32
|
||||||
crate::fusion::codegen::Elem::U32
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,9 +43,8 @@ impl WgpuElement for i32 {
|
||||||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||||
bytemuck::cast_slice(bytes)
|
bytemuck::cast_slice(bytes)
|
||||||
}
|
}
|
||||||
#[cfg(any(feature = "fusion", test))]
|
fn elem_type() -> crate::codegen::Elem {
|
||||||
fn elem_type() -> crate::fusion::codegen::Elem {
|
crate::codegen::Elem::I32
|
||||||
crate::fusion::codegen::Elem::I32
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,9 +59,8 @@ impl WgpuElement for f32 {
|
||||||
bytemuck::cast_slice(bytes)
|
bytemuck::cast_slice(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(any(feature = "fusion", test))]
|
fn elem_type() -> crate::codegen::Elem {
|
||||||
fn elem_type() -> crate::fusion::codegen::Elem {
|
crate::codegen::Elem::F32
|
||||||
crate::fusion::codegen::Elem::F32
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -81,14 +81,6 @@ pub fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
|
||||||
strides
|
strides
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
|
|
||||||
let mut num_elems = 1;
|
|
||||||
for i in shape.iter() {
|
|
||||||
num_elems *= i;
|
|
||||||
}
|
|
||||||
num_elems
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(new, Debug, Clone)]
|
#[derive(new, Debug, Clone)]
|
||||||
/// Handle to be used when fusing operations.
|
/// Handle to be used when fusing operations.
|
||||||
pub struct WgpuFusionHandle {
|
pub struct WgpuFusionHandle {
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
use crate::{
|
||||||
|
codegen::ComputeShader,
|
||||||
|
kernel::{DynamicKernelSource, SourceTemplate},
|
||||||
|
};
|
||||||
|
use hashbrown::HashSet;
|
||||||
|
|
||||||
|
/// This cache ensures that the generation of the source code is only done once when the kernel is
|
||||||
|
/// executed for the first time. Following, we only include the ID in the dynamic kernel source,
|
||||||
|
/// since we rely on the compilation cache of the WGPU compute server.
|
||||||
|
///
|
||||||
|
/// If it ever causes problems, we could cache the compute shader and put it into an Arc to avoid deep
|
||||||
|
/// cloning.
|
||||||
|
#[derive(Default, Debug)]
|
||||||
|
pub struct KernelCompilationCache {
|
||||||
|
already_compiled_ids: HashSet<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(new)]
|
||||||
|
pub enum FusedKernelSource {
|
||||||
|
AlreadyCompiled { id: String },
|
||||||
|
NewKernel { id: String, shader: ComputeShader },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DynamicKernelSource for FusedKernelSource {
|
||||||
|
fn source(&self) -> SourceTemplate {
|
||||||
|
match self {
|
||||||
|
FusedKernelSource::AlreadyCompiled { id: _ } => {
|
||||||
|
panic!("Can't get the source of an already compiled kernel.")
|
||||||
|
}
|
||||||
|
FusedKernelSource::NewKernel {
|
||||||
|
id: _,
|
||||||
|
shader: source,
|
||||||
|
} => SourceTemplate::new(source.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn id(&self) -> String {
|
||||||
|
match self {
|
||||||
|
FusedKernelSource::AlreadyCompiled { id } => id.clone(),
|
||||||
|
FusedKernelSource::NewKernel { id, shader: _ } => id.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl KernelCompilationCache {
|
||||||
|
pub fn get(&self, id: &str) -> Option<FusedKernelSource> {
|
||||||
|
if self.already_compiled_ids.contains(id) {
|
||||||
|
return Some(FusedKernelSource::AlreadyCompiled { id: id.to_string() });
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert(&mut self, id: String) {
|
||||||
|
self.already_compiled_ids.insert(id);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,11 +0,0 @@
|
||||||
mod body;
|
|
||||||
mod function;
|
|
||||||
mod operator;
|
|
||||||
mod shader;
|
|
||||||
mod variable;
|
|
||||||
|
|
||||||
pub use body::*;
|
|
||||||
pub use function::*;
|
|
||||||
pub use operator::*;
|
|
||||||
pub use shader::*;
|
|
||||||
pub use variable::*;
|
|
|
@ -1,8 +1,10 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
|
codegen::{Elem, Operator, Variable},
|
||||||
element::WgpuElement,
|
element::WgpuElement,
|
||||||
fusion::codegen::{Elem, Operator, Variable},
|
fusion::cache::KernelCompilationCache,
|
||||||
FloatElement, GraphicsApi, IntElement, Wgpu,
|
FloatElement, GraphicsApi, IntElement, Wgpu,
|
||||||
};
|
};
|
||||||
|
use burn_common::id::IdGenerator;
|
||||||
use burn_fusion::{
|
use burn_fusion::{
|
||||||
graph::{
|
graph::{
|
||||||
BaseOpsDescription, BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription,
|
BaseOpsDescription, BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription,
|
||||||
|
@ -84,12 +86,14 @@ where
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
Box::new(FloatElementWise {
|
Box::new(FloatElementWise {
|
||||||
|
id: IdGenerator::generate(),
|
||||||
inputs,
|
inputs,
|
||||||
outputs,
|
outputs,
|
||||||
locals,
|
locals,
|
||||||
operators: self.operators.clone(),
|
operators: self.operators.clone(),
|
||||||
scalars_f32: self.scalars_f32,
|
scalars_f32: self.scalars_f32,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
|
cache: KernelCompilationCache::default(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,13 +187,19 @@ where
|
||||||
Operator::AssignGlobal { input: _, out: _ } => {
|
Operator::AssignGlobal { input: _, out: _ } => {
|
||||||
// Nothing to do here.
|
// Nothing to do here.
|
||||||
}
|
}
|
||||||
Operator::ReadGlobal {
|
Operator::AssignLocal { input: _, out: _ } => {
|
||||||
|
// Nothing to do here.
|
||||||
|
}
|
||||||
|
Operator::ReadGlobalIntoContiguous {
|
||||||
variable: _,
|
variable: _,
|
||||||
position: _,
|
position: _,
|
||||||
position_out: _,
|
position_out: _,
|
||||||
} => {
|
} => {
|
||||||
// Nothing to do here.
|
// Nothing to do here.
|
||||||
}
|
}
|
||||||
|
Operator::ReadGlobal { variable: _ } => {
|
||||||
|
// Nothing to do here.
|
||||||
|
}
|
||||||
Operator::Add { lhs, rhs, out } => {
|
Operator::Add { lhs, rhs, out } => {
|
||||||
mark(lhs, &mut local_tensor_ids_input);
|
mark(lhs, &mut local_tensor_ids_input);
|
||||||
mark(rhs, &mut local_tensor_ids_input);
|
mark(rhs, &mut local_tensor_ids_input);
|
||||||
|
@ -242,6 +252,15 @@ where
|
||||||
mark(input, &mut local_tensor_ids_input);
|
mark(input, &mut local_tensor_ids_input);
|
||||||
mark(out, &mut local_tensor_ids_output);
|
mark(out, &mut local_tensor_ids_output);
|
||||||
}
|
}
|
||||||
|
Operator::Clamp {
|
||||||
|
input,
|
||||||
|
min_value: _,
|
||||||
|
max_value: _,
|
||||||
|
out,
|
||||||
|
} => {
|
||||||
|
mark(input, &mut local_tensor_ids_input);
|
||||||
|
mark(out, &mut local_tensor_ids_output);
|
||||||
|
}
|
||||||
Operator::Powf { lhs, rhs, out } => {
|
Operator::Powf { lhs, rhs, out } => {
|
||||||
mark(lhs, &mut local_tensor_ids_input);
|
mark(lhs, &mut local_tensor_ids_input);
|
||||||
mark(rhs, &mut local_tensor_ids_input);
|
mark(rhs, &mut local_tensor_ids_input);
|
||||||
|
@ -287,6 +306,10 @@ where
|
||||||
mark(rhs, &mut local_tensor_ids_input);
|
mark(rhs, &mut local_tensor_ids_input);
|
||||||
mark(out, &mut local_tensor_ids_output);
|
mark(out, &mut local_tensor_ids_output);
|
||||||
}
|
}
|
||||||
|
Operator::Sqrt { input, out } => {
|
||||||
|
mark(input, &mut local_tensor_ids_input);
|
||||||
|
mark(out, &mut local_tensor_ids_output);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,24 +1,73 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
fusion::codegen::{Elem, Operator},
|
codegen::{
|
||||||
fusion::kernel::FusionKernel,
|
ComputeShader, Elem, ElemWiseKernelCodegen, Input, Operator, Output, ReadingStrategy,
|
||||||
|
Visibility,
|
||||||
|
},
|
||||||
|
fusion::{
|
||||||
|
cache::{FusedKernelSource, KernelCompilationCache},
|
||||||
|
kernel,
|
||||||
|
},
|
||||||
FloatElement, GraphicsApi, IntElement, Wgpu,
|
FloatElement, GraphicsApi, IntElement, Wgpu,
|
||||||
};
|
};
|
||||||
use burn_fusion::{graph::Context, Optimization, TensorDescription};
|
use burn_fusion::{graph::Context, Optimization, TensorDescription};
|
||||||
use burn_tensor::Device;
|
use burn_tensor::Device;
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub(crate) struct FloatElementWise<G, F, I>
|
pub(crate) struct FloatElementWise<G, F, I>
|
||||||
where
|
where
|
||||||
G: GraphicsApi,
|
G: GraphicsApi,
|
||||||
F: FloatElement,
|
F: FloatElement,
|
||||||
I: IntElement,
|
I: IntElement,
|
||||||
{
|
{
|
||||||
|
pub(crate) id: String,
|
||||||
pub(crate) inputs: Vec<(TensorDescription, Elem)>,
|
pub(crate) inputs: Vec<(TensorDescription, Elem)>,
|
||||||
pub(crate) outputs: Vec<(TensorDescription, Elem)>,
|
pub(crate) outputs: Vec<(TensorDescription, Elem)>,
|
||||||
pub(crate) locals: Vec<u16>,
|
pub(crate) locals: Vec<u16>,
|
||||||
pub(crate) operators: Vec<Operator>,
|
pub(crate) operators: Vec<Operator>,
|
||||||
pub(crate) scalars_f32: usize,
|
pub(crate) scalars_f32: usize,
|
||||||
pub(crate) device: Device<Wgpu<G, F, I>>,
|
pub(crate) device: Device<Wgpu<G, F, I>>,
|
||||||
|
pub(crate) cache: KernelCompilationCache,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<G, F, I> FloatElementWise<G, F, I>
|
||||||
|
where
|
||||||
|
G: GraphicsApi,
|
||||||
|
F: FloatElement,
|
||||||
|
I: IntElement,
|
||||||
|
{
|
||||||
|
pub fn compile(&mut self) -> ComputeShader {
|
||||||
|
let mut inputs = self
|
||||||
|
.inputs
|
||||||
|
.iter()
|
||||||
|
.map(|(_tensor, elem)| Input::Array {
|
||||||
|
elem: *elem,
|
||||||
|
visibility: Visibility::Read,
|
||||||
|
strategy: ReadingStrategy::IntoContiguous,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let outputs = self
|
||||||
|
.outputs
|
||||||
|
.iter()
|
||||||
|
.zip(self.locals.iter())
|
||||||
|
.map(|((_tensor, elem), local)| Output::Array {
|
||||||
|
elem: *elem,
|
||||||
|
local: *local,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
if self.scalars_f32 > 0 {
|
||||||
|
inputs.push(Input::Scalar {
|
||||||
|
elem: Elem::F32,
|
||||||
|
size: self.scalars_f32,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
ElemWiseKernelCodegen::new()
|
||||||
|
.inputs(&inputs)
|
||||||
|
.body(&self.operators)
|
||||||
|
.outputs(&outputs)
|
||||||
|
.compile()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<G, F, I> Optimization<Wgpu<G, F, I>> for FloatElementWise<G, F, I>
|
impl<G, F, I> Optimization<Wgpu<G, F, I>> for FloatElementWise<G, F, I>
|
||||||
|
@ -27,27 +76,33 @@ where
|
||||||
F: FloatElement,
|
F: FloatElement,
|
||||||
I: IntElement,
|
I: IntElement,
|
||||||
{
|
{
|
||||||
fn execute(&self, context: &mut Context<'_, Wgpu<G, F, I>>) {
|
fn execute(&mut self, context: &mut Context<'_, Wgpu<G, F, I>>) {
|
||||||
let inputs = self
|
if let Some(kernel) = self.cache.get(&self.id) {
|
||||||
.inputs
|
kernel::execute_fusion(
|
||||||
.iter()
|
&self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||||
.map(|(tensor, elem)| (context.tensors.get(&tensor.id).unwrap(), *elem))
|
&self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||||
.collect::<Vec<_>>();
|
self.scalars_f32,
|
||||||
|
kernel,
|
||||||
|
context,
|
||||||
|
self.device.clone(),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
let shader = self.compile();
|
||||||
|
|
||||||
let outputs = self
|
kernel::execute_fusion(
|
||||||
.outputs
|
&self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||||
.iter()
|
&self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||||
.map(|(tensor, elem)| (context.tensors.get(&tensor.id).unwrap(), *elem))
|
self.scalars_f32,
|
||||||
.collect::<Vec<_>>();
|
FusedKernelSource::NewKernel {
|
||||||
|
id: self.id.to_string(),
|
||||||
|
shader,
|
||||||
|
},
|
||||||
|
context,
|
||||||
|
self.device.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
// The context may contain scalars for the end condition, which may vary.
|
self.cache.insert(self.id.clone());
|
||||||
let scalars_f32 = &context.scalar_floats[0..self.scalars_f32];
|
}
|
||||||
|
|
||||||
FusionKernel::new(&self.device)
|
|
||||||
.inputs(&inputs, scalars_f32)
|
|
||||||
.body(&self.operators)
|
|
||||||
.outputs(&outputs, &self.locals)
|
|
||||||
.execute(context.handles);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn len(&self) -> usize {
|
fn len(&self) -> usize {
|
||||||
|
@ -144,6 +199,7 @@ mod tests {
|
||||||
Variant1,
|
Variant1,
|
||||||
Variant2,
|
Variant2,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn execute<B: Backend>(
|
fn execute<B: Backend>(
|
||||||
data_1: Data<f32, 2>,
|
data_1: Data<f32, 2>,
|
||||||
data_2: Data<f32, 2>,
|
data_2: Data<f32, 2>,
|
||||||
|
|
|
@ -1,355 +1,82 @@
|
||||||
use super::codegen::Body;
|
use super::cache::FusedKernelSource;
|
||||||
use crate::compute::{compute_client, DynamicKernel, WgpuComputeClient};
|
use crate::codegen::calculate_num_elems_dyn_rank;
|
||||||
use crate::fusion::codegen::Function;
|
use crate::compute::{compute_client, DynamicKernel};
|
||||||
use crate::fusion::{calculate_num_elems_dyn_rank, strides_dyn_rank};
|
use crate::fusion::strides_dyn_rank;
|
||||||
use crate::fusion::{
|
use crate::fusion::WgpuFusionHandle;
|
||||||
codegen::{
|
|
||||||
Binding, ComputeShader, Elem, Location, Operator, Variable, Visibility, WorkgroupSize,
|
|
||||||
},
|
|
||||||
WgpuFusionHandle,
|
|
||||||
};
|
|
||||||
use crate::kernel::{elemwise_workgroup, WORKGROUP_DEFAULT};
|
use crate::kernel::{elemwise_workgroup, WORKGROUP_DEFAULT};
|
||||||
use crate::{FloatElement, GraphicsApi, IntElement, Wgpu};
|
use crate::{FloatElement, GraphicsApi, IntElement, Wgpu};
|
||||||
use burn_fusion::{HandleContainer, TensorDescription};
|
use burn_fusion::graph::Context;
|
||||||
|
use burn_fusion::TensorDescription;
|
||||||
use burn_tensor::Device;
|
use burn_tensor::Device;
|
||||||
use std::marker::PhantomData;
|
|
||||||
|
|
||||||
/// Kernel creation input phase, see [fusion kernel](FusionKernel) for more details.
|
pub fn execute_fusion<G: GraphicsApi, F: FloatElement, I: IntElement>(
|
||||||
pub struct InputPhase;
|
inputs: &[&TensorDescription],
|
||||||
/// Kernel creation body phase, see [fusion kernel](FusionKernel) for more details.
|
outputs: &[&TensorDescription],
|
||||||
pub struct BodyPhase;
|
scalars_f32: usize,
|
||||||
/// Kernel creation output phase, see [fusion kernel](FusionKernel) for more details.
|
kernel: FusedKernelSource,
|
||||||
pub struct OutputPhase;
|
context: &mut Context<'_, Wgpu<G, F, I>>,
|
||||||
/// Kernel execution phase, see [fusion kernel](FusionKernel) for more details.
|
|
||||||
pub struct ExecutionPhase;
|
|
||||||
|
|
||||||
/// Allows to create custom wgsl kernels based on configured inputs, body and outputs.
|
|
||||||
///
|
|
||||||
/// This type has 4 phases that must be executed in order, but no worry the type system won't allow
|
|
||||||
/// you to make mistakes.
|
|
||||||
///
|
|
||||||
/// 1. [Input Phase](InputPhase)
|
|
||||||
/// This phase focuses on registering the input tensor descriptions that are going to be used by
|
|
||||||
/// the fused kernel.
|
|
||||||
/// 2. [Body Phase](BodyPhase)
|
|
||||||
/// After the input phase is done, all the operations that happen in the body must be
|
|
||||||
/// registered.
|
|
||||||
/// 3. [Output Phase](OutputPhase)
|
|
||||||
/// This step focuses on registering all tensor descriptions that the kernel needs to write to.
|
|
||||||
/// 4. [Execution Phase](ExecutionPhase)
|
|
||||||
/// Now that all other phases are completed, we can actually run the kernel on the given
|
|
||||||
/// [handles](HandleContainer). Note that the actual chosen kernel may vary based on the
|
|
||||||
/// handles provided.
|
|
||||||
pub struct FusionKernel<G, F, I, Phase = InputPhase>
|
|
||||||
where
|
|
||||||
G: GraphicsApi,
|
|
||||||
F: FloatElement,
|
|
||||||
I: IntElement,
|
|
||||||
{
|
|
||||||
operations: Vec<Operator>,
|
|
||||||
input_bindings: Vec<(Binding, TensorDescription)>,
|
|
||||||
output_bindings: Vec<(Binding, TensorDescription)>,
|
|
||||||
named_bindings: Vec<(String, Binding, DataBuffer)>,
|
|
||||||
functions: Vec<Function>,
|
|
||||||
num_elems_output: usize,
|
|
||||||
device: Device<Wgpu<G, F, I>>,
|
device: Device<Wgpu<G, F, I>>,
|
||||||
client: WgpuComputeClient,
|
) {
|
||||||
_phase: PhantomData<Phase>,
|
let client = compute_client::<G>(&device);
|
||||||
}
|
let mut info = Vec::new();
|
||||||
|
let mut handles = Vec::with_capacity(inputs.len() + outputs.len() + 2);
|
||||||
|
|
||||||
enum DataBuffer {
|
// Inner function to fill the info buffer.
|
||||||
F32(Vec<f32>),
|
let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| {
|
||||||
U32(Vec<u32>),
|
if info.is_empty() {
|
||||||
}
|
info.push(handle.strides.len() as u32);
|
||||||
|
}
|
||||||
|
|
||||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, InputPhase> {
|
for s in handle.strides.iter() {
|
||||||
/// Create a new fusion kernel on the given device.
|
info.push(*s as u32);
|
||||||
pub fn new(device: &Device<Wgpu<G, F, I>>) -> Self {
|
}
|
||||||
let client = compute_client::<G>(device);
|
for s in tensor.shape.iter() {
|
||||||
|
info.push(*s as u32);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
Self {
|
// We start by registering the inputs.
|
||||||
operations: Vec::new(),
|
for tensor in inputs.iter() {
|
||||||
input_bindings: Vec::new(),
|
let tensor = context.tensors.get(&tensor.id).unwrap();
|
||||||
output_bindings: Vec::new(),
|
let handle = context.handles.get_handle(tensor);
|
||||||
named_bindings: Vec::new(),
|
|
||||||
functions: Vec::new(),
|
register_info_tensor(tensor, &handle);
|
||||||
num_elems_output: 0,
|
handles.push(handle.handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut num_elems_output = 0;
|
||||||
|
|
||||||
|
// Then we follow with the outputs.
|
||||||
|
for tensor in outputs.iter() {
|
||||||
|
let tensor = context.tensors.get(&tensor.id).unwrap();
|
||||||
|
|
||||||
|
let num_elems = calculate_num_elems_dyn_rank(&tensor.shape);
|
||||||
|
if num_elems > num_elems_output {
|
||||||
|
num_elems_output = num_elems;
|
||||||
|
}
|
||||||
|
let handle_fusion = WgpuFusionHandle {
|
||||||
|
client: client.clone(),
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
client,
|
strides: strides_dyn_rank(&tensor.shape),
|
||||||
_phase: PhantomData,
|
handle: client.empty(core::mem::size_of::<F>() * num_elems),
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Register the inputs used by the kernel.
|
|
||||||
pub fn inputs(
|
|
||||||
mut self,
|
|
||||||
inputs_tensor: &[(&TensorDescription, Elem)],
|
|
||||||
inputs_scalar_f32: &[f32],
|
|
||||||
) -> FusionKernel<G, F, I, BodyPhase> {
|
|
||||||
for (i, (input, elem)) in inputs_tensor.iter().enumerate() {
|
|
||||||
if elem != &Elem::Bool {
|
|
||||||
self.input_bindings.push((
|
|
||||||
Binding {
|
|
||||||
elem: *elem,
|
|
||||||
visibility: Visibility::Read,
|
|
||||||
location: Location::Storage,
|
|
||||||
size: None,
|
|
||||||
},
|
|
||||||
(*input).clone(),
|
|
||||||
));
|
|
||||||
|
|
||||||
self.operations.push(Operator::ReadGlobal {
|
|
||||||
variable: Variable::Input(i as u16, *elem),
|
|
||||||
position: i,
|
|
||||||
position_out: inputs_tensor.len(), // First output
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
self.input_bindings.push((
|
|
||||||
Binding {
|
|
||||||
elem: Elem::I32,
|
|
||||||
visibility: Visibility::Read,
|
|
||||||
location: Location::Storage,
|
|
||||||
size: None,
|
|
||||||
},
|
|
||||||
(*input).clone(),
|
|
||||||
));
|
|
||||||
|
|
||||||
self.operations.push(Operator::ReadGlobal {
|
|
||||||
variable: Variable::Input(i as u16, *elem),
|
|
||||||
position: i,
|
|
||||||
position_out: inputs_tensor.len(), // First output
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !inputs_scalar_f32.is_empty() {
|
|
||||||
self.named_bindings.push((
|
|
||||||
"scalars_f32".to_string(),
|
|
||||||
Binding {
|
|
||||||
elem: Elem::F32,
|
|
||||||
visibility: Visibility::Read,
|
|
||||||
location: Location::Storage,
|
|
||||||
size: Some(inputs_scalar_f32.len()),
|
|
||||||
},
|
|
||||||
DataBuffer::F32(inputs_scalar_f32.to_vec()),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
FusionKernel {
|
|
||||||
operations: self.operations,
|
|
||||||
input_bindings: self.input_bindings,
|
|
||||||
output_bindings: self.output_bindings,
|
|
||||||
named_bindings: self.named_bindings,
|
|
||||||
functions: self.functions,
|
|
||||||
num_elems_output: self.num_elems_output,
|
|
||||||
device: self.device,
|
|
||||||
client: self.client,
|
|
||||||
_phase: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, BodyPhase> {
|
|
||||||
/// Register the [operators](Operator) that the kernel must execute in the order provided.
|
|
||||||
pub fn body(mut self, operators: &[Operator]) -> FusionKernel<G, F, I, OutputPhase> {
|
|
||||||
let mut register_function = |function: Function| {
|
|
||||||
if !self.functions.contains(&function) {
|
|
||||||
self.functions.push(function);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Since not all operators are native to WGSL, we need to add the custom ones.
|
register_info_tensor(tensor, &handle_fusion);
|
||||||
for ops in operators.iter() {
|
|
||||||
match ops {
|
|
||||||
Operator::Powf {
|
|
||||||
lhs: _,
|
|
||||||
rhs: _,
|
|
||||||
out: _,
|
|
||||||
} => {
|
|
||||||
register_function(Function::Powf(Elem::F32));
|
|
||||||
}
|
|
||||||
Operator::Erf { input: _, out: _ } => {
|
|
||||||
register_function(Function::Erf(Elem::F32));
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
self.operations.push(ops.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
FusionKernel {
|
handles.push(handle_fusion.handle.clone());
|
||||||
operations: self.operations,
|
context
|
||||||
input_bindings: self.input_bindings,
|
.handles
|
||||||
output_bindings: self.output_bindings,
|
.register_handle(tensor.id.clone(), handle_fusion);
|
||||||
named_bindings: self.named_bindings,
|
|
||||||
functions: self.functions,
|
|
||||||
num_elems_output: self.num_elems_output,
|
|
||||||
device: self.device,
|
|
||||||
client: self.client,
|
|
||||||
_phase: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, OutputPhase> {
|
|
||||||
/// Register the outputs with their local variable index.
|
|
||||||
///
|
|
||||||
/// Note that the index corresponds to the registered [operator](Operator) number at the
|
|
||||||
/// [body phase](BodyPhase).
|
|
||||||
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
|
|
||||||
pub fn outputs(
|
|
||||||
mut self,
|
|
||||||
outputs: &[(&TensorDescription, Elem)],
|
|
||||||
locals: &[u16],
|
|
||||||
) -> FusionKernel<G, F, I, ExecutionPhase> {
|
|
||||||
let mut num_elems_launch_option = 0;
|
|
||||||
|
|
||||||
for (i, ((output, elem), local)) in outputs.iter().zip(locals).enumerate() {
|
|
||||||
let num_elems_output = calculate_num_elems_dyn_rank(&output.shape);
|
|
||||||
if num_elems_output > num_elems_launch_option {
|
|
||||||
num_elems_launch_option = num_elems_output;
|
|
||||||
}
|
|
||||||
|
|
||||||
if elem != &Elem::Bool {
|
|
||||||
self.output_bindings.push((
|
|
||||||
Binding {
|
|
||||||
elem: *elem,
|
|
||||||
visibility: Visibility::ReadWrite,
|
|
||||||
location: Location::Storage,
|
|
||||||
size: None,
|
|
||||||
},
|
|
||||||
(*output).clone(),
|
|
||||||
));
|
|
||||||
|
|
||||||
self.operations.push(Operator::AssignGlobal {
|
|
||||||
input: Variable::Local(*local, *elem),
|
|
||||||
out: Variable::Output(i as u16, *elem),
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
self.output_bindings.push((
|
|
||||||
Binding {
|
|
||||||
elem: Elem::I32, // I32 are used for bool tensors
|
|
||||||
visibility: Visibility::ReadWrite,
|
|
||||||
location: Location::Storage,
|
|
||||||
size: None,
|
|
||||||
},
|
|
||||||
(*output).clone(),
|
|
||||||
));
|
|
||||||
|
|
||||||
self.operations.push(Operator::AssignGlobal {
|
|
||||||
input: Variable::Local(*local, *elem),
|
|
||||||
out: Variable::Output(i as u16, Elem::I32),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.num_elems_output = num_elems_launch_option;
|
|
||||||
|
|
||||||
FusionKernel {
|
|
||||||
operations: self.operations,
|
|
||||||
input_bindings: self.input_bindings,
|
|
||||||
output_bindings: self.output_bindings,
|
|
||||||
named_bindings: self.named_bindings,
|
|
||||||
functions: self.functions,
|
|
||||||
num_elems_output: self.num_elems_output,
|
|
||||||
device: self.device,
|
|
||||||
client: self.client,
|
|
||||||
_phase: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, ExecutionPhase> {
|
|
||||||
/// Execute the kernel on the provided [handles](HandleContainer).
|
|
||||||
pub fn execute(mut self, handle_container: &mut HandleContainer<Wgpu<G, F, I>>) {
|
|
||||||
let mut inputs = Vec::with_capacity(self.input_bindings.len());
|
|
||||||
let mut outputs = Vec::with_capacity(self.output_bindings.len());
|
|
||||||
let mut named = Vec::with_capacity(2);
|
|
||||||
let mut info = Vec::new();
|
|
||||||
let mut handles =
|
|
||||||
Vec::with_capacity(inputs.capacity() + outputs.capacity() + named.capacity());
|
|
||||||
|
|
||||||
// Inner function to fill the info buffer.
|
|
||||||
let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| {
|
|
||||||
if info.is_empty() {
|
|
||||||
info.push(handle.strides.len() as u32);
|
|
||||||
}
|
|
||||||
|
|
||||||
for s in handle.strides.iter() {
|
|
||||||
info.push(*s as u32);
|
|
||||||
}
|
|
||||||
for s in tensor.shape.iter() {
|
|
||||||
info.push(*s as u32);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// We start by registering the inputs.
|
|
||||||
for (binding, tensor) in self.input_bindings.into_iter() {
|
|
||||||
let handle = handle_container.get_handle(&tensor);
|
|
||||||
register_info_tensor(&tensor, &handle);
|
|
||||||
|
|
||||||
inputs.push(binding);
|
|
||||||
handles.push(handle.handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then we follow with the outputs.
|
|
||||||
for (binding, tensor) in self.output_bindings {
|
|
||||||
let num_elems = calculate_num_elems_dyn_rank(&tensor.shape);
|
|
||||||
let handle_fusion = WgpuFusionHandle {
|
|
||||||
client: self.client.clone(),
|
|
||||||
device: self.device.clone(),
|
|
||||||
strides: strides_dyn_rank(&tensor.shape),
|
|
||||||
handle: self.client.empty(core::mem::size_of::<F>() * num_elems),
|
|
||||||
};
|
|
||||||
register_info_tensor(&tensor, &handle_fusion);
|
|
||||||
|
|
||||||
handles.push(handle_fusion.handle.clone());
|
|
||||||
handle_container.register_handle(tensor.id, handle_fusion);
|
|
||||||
outputs.push(binding);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now we can create the info handle.
|
|
||||||
Self::build_info_handle(&mut self.named_bindings, info);
|
|
||||||
|
|
||||||
// Finally we finish with the named bindings.
|
|
||||||
for (name, binding, data) in self.named_bindings {
|
|
||||||
let handle = self.client.create(match &data {
|
|
||||||
DataBuffer::F32(values) => bytemuck::cast_slice(values),
|
|
||||||
DataBuffer::U32(values) => bytemuck::cast_slice(values),
|
|
||||||
});
|
|
||||||
named.push((name, binding));
|
|
||||||
handles.push(handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
// We create the shader codegen type and launch the kernel.
|
|
||||||
let kernel = ComputeShader {
|
|
||||||
inputs,
|
|
||||||
outputs,
|
|
||||||
named,
|
|
||||||
workgroup_size: WorkgroupSize::default(),
|
|
||||||
body: Body::new(self.operations),
|
|
||||||
num_workgroups: true,
|
|
||||||
global_invocation_id: true,
|
|
||||||
functions: self.functions,
|
|
||||||
};
|
|
||||||
|
|
||||||
let workgroup = elemwise_workgroup(self.num_elems_output, WORKGROUP_DEFAULT);
|
|
||||||
let kernel = Box::new(DynamicKernel::new(kernel, workgroup));
|
|
||||||
|
|
||||||
self.client
|
|
||||||
.execute(kernel, &handles.iter().collect::<Vec<_>>());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_info_handle(named_bindings: &mut Vec<(String, Binding, DataBuffer)>, info: Vec<u32>) {
|
|
||||||
named_bindings.push((
|
|
||||||
"info".to_string(),
|
|
||||||
Binding {
|
|
||||||
elem: Elem::U32,
|
|
||||||
visibility: Visibility::Read,
|
|
||||||
location: Location::Storage,
|
|
||||||
size: None, // We avoid putting the length here since it will force a new kernel
|
|
||||||
// for each tensor rank.
|
|
||||||
},
|
|
||||||
DataBuffer::U32(info),
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
handles.push(client.create(bytemuck::cast_slice(&info)));
|
||||||
|
|
||||||
|
// Finally we finish with the named bindings.
|
||||||
|
if scalars_f32 > 0 {
|
||||||
|
handles.push(client.create(bytemuck::cast_slice(&context.scalar_floats[0..scalars_f32])));
|
||||||
|
}
|
||||||
|
|
||||||
|
let workgroup = elemwise_workgroup(num_elems_output, WORKGROUP_DEFAULT);
|
||||||
|
let kernel = Box::new(DynamicKernel::new(kernel, workgroup));
|
||||||
|
client.execute(kernel, &handles.iter().collect::<Vec<_>>());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
mod base;
|
mod base;
|
||||||
mod elemwise;
|
mod elemwise;
|
||||||
|
|
||||||
pub(crate) mod codegen;
|
pub(crate) mod cache;
|
||||||
pub(crate) mod kernel;
|
pub(crate) mod kernel;
|
||||||
|
|
||||||
pub use base::*;
|
pub use base::*;
|
||||||
|
|
|
@ -1,78 +1,27 @@
|
||||||
|
use super::unary;
|
||||||
use crate::{
|
use crate::{
|
||||||
compute::StaticKernel,
|
codegen::{Operator, Variable},
|
||||||
element::WgpuElement,
|
element::WgpuElement,
|
||||||
kernel::{unary_scalar, unary_scalar_inplace_default, WORKGROUP_DEFAULT},
|
|
||||||
kernel_wgsl,
|
|
||||||
ops::numeric::empty_device,
|
|
||||||
tensor::WgpuTensor,
|
tensor::WgpuTensor,
|
||||||
unary_scalar, unary_scalar_inplace,
|
unary,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{elemwise_workgroup, KernelSettings};
|
unary!(
|
||||||
|
|elem| Operator::Clamp {
|
||||||
kernel_wgsl!(Clamp, "../template/clamp/clamp.wgsl");
|
input: Variable::Input(0, elem),
|
||||||
kernel_wgsl!(ClampInplace, "../template/clamp/clamp_inplace.wgsl");
|
min_value: Variable::Scalar(0, elem),
|
||||||
|
max_value: Variable::Scalar(1, elem),
|
||||||
pub(crate) fn clamp_min<E: WgpuElement, const D: usize>(
|
out: Variable::Local(0, elem),
|
||||||
input: WgpuTensor<E, D>,
|
},
|
||||||
min_value: E,
|
scalar 2
|
||||||
) -> WgpuTensor<E, D> {
|
);
|
||||||
unary_scalar!(ClampMin, func "max");
|
|
||||||
unary_scalar_inplace!(ClampMinInplace, func "max");
|
|
||||||
|
|
||||||
if input.can_mut() {
|
|
||||||
return unary_scalar_inplace_default::<ClampMinInplace, E, D>(input, min_value);
|
|
||||||
}
|
|
||||||
|
|
||||||
unary_scalar::<ClampMin, E, D, WORKGROUP_DEFAULT>(input, min_value)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn clamp_max<E: WgpuElement, const D: usize>(
|
|
||||||
input: WgpuTensor<E, D>,
|
|
||||||
max_value: E,
|
|
||||||
) -> WgpuTensor<E, D> {
|
|
||||||
unary_scalar!(ClampMax, func "min");
|
|
||||||
unary_scalar_inplace!(ClampMaxInPlace, func "min");
|
|
||||||
|
|
||||||
if input.can_mut() {
|
|
||||||
return unary_scalar_inplace_default::<ClampMaxInPlace, E, D>(input, max_value);
|
|
||||||
}
|
|
||||||
|
|
||||||
unary_scalar::<ClampMax, E, D, WORKGROUP_DEFAULT>(input, max_value)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn clamp<E: WgpuElement, const D: usize>(
|
pub(crate) fn clamp<E: WgpuElement, const D: usize>(
|
||||||
input: WgpuTensor<E, D>,
|
input: WgpuTensor<E, D>,
|
||||||
min_value: E,
|
min_value: E,
|
||||||
max_value: E,
|
max_value: E,
|
||||||
) -> WgpuTensor<E, D> {
|
) -> WgpuTensor<E, D> {
|
||||||
let num_elems = input.shape.num_elements();
|
unary::<Ops<E>, OpsInplace<E>, E, D>(input, Some(&[min_value, max_value]))
|
||||||
let min_handle = input.client.create(E::as_bytes(&[min_value]));
|
|
||||||
let max_handle = input.client.create(E::as_bytes(&[max_value]));
|
|
||||||
|
|
||||||
if input.can_mut() {
|
|
||||||
let kernel = StaticKernel::<
|
|
||||||
KernelSettings<ClampInplace, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
|
||||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
|
||||||
|
|
||||||
input
|
|
||||||
.client
|
|
||||||
.execute(Box::new(kernel), &[&input.handle, &min_handle, &max_handle]);
|
|
||||||
|
|
||||||
return input;
|
|
||||||
}
|
|
||||||
|
|
||||||
let output = empty_device(input.client.clone(), input.device.clone(), input.shape);
|
|
||||||
let kernel = StaticKernel::<
|
|
||||||
KernelSettings<Clamp, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
|
||||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
|
||||||
|
|
||||||
input.client.execute(
|
|
||||||
Box::new(kernel),
|
|
||||||
&[&input.handle, &output.handle, &min_handle, &max_handle],
|
|
||||||
);
|
|
||||||
|
|
||||||
output
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -80,30 +29,6 @@ mod tests {
|
||||||
use crate::tests::{ReferenceBackend, TestBackend};
|
use crate::tests::{ReferenceBackend, TestBackend};
|
||||||
use burn_tensor::{Distribution, Tensor};
|
use burn_tensor::{Distribution, Tensor};
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn clamp_min_should_match_reference() {
|
|
||||||
let input = Tensor::<TestBackend, 4>::random([1, 5, 32, 32], Distribution::Default);
|
|
||||||
let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data());
|
|
||||||
|
|
||||||
let output = input.clamp_min(0.5);
|
|
||||||
|
|
||||||
output
|
|
||||||
.into_data()
|
|
||||||
.assert_approx_eq(&input_ref.clamp_min(0.5).into_data(), 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn clamp_max_should_match_reference() {
|
|
||||||
let input = Tensor::<TestBackend, 4>::random([1, 5, 32, 32], Distribution::Default);
|
|
||||||
let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data());
|
|
||||||
|
|
||||||
let output = input.clamp_max(0.5);
|
|
||||||
|
|
||||||
output
|
|
||||||
.into_data()
|
|
||||||
.assert_approx_eq(&input_ref.clamp_max(0.5).into_data(), 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn clamp_should_match_reference() {
|
fn clamp_should_match_reference() {
|
||||||
let input = Tensor::<TestBackend, 4>::random([1, 5, 32, 32], Distribution::Default);
|
let input = Tensor::<TestBackend, 4>::random([1, 5, 32, 32], Distribution::Default);
|
||||||
|
|
|
@ -8,14 +8,12 @@ mod index;
|
||||||
mod mask;
|
mod mask;
|
||||||
mod source;
|
mod source;
|
||||||
mod unary;
|
mod unary;
|
||||||
mod unary_scalar;
|
|
||||||
|
|
||||||
pub use base::*;
|
pub use base::*;
|
||||||
pub use binary_elemwise::*;
|
pub use binary_elemwise::*;
|
||||||
pub use cast::*;
|
pub use cast::*;
|
||||||
pub use source::*;
|
pub use source::*;
|
||||||
pub use unary::*;
|
pub use unary::*;
|
||||||
pub use unary_scalar::*;
|
|
||||||
|
|
||||||
/// Convolution kernels
|
/// Convolution kernels
|
||||||
pub mod conv;
|
pub mod conv;
|
||||||
|
|
|
@ -1,169 +1,220 @@
|
||||||
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT};
|
use super::StaticKernelSource;
|
||||||
use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
use crate::{
|
||||||
|
codegen::{execute_static, StaticHandle},
|
||||||
kernel_wgsl!(UnaryRaw, "../template/unary.wgsl");
|
element::WgpuElement,
|
||||||
kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl");
|
tensor::WgpuTensor,
|
||||||
|
};
|
||||||
|
|
||||||
/// Creates a unary kernel.
|
/// Creates a unary kernel.
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! unary {
|
macro_rules! unary {
|
||||||
(
|
(
|
||||||
$struct:ident,
|
operator: $ops:expr,
|
||||||
func $func:expr
|
input: $input:expr,
|
||||||
) => {
|
elem: $elem:ty
|
||||||
pub struct $struct;
|
) => {{
|
||||||
|
unary!($ops);
|
||||||
|
|
||||||
impl $crate::kernel::StaticKernelSource for $struct {
|
$crate::kernel::unary::<Ops<$elem>, OpsInplace<$elem>, $elem, D>($input, None)
|
||||||
|
}};
|
||||||
|
(
|
||||||
|
operator: $ops:expr,
|
||||||
|
input: $input:expr; $scalar:expr,
|
||||||
|
elem: $elem:ty
|
||||||
|
) => {{
|
||||||
|
unary!($ops, scalar 1);
|
||||||
|
|
||||||
|
$crate::kernel::unary::<Ops<$elem>, OpsInplace<$elem>, $elem, D>($input, Some(&[$scalar]))
|
||||||
|
}};
|
||||||
|
|
||||||
|
(
|
||||||
|
$ops:expr
|
||||||
|
) => {
|
||||||
|
pub struct Ops<E> {
|
||||||
|
_e: core::marker::PhantomData<E>,
|
||||||
|
}
|
||||||
|
pub struct OpsInplace<E> {
|
||||||
|
_e: core::marker::PhantomData<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::redundant_closure_call)]
|
||||||
|
impl<E: $crate::element::WgpuElement> $crate::kernel::StaticKernelSource for Ops<E> {
|
||||||
fn source() -> $crate::kernel::SourceTemplate {
|
fn source() -> $crate::kernel::SourceTemplate {
|
||||||
let source = $crate::kernel::UnaryRaw::source();
|
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||||
source.register("body", format!("output[id] = {}(input[id]);", $func))
|
.inputs(&[$crate::codegen::Input::Array {
|
||||||
|
elem: E::elem_type(),
|
||||||
|
visibility: $crate::codegen::Visibility::Read,
|
||||||
|
strategy: $crate::codegen::ReadingStrategy::IntoContiguous,
|
||||||
|
}])
|
||||||
|
.body(&[$ops(E::elem_type())])
|
||||||
|
.outputs(&[$crate::codegen::Output::Array {
|
||||||
|
elem: E::elem_type(),
|
||||||
|
local: 0,
|
||||||
|
}])
|
||||||
|
.compile();
|
||||||
|
|
||||||
|
$crate::kernel::SourceTemplate::new(shader.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::redundant_closure_call)]
|
||||||
|
impl<E: $crate::element::WgpuElement> $crate::kernel::StaticKernelSource for OpsInplace<E> {
|
||||||
|
fn source() -> $crate::kernel::SourceTemplate {
|
||||||
|
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||||
|
.inputs(&[$crate::codegen::Input::Array {
|
||||||
|
elem: E::elem_type(),
|
||||||
|
visibility: $crate::codegen::Visibility::ReadWrite,
|
||||||
|
strategy: $crate::codegen::ReadingStrategy::Plain,
|
||||||
|
}])
|
||||||
|
.body(&[$ops(E::elem_type())])
|
||||||
|
.outputs(&[$crate::codegen::Output::Input {
|
||||||
|
elem: E::elem_type(),
|
||||||
|
input: 0,
|
||||||
|
local: 0,
|
||||||
|
}])
|
||||||
|
.compile();
|
||||||
|
|
||||||
|
$crate::kernel::SourceTemplate::new(shader.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
(
|
(
|
||||||
$struct:ident,
|
$ops:expr,
|
||||||
body $body:expr
|
scalar $num:expr
|
||||||
) => {
|
) => {
|
||||||
pub struct $struct;
|
pub struct Ops<E> {
|
||||||
|
_e: core::marker::PhantomData<E>,
|
||||||
|
}
|
||||||
|
pub struct OpsInplace<E> {
|
||||||
|
_e: core::marker::PhantomData<E>,
|
||||||
|
}
|
||||||
|
|
||||||
impl $crate::kernel::StaticKernelSource for $struct {
|
#[allow(clippy::redundant_closure_call)]
|
||||||
|
impl<E: $crate::element::WgpuElement> $crate::kernel::StaticKernelSource for Ops<E> {
|
||||||
fn source() -> $crate::kernel::SourceTemplate {
|
fn source() -> $crate::kernel::SourceTemplate {
|
||||||
$crate::kernel::UnaryRaw::source().register("body", $body)
|
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||||
|
.inputs(&[
|
||||||
|
$crate::codegen::Input::Array {
|
||||||
|
elem: E::elem_type(),
|
||||||
|
visibility: $crate::codegen::Visibility::Read,
|
||||||
|
strategy: $crate::codegen::ReadingStrategy::IntoContiguous,
|
||||||
|
},
|
||||||
|
$crate::codegen::Input::Scalar {
|
||||||
|
elem: E::elem_type(),
|
||||||
|
size: $num,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
.body(&[$ops(E::elem_type())])
|
||||||
|
.outputs(&[$crate::codegen::Output::Array {
|
||||||
|
elem: E::elem_type(),
|
||||||
|
local: 0,
|
||||||
|
}])
|
||||||
|
.compile();
|
||||||
|
|
||||||
|
$crate::kernel::SourceTemplate::new(shader.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
|
||||||
(
|
|
||||||
$struct:ident,
|
|
||||||
func $func:expr,
|
|
||||||
include $file:expr
|
|
||||||
) => {
|
|
||||||
pub struct $struct;
|
|
||||||
|
|
||||||
impl $crate::kernel::StaticKernelSource for $struct {
|
#[allow(clippy::redundant_closure_call)]
|
||||||
|
impl<E: $crate::element::WgpuElement> $crate::kernel::StaticKernelSource for OpsInplace<E> {
|
||||||
fn source() -> $crate::kernel::SourceTemplate {
|
fn source() -> $crate::kernel::SourceTemplate {
|
||||||
$crate::kernel::UnaryRaw::source()
|
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||||
.register("body", format!("output[id] = {}(input[id]);", $func))
|
.inputs(&[
|
||||||
.add_template(include_str!($file))
|
$crate::codegen::Input::Array {
|
||||||
|
elem: E::elem_type(),
|
||||||
|
visibility: $crate::codegen::Visibility::ReadWrite,
|
||||||
|
strategy: $crate::codegen::ReadingStrategy::Plain,
|
||||||
|
},
|
||||||
|
$crate::codegen::Input::Scalar {
|
||||||
|
elem: E::elem_type(),
|
||||||
|
size: $num,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
.body(&[$ops(E::elem_type())])
|
||||||
|
.outputs(&[$crate::codegen::Output::Input {
|
||||||
|
elem: E::elem_type(),
|
||||||
|
input: 0,
|
||||||
|
local: 0,
|
||||||
|
}])
|
||||||
|
.compile();
|
||||||
|
|
||||||
|
$crate::kernel::SourceTemplate::new(shader.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a unary inplace kernel.
|
/// Launch an unary operation.
|
||||||
#[macro_export]
|
pub fn unary<K, KI, E, const D: usize>(
|
||||||
macro_rules! unary_inplace {
|
tensor: WgpuTensor<E, D>,
|
||||||
(
|
scalars: Option<&[E]>,
|
||||||
$struct:ident,
|
) -> WgpuTensor<E, D>
|
||||||
func $func:expr
|
where
|
||||||
) => {
|
|
||||||
pub struct $struct;
|
|
||||||
|
|
||||||
impl $crate::kernel::StaticKernelSource for $struct {
|
|
||||||
fn source() -> $crate::kernel::SourceTemplate {
|
|
||||||
$crate::kernel::UnaryInplaceRaw::source()
|
|
||||||
.register("body", format!("input[id] = {}(input[id]);", $func))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
(
|
|
||||||
$struct:ident,
|
|
||||||
body $body:expr
|
|
||||||
) => {
|
|
||||||
pub struct $struct;
|
|
||||||
|
|
||||||
impl $crate::kernel::StaticKernelSource for $struct {
|
|
||||||
fn source() -> $crate::kernel::SourceTemplate {
|
|
||||||
$crate::kernel::UnaryInplaceRaw::source().register("body", $body)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
(
|
|
||||||
$struct:ident,
|
|
||||||
func $func:expr,
|
|
||||||
include $file:expr
|
|
||||||
) => {
|
|
||||||
pub struct $struct;
|
|
||||||
|
|
||||||
impl $crate::kernel::StaticKernelSource for $struct {
|
|
||||||
fn source() -> $crate::kernel::SourceTemplate {
|
|
||||||
$crate::kernel::UnaryInplaceRaw::source()
|
|
||||||
.register("body", format!("input[id] = {}(input[id]);", $func))
|
|
||||||
.add_template(include_str!($file))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute a unary kernel using the default settings.
|
|
||||||
pub fn unary_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
|
||||||
input: WgpuTensor<E, D>,
|
|
||||||
) -> WgpuTensor<E, D> {
|
|
||||||
unary::<K, E, D, WORKGROUP_DEFAULT>(input)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute a unary inplace kernel using the default settings.
|
|
||||||
pub fn unary_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
|
||||||
input: WgpuTensor<E, D>,
|
|
||||||
) -> WgpuTensor<E, D> {
|
|
||||||
unary_inplace::<K, E, D, WORKGROUP_DEFAULT>(input)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute a unary inplace kernel using the provided WORKGROUP.
|
|
||||||
pub fn unary_inplace<
|
|
||||||
K: StaticKernelSource,
|
K: StaticKernelSource,
|
||||||
|
KI: StaticKernelSource,
|
||||||
E: WgpuElement,
|
E: WgpuElement,
|
||||||
const D: usize,
|
{
|
||||||
const WORKGROUP: usize,
|
if !tensor.can_mut() {
|
||||||
>(
|
let num_elems = tensor.shape.num_elements();
|
||||||
input: WgpuTensor<E, D>,
|
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<E>());
|
||||||
) -> WgpuTensor<E, D> {
|
let output = WgpuTensor::new(
|
||||||
let num_elems = input.shape.num_elements();
|
tensor.client.clone(),
|
||||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
tensor.device,
|
||||||
elemwise_workgroup(num_elems, WORKGROUP),
|
tensor.shape.clone(),
|
||||||
);
|
buffer,
|
||||||
|
);
|
||||||
|
|
||||||
input.client.execute(Box::new(kernel), &[&input.handle]);
|
execute_static::<K, E>(
|
||||||
|
&[StaticHandle::new(
|
||||||
|
&tensor.handle,
|
||||||
|
&tensor.strides,
|
||||||
|
&tensor.shape.dims,
|
||||||
|
)],
|
||||||
|
&[StaticHandle::new(
|
||||||
|
&output.handle,
|
||||||
|
&output.strides,
|
||||||
|
&output.shape.dims,
|
||||||
|
)],
|
||||||
|
scalars,
|
||||||
|
tensor.client,
|
||||||
|
);
|
||||||
|
|
||||||
input
|
output
|
||||||
}
|
} else {
|
||||||
|
execute_static::<KI, E>(
|
||||||
|
&[],
|
||||||
|
&[StaticHandle::new(
|
||||||
|
&tensor.handle,
|
||||||
|
&tensor.strides,
|
||||||
|
&tensor.shape.dims,
|
||||||
|
)],
|
||||||
|
scalars,
|
||||||
|
tensor.client.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
/// Execute a unary kernel using the provided WORKGROUP.
|
tensor
|
||||||
pub fn unary<K: StaticKernelSource, E: WgpuElement, const D: usize, const WORKGROUP: usize>(
|
}
|
||||||
input: WgpuTensor<E, D>,
|
|
||||||
) -> WgpuTensor<E, D> {
|
|
||||||
let num_elems = input.shape.num_elements();
|
|
||||||
let buffer = input.client.empty(num_elems * core::mem::size_of::<E>());
|
|
||||||
let mut output = WgpuTensor::new(input.client.clone(), input.device, input.shape, buffer);
|
|
||||||
// Since we don't handle the stride inside the kernel, the output tensor have the same strides
|
|
||||||
// as the input tensor. It might not be in the default format.
|
|
||||||
output.strides = input.strides;
|
|
||||||
|
|
||||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
|
||||||
elemwise_workgroup(num_elems, WORKGROUP),
|
|
||||||
);
|
|
||||||
input
|
|
||||||
.client
|
|
||||||
.execute(Box::new(kernel), &[&input.handle, &output.handle]);
|
|
||||||
|
|
||||||
output
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::codegen::{Operator, Variable};
|
||||||
use crate::tests::{ReferenceBackend, TestBackend};
|
use crate::tests::{ReferenceBackend, TestBackend};
|
||||||
use burn_tensor::{Distribution, Tensor};
|
use burn_tensor::{Distribution, Tensor};
|
||||||
|
|
||||||
unary!(TestKernel, func "log");
|
unary!(|elem| Operator::Tanh {
|
||||||
unary_inplace!(TestKernelInplace, func "log");
|
input: Variable::Input(0, elem),
|
||||||
|
out: Variable::Local(0, elem),
|
||||||
|
});
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn unary_should_work_with_multiple_invocations() {
|
fn unary_should_work_with_multiple_invocations() {
|
||||||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
||||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||||
|
|
||||||
let actual = unary::<TestKernel, _, 2, 16>(tensor.into_primitive());
|
let actual = unary::<Ops<f32>, OpsInplace<f32>, f32, 2>(tensor.into_primitive(), None);
|
||||||
let expected = tensor_ref.log();
|
let expected = tensor_ref.tanh();
|
||||||
|
|
||||||
expected.into_data().assert_approx_eq(
|
expected.into_data().assert_approx_eq(
|
||||||
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
||||||
|
@ -176,8 +227,8 @@ mod tests {
|
||||||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
||||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||||
|
|
||||||
let actual = unary_inplace::<TestKernelInplace, _, 2, 16>(tensor.into_primitive());
|
let actual = unary::<Ops<f32>, OpsInplace<f32>, f32, 2>(tensor.into_primitive(), None);
|
||||||
let expected = tensor_ref.log();
|
let expected = tensor_ref.tanh();
|
||||||
|
|
||||||
expected.into_data().assert_approx_eq(
|
expected.into_data().assert_approx_eq(
|
||||||
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
||||||
|
|
|
@ -1,220 +0,0 @@
|
||||||
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT};
|
|
||||||
use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
|
||||||
|
|
||||||
kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl");
|
|
||||||
kernel_wgsl!(
|
|
||||||
UnaryScalarInplaceRaw,
|
|
||||||
"../template/unary_scalar_inplace.wgsl"
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Creates a unary scalar kernel.
|
|
||||||
#[macro_export]
|
|
||||||
macro_rules! unary_scalar {
|
|
||||||
(
|
|
||||||
$struct:ident,
|
|
||||||
ops $ops:expr
|
|
||||||
) => {
|
|
||||||
pub struct $struct;
|
|
||||||
|
|
||||||
impl $crate::kernel::StaticKernelSource for $struct {
|
|
||||||
fn source() -> $crate::kernel::SourceTemplate {
|
|
||||||
$crate::kernel::UnaryScalarRaw::source()
|
|
||||||
.register("body", format!("output[id] = lhs[id] {} rhs;", $ops))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
(
|
|
||||||
$struct:ident,
|
|
||||||
func $func:expr
|
|
||||||
) => {
|
|
||||||
pub struct $struct;
|
|
||||||
|
|
||||||
impl $crate::kernel::StaticKernelSource for $struct {
|
|
||||||
fn source() -> $crate::kernel::SourceTemplate {
|
|
||||||
$crate::kernel::UnaryScalarRaw::source()
|
|
||||||
.register("body", format!("output[id] = {}(lhs[id], rhs);", $func))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
(
|
|
||||||
$struct:ident,
|
|
||||||
func $func:expr,
|
|
||||||
include $file:expr
|
|
||||||
) => {
|
|
||||||
pub struct $struct;
|
|
||||||
|
|
||||||
impl $crate::kernel::StaticKernelSource for $struct {
|
|
||||||
fn source() -> $crate::kernel::SourceTemplate {
|
|
||||||
$crate::kernel::UnaryScalarRaw::source()
|
|
||||||
.register("body", format!("output[id] = {}(lhs[id], rhs);", $func))
|
|
||||||
.add_template(include_str!($file))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a unary scalar inplace kernel.
|
|
||||||
#[macro_export]
|
|
||||||
macro_rules! unary_scalar_inplace {
|
|
||||||
(
|
|
||||||
$struct:ident,
|
|
||||||
ops $ops:expr
|
|
||||||
) => {
|
|
||||||
pub struct $struct;
|
|
||||||
|
|
||||||
impl $crate::kernel::StaticKernelSource for $struct {
|
|
||||||
fn source() -> $crate::kernel::SourceTemplate {
|
|
||||||
$crate::kernel::UnaryScalarInplaceRaw::source()
|
|
||||||
.register("body", format!("lhs[id] = lhs[id] {} rhs;", $ops))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
(
|
|
||||||
$struct:ident,
|
|
||||||
body $body:expr
|
|
||||||
) => {
|
|
||||||
pub struct $struct;
|
|
||||||
|
|
||||||
impl $crate::kernel::StaticKernelSource for $struct {
|
|
||||||
fn source() -> $crate::kernel::SourceTemplate {
|
|
||||||
$crate::kernel::UnaryScalarInplaceRaw::source().register("body", $body)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
(
|
|
||||||
$struct:ident,
|
|
||||||
func $func:expr
|
|
||||||
) => {
|
|
||||||
pub struct $struct;
|
|
||||||
|
|
||||||
impl $crate::kernel::StaticKernelSource for $struct {
|
|
||||||
fn source() -> $crate::kernel::SourceTemplate {
|
|
||||||
$crate::kernel::UnaryScalarInplaceRaw::source()
|
|
||||||
.register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
(
|
|
||||||
$struct:ident,
|
|
||||||
func $func:expr,
|
|
||||||
include $file:expr
|
|
||||||
) => {
|
|
||||||
pub struct $struct;
|
|
||||||
|
|
||||||
impl $crate::kernel::StaticKernelSource for $struct {
|
|
||||||
fn source() -> $crate::kernel::SourceTemplate {
|
|
||||||
$crate::kernel::UnaryScalarInplaceRaw::source()
|
|
||||||
.register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func))
|
|
||||||
.add_template(include_str!($file))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute a unary scalar kernel using the default settings.
|
|
||||||
pub fn unary_scalar_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
|
||||||
lhs: WgpuTensor<E, D>,
|
|
||||||
scalar: E,
|
|
||||||
) -> WgpuTensor<E, D> {
|
|
||||||
unary_scalar::<K, E, D, WORKGROUP_DEFAULT>(lhs, scalar)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute a unary scalar kernel using the provided WORKGROUP.
|
|
||||||
pub fn unary_scalar<
|
|
||||||
K: StaticKernelSource,
|
|
||||||
E: WgpuElement,
|
|
||||||
const D: usize,
|
|
||||||
const WORKGROUP: usize,
|
|
||||||
>(
|
|
||||||
lhs: WgpuTensor<E, D>,
|
|
||||||
scalar: E,
|
|
||||||
) -> WgpuTensor<E, D> {
|
|
||||||
let num_elems = lhs.shape.num_elements();
|
|
||||||
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
|
|
||||||
let output = WgpuTensor::new(lhs.client.clone(), lhs.device, lhs.shape, buffer);
|
|
||||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
|
||||||
elemwise_workgroup(num_elems, WORKGROUP),
|
|
||||||
);
|
|
||||||
let rhs_handle = lhs.client.create(E::as_bytes(&[scalar]));
|
|
||||||
|
|
||||||
lhs.client.execute(
|
|
||||||
Box::new(kernel),
|
|
||||||
&[&lhs.handle, &rhs_handle, &output.handle],
|
|
||||||
);
|
|
||||||
|
|
||||||
output
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute a unary scalar inplace kernel using the default settings.
|
|
||||||
pub fn unary_scalar_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
|
||||||
lhs: WgpuTensor<E, D>,
|
|
||||||
scalar: E,
|
|
||||||
) -> WgpuTensor<E, D> {
|
|
||||||
unary_scalar_inplace::<K, E, D, WORKGROUP_DEFAULT>(lhs, scalar)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute a unary scalar inplace kernel using the provided WORKGROUP.
|
|
||||||
pub fn unary_scalar_inplace<
|
|
||||||
K: StaticKernelSource,
|
|
||||||
E: WgpuElement,
|
|
||||||
const D: usize,
|
|
||||||
const WORKGROUP: usize,
|
|
||||||
>(
|
|
||||||
lhs: WgpuTensor<E, D>,
|
|
||||||
scalar: E,
|
|
||||||
) -> WgpuTensor<E, D> {
|
|
||||||
let num_elems = lhs.shape.num_elements();
|
|
||||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
|
||||||
elemwise_workgroup(num_elems, WORKGROUP),
|
|
||||||
);
|
|
||||||
let rhs_handle = lhs.client.create(E::as_bytes(&[scalar]));
|
|
||||||
|
|
||||||
lhs.client
|
|
||||||
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]);
|
|
||||||
|
|
||||||
lhs
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use crate::tests::{ReferenceBackend, TestBackend};
|
|
||||||
use burn_tensor::{Distribution, Tensor};
|
|
||||||
|
|
||||||
unary_scalar!(TestKernel, ops "*");
|
|
||||||
unary_scalar_inplace!(TestKernelInplace, ops "*");
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn unary_scalar_should_work_with_multiple_invocations() {
|
|
||||||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
|
||||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
|
||||||
|
|
||||||
let actual = unary_scalar::<TestKernel, _, 2, 16>(tensor.into_primitive(), 5.0);
|
|
||||||
let expected = tensor_ref.mul_scalar(5.0);
|
|
||||||
|
|
||||||
expected.into_data().assert_approx_eq(
|
|
||||||
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
|
||||||
3,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn unary_scalar_inplace_should_work_with_multiple_invocations() {
|
|
||||||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
|
||||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
|
||||||
|
|
||||||
let actual =
|
|
||||||
unary_scalar_inplace::<TestKernelInplace, _, 2, 16>(tensor.into_primitive(), 5.0);
|
|
||||||
let expected = tensor_ref.mul_scalar(5.0);
|
|
||||||
|
|
||||||
expected.into_data().assert_approx_eq(
|
|
||||||
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
|
||||||
3,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -15,6 +15,8 @@ pub mod kernel;
|
||||||
/// Tensor module.
|
/// Tensor module.
|
||||||
pub mod tensor;
|
pub mod tensor;
|
||||||
|
|
||||||
|
pub(crate) mod codegen;
|
||||||
|
|
||||||
mod element;
|
mod element;
|
||||||
pub use element::{FloatElement, IntElement};
|
pub use element::{FloatElement, IntElement};
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
use burn_tensor::ops::{ActivationOps, FloatTensor};
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
element::{FloatElement, IntElement},
|
element::{FloatElement, IntElement},
|
||||||
kernel::{unary_default, unary_inplace_default},
|
GraphicsApi, Wgpu,
|
||||||
unary, unary_inplace, GraphicsApi, Wgpu,
|
|
||||||
};
|
};
|
||||||
|
use burn_tensor::ops::ActivationOps;
|
||||||
|
|
||||||
impl<G, F, I> ActivationOps<Wgpu<G, F, I>> for Wgpu<G, F, I>
|
impl<G, F, I> ActivationOps<Wgpu<G, F, I>> for Wgpu<G, F, I>
|
||||||
where
|
where
|
||||||
|
@ -12,14 +10,4 @@ where
|
||||||
F: FloatElement,
|
F: FloatElement,
|
||||||
I: IntElement,
|
I: IntElement,
|
||||||
{
|
{
|
||||||
fn relu<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
|
||||||
unary!(Relu, body "output[id] = max(input[id], 0.0);");
|
|
||||||
unary_inplace!(ReluInplace, body "input[id] = max(input[id], 0.0);");
|
|
||||||
|
|
||||||
if tensor.can_mut() {
|
|
||||||
return unary_inplace_default::<ReluInplace, F, D>(tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
unary_default::<Relu, F, D>(tensor)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use super::numeric;
|
use super::numeric;
|
||||||
|
use crate::codegen::{Elem, Operator, Variable};
|
||||||
#[cfg(not(feature = "autotune"))]
|
#[cfg(not(feature = "autotune"))]
|
||||||
use crate::kernel::matmul::init_matmul_output;
|
use crate::kernel::matmul::init_matmul_output;
|
||||||
#[cfg(feature = "autotune")]
|
#[cfg(feature = "autotune")]
|
||||||
|
@ -8,18 +9,14 @@ use crate::kernel::matmul::vec4::matmul_tiling_2d_vec4;
|
||||||
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
|
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
|
||||||
#[cfg(not(feature = "autotune"))]
|
#[cfg(not(feature = "autotune"))]
|
||||||
use crate::kernel::reduce::init_reduce_output;
|
use crate::kernel::reduce::init_reduce_output;
|
||||||
use crate::kernel::{
|
use crate::kernel::{self, reduce};
|
||||||
self, reduce, unary_default, unary_inplace_default, unary_scalar_default,
|
use crate::WgpuDevice;
|
||||||
unary_scalar_inplace_default,
|
use crate::{unary, FloatElement, GraphicsApi, IntElement, Wgpu};
|
||||||
};
|
|
||||||
use crate::{unary, unary_inplace, unary_scalar, FloatElement, GraphicsApi, IntElement, Wgpu};
|
|
||||||
use crate::{unary_scalar_inplace, WgpuDevice};
|
|
||||||
use burn_tensor::ops::{
|
use burn_tensor::ops::{
|
||||||
BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor,
|
BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor,
|
||||||
};
|
};
|
||||||
use burn_tensor::{ops::TensorOps, Data, Distribution, Shape};
|
use burn_tensor::{ops::TensorOps, Data, Distribution, Shape};
|
||||||
use burn_tensor::{ElementConversion, Reader};
|
use burn_tensor::{ElementConversion, Reader};
|
||||||
|
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
|
|
||||||
impl<G, F, I> TensorOps<Wgpu<G, F, I>> for Wgpu<G, F, I>
|
impl<G, F, I> TensorOps<Wgpu<G, F, I>> for Wgpu<G, F, I>
|
||||||
|
@ -357,122 +354,115 @@ where
|
||||||
kernel::cast(tensor)
|
kernel::cast(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn exp<const D: usize>(lhs: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
fn exp<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
unary!(Exp, func "exp");
|
unary!(
|
||||||
unary_inplace!(ExpInplace, func "exp");
|
operator: |elem: Elem| Operator::Exp {
|
||||||
|
input: Variable::Input(0, elem),
|
||||||
if lhs.can_mut() {
|
out: Variable::Local(0, elem),
|
||||||
return unary_inplace_default::<ExpInplace, F, D>(lhs);
|
},
|
||||||
}
|
input: tensor,
|
||||||
|
elem: F
|
||||||
unary_default::<Exp, F, D>(lhs)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
fn log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
unary!(Log, func "log");
|
unary!(
|
||||||
unary_inplace!(LogInplace, func "log");
|
operator: |elem: Elem| Operator::Log {
|
||||||
|
input: Variable::Input(0, elem),
|
||||||
if tensor.can_mut() {
|
out: Variable::Local(0, elem),
|
||||||
return unary_inplace_default::<LogInplace, F, D>(tensor);
|
},
|
||||||
}
|
input: tensor,
|
||||||
|
elem: F
|
||||||
unary_default::<Log, F, D>(tensor)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
fn log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
unary!(Log1p, body "output[id] = log(1.0 + input[id]);");
|
unary!(
|
||||||
unary_inplace!(Log1pInplace, body "input[id] = log(1.0 + input[id]);");
|
operator: |elem: Elem| Operator::Log1p {
|
||||||
|
input: Variable::Input(0, elem),
|
||||||
if tensor.can_mut() {
|
out: Variable::Local(0, elem),
|
||||||
return unary_inplace_default::<Log1pInplace, F, D>(tensor);
|
},
|
||||||
}
|
input: tensor,
|
||||||
|
elem: F
|
||||||
unary_default::<Log1p, F, D>(tensor)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn powf<const D: usize>(lhs: FloatTensor<Self, D>, rhs: f32) -> FloatTensor<Self, D> {
|
fn powf<const D: usize>(lhs: FloatTensor<Self, D>, rhs: f32) -> FloatTensor<Self, D> {
|
||||||
unary_scalar!(Powf, func "powf", include "../template/powf.wgsl");
|
unary!(
|
||||||
unary_scalar_inplace!(PowfInplace, func "powf", include "../template/powf.wgsl");
|
operator: |elem: Elem| Operator::Powf {
|
||||||
|
lhs: Variable::Input(0, elem),
|
||||||
if lhs.can_mut() {
|
rhs: Variable::Scalar(0, elem),
|
||||||
return unary_scalar_inplace_default::<PowfInplace, F, D>(lhs, rhs.elem());
|
out: Variable::Local(0, elem),
|
||||||
}
|
},
|
||||||
|
input: lhs; rhs.elem(),
|
||||||
unary_scalar_default::<Powf, F, D>(lhs, rhs.elem())
|
elem: F
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
unary!(Sqrt, func "sqrt");
|
unary!(
|
||||||
unary_inplace!(SqrtInplace, func "sqrt");
|
operator: |elem: Elem| Operator::Sqrt {
|
||||||
|
input: Variable::Input(0, elem),
|
||||||
if tensor.can_mut() {
|
out: Variable::Local(0, elem),
|
||||||
return unary_inplace_default::<SqrtInplace, F, D>(tensor);
|
},
|
||||||
}
|
input: tensor,
|
||||||
|
elem: F
|
||||||
unary_default::<Sqrt, F, D>(tensor)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
fn abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
unary!(Abs, func "abs");
|
unary!(
|
||||||
unary_inplace!(AbsInplace, func "abs");
|
operator: |elem: Elem| Operator::Abs {
|
||||||
|
input: Variable::Input(0, elem),
|
||||||
if tensor.can_mut() {
|
out: Variable::Local(0, elem),
|
||||||
return unary_inplace_default::<AbsInplace, F, D>(tensor);
|
},
|
||||||
}
|
input: tensor,
|
||||||
|
elem: F
|
||||||
unary_default::<Abs, F, D>(tensor)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
fn cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
unary!(Cos, func "cos");
|
unary!(
|
||||||
unary_inplace!(CosInplace, func "cos");
|
operator: |elem: Elem| Operator::Cos {
|
||||||
|
input: Variable::Input(0, elem),
|
||||||
if tensor.can_mut() {
|
out: Variable::Local(0, elem),
|
||||||
return unary_inplace_default::<CosInplace, F, D>(tensor);
|
},
|
||||||
}
|
input: tensor,
|
||||||
|
elem: F
|
||||||
unary_default::<Cos, F, D>(tensor)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
fn sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
unary!(Sin, func "sin");
|
unary!(
|
||||||
unary_inplace!(SinInplace, func "sin");
|
operator: |elem: Elem| Operator::Sin {
|
||||||
|
input: Variable::Input(0, elem),
|
||||||
if tensor.can_mut() {
|
out: Variable::Local(0, elem),
|
||||||
return unary_inplace_default::<SinInplace, F, D>(tensor);
|
},
|
||||||
}
|
input: tensor,
|
||||||
|
elem: F
|
||||||
unary_default::<Sin, F, D>(tensor)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
fn tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
// Metal has a weird numerical behaviour with tanh which require a new function
|
unary!(
|
||||||
#[cfg(target_os = "macos")]
|
operator: |elem: Elem| Operator::Tanh {
|
||||||
unary!(Tanh, func "safe_tanh", include "../template/safe_tanh.wgsl");
|
input: Variable::Input(0, elem),
|
||||||
#[cfg(target_os = "macos")]
|
out: Variable::Local(0, elem),
|
||||||
unary_inplace!(TanhInplace, func "safe_tanh", include "../template/safe_tanh.wgsl");
|
},
|
||||||
|
input: tensor,
|
||||||
#[cfg(not(target_os = "macos"))]
|
elem: F
|
||||||
unary!(Tanh, func "tanh");
|
)
|
||||||
#[cfg(not(target_os = "macos"))]
|
|
||||||
unary_inplace!(TanhInplace, func "tanh");
|
|
||||||
|
|
||||||
if tensor.can_mut() {
|
|
||||||
return unary_inplace_default::<TanhInplace, F, D>(tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
unary_default::<Tanh, F, D>(tensor)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
unary!(Erf, func "erf", include "../template/erf.wgsl");
|
unary!(
|
||||||
unary_inplace!(ErfInplace, func "erf", include "../template/erf.wgsl");
|
operator: |elem: Elem| Operator::Erf {
|
||||||
|
input: Variable::Input(0, elem),
|
||||||
if tensor.can_mut() {
|
out: Variable::Local(0, elem),
|
||||||
return unary_inplace_default::<ErfInplace, F, D>(tensor);
|
},
|
||||||
}
|
input: tensor,
|
||||||
|
elem: F
|
||||||
unary_default::<Erf, F, D>(tensor)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cat<const D: usize>(tensors: Vec<FloatTensor<Self, D>>, dim: usize) -> FloatTensor<Self, D> {
|
fn cat<const D: usize>(tensors: Vec<FloatTensor<Self, D>>, dim: usize) -> FloatTensor<Self, D> {
|
||||||
|
@ -491,20 +481,6 @@ where
|
||||||
kernel::cast(tensor)
|
kernel::cast(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clamp_min<const D: usize>(
|
|
||||||
tensor: FloatTensor<Self, D>,
|
|
||||||
min: FloatElem<Self>,
|
|
||||||
) -> FloatTensor<Self, D> {
|
|
||||||
kernel::clamp_min(tensor, min)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn clamp_max<const D: usize>(
|
|
||||||
tensor: FloatTensor<Self, D>,
|
|
||||||
max: FloatElem<Self>,
|
|
||||||
) -> FloatTensor<Self, D> {
|
|
||||||
kernel::clamp_max(tensor, max)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn clamp<const D: usize>(
|
fn clamp<const D: usize>(
|
||||||
tensor: FloatTensor<Self, D>,
|
tensor: FloatTensor<Self, D>,
|
||||||
min: FloatElem<Self>,
|
min: FloatElem<Self>,
|
||||||
|
@ -516,14 +492,14 @@ where
|
||||||
fn recip<const D: usize>(
|
fn recip<const D: usize>(
|
||||||
tensor: FloatTensor<Wgpu<G, F, I>, D>,
|
tensor: FloatTensor<Wgpu<G, F, I>, D>,
|
||||||
) -> FloatTensor<Wgpu<G, F, I>, D> {
|
) -> FloatTensor<Wgpu<G, F, I>, D> {
|
||||||
unary!(Recip, func "1.0 /");
|
unary!(
|
||||||
unary_inplace!(RecipInplace, func "1.0 /");
|
operator: |elem: Elem| Operator::Recip {
|
||||||
|
input: Variable::Input(0, elem),
|
||||||
if tensor.can_mut() {
|
out: Variable::Local(0, elem),
|
||||||
return unary_inplace_default::<RecipInplace, F, D>(tensor);
|
},
|
||||||
}
|
input: tensor,
|
||||||
|
elem: F
|
||||||
unary_default::<Recip, F, D>(tensor)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat<const D: usize>(
|
fn repeat<const D: usize>(
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
use super::numeric;
|
use super::numeric;
|
||||||
|
|
||||||
|
use crate::codegen::{Elem, Operator, Variable};
|
||||||
use crate::kernel::reduce::{self, init_reduce_output};
|
use crate::kernel::reduce::{self, init_reduce_output};
|
||||||
use crate::kernel::{unary_default, unary_inplace_default};
|
|
||||||
use crate::{
|
use crate::{
|
||||||
element::{FloatElement, IntElement},
|
element::{FloatElement, IntElement},
|
||||||
kernel, unary, unary_inplace, GraphicsApi, Wgpu,
|
kernel, unary, GraphicsApi, Wgpu,
|
||||||
};
|
};
|
||||||
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
||||||
|
|
||||||
|
@ -280,20 +280,6 @@ where
|
||||||
kernel::reduce::argmin(tensor, dim)
|
kernel::reduce::argmin(tensor, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn int_clamp_min<const D: usize>(
|
|
||||||
tensor: IntTensor<Self, D>,
|
|
||||||
min: IntElem<Self>,
|
|
||||||
) -> IntTensor<Self, D> {
|
|
||||||
kernel::clamp_min(tensor, min)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn int_clamp_max<const D: usize>(
|
|
||||||
tensor: IntTensor<Self, D>,
|
|
||||||
max: IntElem<Self>,
|
|
||||||
) -> IntTensor<Self, D> {
|
|
||||||
kernel::clamp_max(tensor, max)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn int_clamp<const D: usize>(
|
fn int_clamp<const D: usize>(
|
||||||
tensor: IntTensor<Self, D>,
|
tensor: IntTensor<Self, D>,
|
||||||
min: IntElem<Self>,
|
min: IntElem<Self>,
|
||||||
|
@ -303,14 +289,14 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
fn int_abs<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
|
fn int_abs<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
|
||||||
unary!(IntAbs, func "abs");
|
unary!(
|
||||||
unary_inplace!(IntAbsInplace, func "abs");
|
operator: |elem: Elem| Operator::Abs {
|
||||||
|
input: Variable::Input(0, elem),
|
||||||
if tensor.can_mut() {
|
out: Variable::Local(0, elem),
|
||||||
return unary_inplace_default::<IntAbsInplace, I, D>(tensor);
|
},
|
||||||
}
|
input: tensor,
|
||||||
|
elem: I
|
||||||
unary_default::<IntAbs, I, D>(tensor)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn int_into_float<const D: usize>(tensor: IntTensor<Self, D>) -> FloatTensor<Self, D> {
|
fn int_into_float<const D: usize>(tensor: IntTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
|
|
|
@ -1,11 +1,8 @@
|
||||||
|
use crate::codegen::{Elem, Operator, Variable};
|
||||||
use crate::compute::{compute_client, WgpuComputeClient};
|
use crate::compute::{compute_client, WgpuComputeClient};
|
||||||
use crate::kernel::{
|
use crate::kernel::{binary_elemwise_default, binary_elemwise_inplace_default};
|
||||||
binary_elemwise_default, binary_elemwise_inplace_default, unary_scalar_default,
|
|
||||||
unary_scalar_inplace_default,
|
|
||||||
};
|
|
||||||
use crate::{
|
use crate::{
|
||||||
binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor,
|
binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor, unary,
|
||||||
unary_scalar, unary_scalar_inplace,
|
|
||||||
};
|
};
|
||||||
use crate::{GraphicsApi, WgpuDevice};
|
use crate::{GraphicsApi, WgpuDevice};
|
||||||
use burn_tensor::{Element, ElementConversion, Shape};
|
use burn_tensor::{Element, ElementConversion, Shape};
|
||||||
|
@ -28,8 +25,14 @@ pub fn full_device<E: WgpuElement + Element, const D: usize>(
|
||||||
) -> WgpuTensor<E, D> {
|
) -> WgpuTensor<E, D> {
|
||||||
let empty = empty_device(client, device, shape);
|
let empty = empty_device(client, device, shape);
|
||||||
|
|
||||||
unary_scalar_inplace!(Full, body "lhs[id] = rhs;");
|
unary!(
|
||||||
unary_scalar_inplace_default::<Full, E, D>(empty, value)
|
operator: |elem: Elem| Operator::AssignLocal {
|
||||||
|
input: Variable::Scalar(0, elem),
|
||||||
|
out: Variable::Local(0, elem),
|
||||||
|
},
|
||||||
|
input: empty; value,
|
||||||
|
elem: E
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn zeros<G: GraphicsApi, E: WgpuElement + Element, const D: usize>(
|
pub fn zeros<G: GraphicsApi, E: WgpuElement + Element, const D: usize>(
|
||||||
|
@ -98,14 +101,15 @@ pub fn add_scalar<E: WgpuElement, const D: usize>(
|
||||||
lhs: WgpuTensor<E, D>,
|
lhs: WgpuTensor<E, D>,
|
||||||
rhs: E,
|
rhs: E,
|
||||||
) -> WgpuTensor<E, D> {
|
) -> WgpuTensor<E, D> {
|
||||||
unary_scalar!(AddScalar, ops "+");
|
unary!(
|
||||||
unary_scalar_inplace!(AddScalarInplace, ops "+");
|
operator: |elem: Elem| Operator::Add {
|
||||||
|
lhs: Variable::Input(0, elem),
|
||||||
if lhs.can_mut() {
|
rhs: Variable::Scalar(0, elem),
|
||||||
return unary_scalar_inplace_default::<AddScalarInplace, E, D>(lhs, rhs);
|
out: Variable::Local(0, elem),
|
||||||
}
|
},
|
||||||
|
input: lhs; rhs,
|
||||||
unary_scalar_default::<AddScalar, E, D>(lhs, rhs)
|
elem: E
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sub<E: WgpuElement, const D: usize>(
|
pub fn sub<E: WgpuElement, const D: usize>(
|
||||||
|
@ -126,14 +130,15 @@ pub fn sub_scalar<E: WgpuElement, const D: usize>(
|
||||||
lhs: WgpuTensor<E, D>,
|
lhs: WgpuTensor<E, D>,
|
||||||
rhs: E,
|
rhs: E,
|
||||||
) -> WgpuTensor<E, D> {
|
) -> WgpuTensor<E, D> {
|
||||||
unary_scalar!(SubScalar, ops "-");
|
unary!(
|
||||||
unary_scalar_inplace!(SubScalarInplace, ops "-");
|
operator: |elem: Elem| Operator::Sub {
|
||||||
|
lhs: Variable::Input(0, elem),
|
||||||
if lhs.can_mut() {
|
rhs: Variable::Scalar(0, elem),
|
||||||
return unary_scalar_inplace_default::<SubScalarInplace, E, D>(lhs, rhs);
|
out: Variable::Local(0, elem),
|
||||||
}
|
},
|
||||||
|
input: lhs; rhs,
|
||||||
unary_scalar_default::<SubScalar, E, D>(lhs, rhs)
|
elem: E
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn mul<E: WgpuElement, const D: usize>(
|
pub fn mul<E: WgpuElement, const D: usize>(
|
||||||
|
@ -158,14 +163,15 @@ pub fn mul_scalar<E: WgpuElement, const D: usize>(
|
||||||
lhs: WgpuTensor<E, D>,
|
lhs: WgpuTensor<E, D>,
|
||||||
rhs: E,
|
rhs: E,
|
||||||
) -> WgpuTensor<E, D> {
|
) -> WgpuTensor<E, D> {
|
||||||
unary_scalar!(MulScalar, ops "*");
|
unary!(
|
||||||
unary_scalar_inplace!(MulScalarInplace, ops "*");
|
operator: |elem: Elem| Operator::Mul {
|
||||||
|
lhs: Variable::Input(0, elem),
|
||||||
if lhs.can_mut() {
|
rhs: Variable::Scalar(0, elem),
|
||||||
return unary_scalar_inplace_default::<MulScalarInplace, E, D>(lhs, rhs);
|
out: Variable::Local(0, elem),
|
||||||
}
|
},
|
||||||
|
input: lhs; rhs,
|
||||||
unary_scalar_default::<MulScalar, E, D>(lhs, rhs)
|
elem: E
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn div<E: WgpuElement, const D: usize>(
|
pub fn div<E: WgpuElement, const D: usize>(
|
||||||
|
@ -186,12 +192,13 @@ pub fn div_scalar<E: WgpuElement, const D: usize>(
|
||||||
lhs: WgpuTensor<E, D>,
|
lhs: WgpuTensor<E, D>,
|
||||||
rhs: E,
|
rhs: E,
|
||||||
) -> WgpuTensor<E, D> {
|
) -> WgpuTensor<E, D> {
|
||||||
unary_scalar!(DivScalar, ops "/");
|
unary!(
|
||||||
unary_scalar_inplace!(DivScalarInplace, ops "/");
|
operator: |elem: Elem| Operator::Div {
|
||||||
|
lhs: Variable::Input(0, elem),
|
||||||
if lhs.can_mut() {
|
rhs: Variable::Scalar(0, elem),
|
||||||
return unary_scalar_inplace_default::<DivScalarInplace, E, D>(lhs, rhs);
|
out: Variable::Local(0, elem),
|
||||||
}
|
},
|
||||||
|
input: lhs; rhs,
|
||||||
unary_scalar_default::<DivScalar, E, D>(lhs, rhs)
|
elem: E
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,25 +0,0 @@
|
||||||
@group(0)
|
|
||||||
@binding(0)
|
|
||||||
var<storage, read> input: array<{{ elem }}>;
|
|
||||||
|
|
||||||
@group(0)
|
|
||||||
@binding(1)
|
|
||||||
var<storage, read_write> output: array<{{ elem }}>;
|
|
||||||
|
|
||||||
@group(0)
|
|
||||||
@binding(2)
|
|
||||||
var<storage, read> min_value: {{ elem }};
|
|
||||||
|
|
||||||
@group(0)
|
|
||||||
@binding(3)
|
|
||||||
var<storage, read> max_value: {{ elem }};
|
|
||||||
|
|
||||||
@compute
|
|
||||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
|
||||||
fn main(
|
|
||||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
|
||||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
|
||||||
) {
|
|
||||||
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
|
|
||||||
output[id] = clamp(input[id], min_value, max_value);
|
|
||||||
}
|
|
|
@ -1,21 +0,0 @@
|
||||||
@group(0)
|
|
||||||
@binding(0)
|
|
||||||
var<storage, read_write> input: array<{{ elem }}>;
|
|
||||||
|
|
||||||
@group(0)
|
|
||||||
@binding(1)
|
|
||||||
var<storage, read> min_value: {{ elem }};
|
|
||||||
|
|
||||||
@group(0)
|
|
||||||
@binding(2)
|
|
||||||
var<storage, read> max_value: {{ elem }};
|
|
||||||
|
|
||||||
@compute
|
|
||||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
|
||||||
fn main(
|
|
||||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
|
||||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
|
||||||
) {
|
|
||||||
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
|
|
||||||
input[id] = clamp(input[id], min_value, max_value);
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
/// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations
|
|
||||||
///
|
|
||||||
/// > (maximum error: 1.5×10−7)
|
|
||||||
/// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = −erf(−x).
|
|
||||||
fn erf_positive(x: {{ elem }}) -> {{ elem }} {
|
|
||||||
let p = 0.3275911;
|
|
||||||
let a1 = 0.254829592;
|
|
||||||
let a2 = -0.284496736;
|
|
||||||
let a3 = 1.421413741;
|
|
||||||
let a4 = -1.453152027;
|
|
||||||
let a5 = 1.061405429;
|
|
||||||
|
|
||||||
let t = 1.0 / (1.0 + p * abs(x));
|
|
||||||
let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1;
|
|
||||||
|
|
||||||
return 1.0 - (tmp * t * exp(-x * x));
|
|
||||||
}
|
|
||||||
|
|
||||||
fn erf(x: {{ elem }}) -> {{ elem }} {
|
|
||||||
if (x < 0.0) {
|
|
||||||
return -1.0 * erf_positive(-1.0 * x);
|
|
||||||
}
|
|
||||||
|
|
||||||
return erf_positive(x);
|
|
||||||
}
|
|
|
@ -1,14 +0,0 @@
|
||||||
fn powf(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {
|
|
||||||
let modulo = rhs % 2.0;
|
|
||||||
|
|
||||||
if (modulo == 0.0) {
|
|
||||||
// Even number
|
|
||||||
return pow(abs(lhs), rhs);
|
|
||||||
} else if (modulo == 1.0 && lhs < 0.0) {
|
|
||||||
// Odd number
|
|
||||||
return -1.0 * pow(-1.0 * lhs, rhs);
|
|
||||||
} else {
|
|
||||||
// Float number
|
|
||||||
return pow(lhs, rhs);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,8 +0,0 @@
|
||||||
/// Metal has a weird numerical behaviour with tanh for inputs over 43.0
|
|
||||||
fn safe_tanh(x: {{ elem }}) -> {{ elem }} {
|
|
||||||
if x > 43.0 {
|
|
||||||
return 1.0;
|
|
||||||
} else {
|
|
||||||
return tanh(x);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,17 +0,0 @@
|
||||||
@group(0)
|
|
||||||
@binding(0)
|
|
||||||
var<storage, read> input: array<{{ elem }}>;
|
|
||||||
|
|
||||||
@group(0)
|
|
||||||
@binding(1)
|
|
||||||
var<storage, read_write> output: array<{{ elem }}>;
|
|
||||||
|
|
||||||
@compute
|
|
||||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
|
||||||
fn main(
|
|
||||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
|
||||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
|
||||||
) {
|
|
||||||
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
|
|
||||||
{{ body }}
|
|
||||||
}
|
|
|
@ -1,13 +0,0 @@
|
||||||
@group(0)
|
|
||||||
@binding(0)
|
|
||||||
var<storage, read_write> input: array<{{ elem }}>;
|
|
||||||
|
|
||||||
@compute
|
|
||||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
|
||||||
fn main(
|
|
||||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
|
||||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
|
||||||
) {
|
|
||||||
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
|
|
||||||
{{ body }}
|
|
||||||
}
|
|
|
@ -1,21 +0,0 @@
|
||||||
@group(0)
|
|
||||||
@binding(0)
|
|
||||||
var<storage, read> lhs: array<{{ elem }}>;
|
|
||||||
|
|
||||||
@group(0)
|
|
||||||
@binding(1)
|
|
||||||
var<storage, read> rhs: {{ elem }};
|
|
||||||
|
|
||||||
@group(0)
|
|
||||||
@binding(2)
|
|
||||||
var<storage, read_write> output: array<{{ elem }}>;
|
|
||||||
|
|
||||||
@compute
|
|
||||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
|
||||||
fn main(
|
|
||||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
|
||||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
|
||||||
) {
|
|
||||||
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
|
|
||||||
{{ body }}
|
|
||||||
}
|
|
|
@ -1,17 +0,0 @@
|
||||||
@group(0)
|
|
||||||
@binding(0)
|
|
||||||
var<storage, read_write> lhs: array<{{ elem }}>;
|
|
||||||
|
|
||||||
@group(0)
|
|
||||||
@binding(1)
|
|
||||||
var<storage, read> rhs: {{ elem }};
|
|
||||||
|
|
||||||
@compute
|
|
||||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
|
||||||
fn main(
|
|
||||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
|
||||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
|
||||||
) {
|
|
||||||
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
|
|
||||||
{{ body }}
|
|
||||||
}
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
use crate::codegen::{Elem, Operator, Variable};
|
||||||
|
use crate::element::WgpuElement;
|
||||||
use crate::{
|
use crate::{
|
||||||
compute::{WgpuComputeClient, WgpuHandle},
|
compute::{WgpuComputeClient, WgpuHandle},
|
||||||
unary, WgpuDevice,
|
unary, WgpuDevice,
|
||||||
};
|
};
|
||||||
use crate::{element::WgpuElement, kernel::unary_default};
|
|
||||||
use burn_tensor::Shape;
|
use burn_tensor::Shape;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
@ -96,8 +97,14 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
|
||||||
// slowdowns.
|
// slowdowns.
|
||||||
//
|
//
|
||||||
// The solution is just to use a simple unary compute shader.
|
// The solution is just to use a simple unary compute shader.
|
||||||
unary!(CopyBuffer, body "output[id] = input[id];");
|
unary!(
|
||||||
unary_default::<CopyBuffer, E, D>(self.clone())
|
operator: |elem: Elem| Operator::AssignLocal {
|
||||||
|
input: Variable::Input(0, elem),
|
||||||
|
out: Variable::Local(0, elem),
|
||||||
|
},
|
||||||
|
input: self.clone(),
|
||||||
|
elem: E
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if the tensor is safe to mutate.
|
/// Check if the tensor is safe to mutate.
|
||||||
|
|
Loading…
Reference in New Issue