Refactor/wgpu/binary (#1078)

* Refactor binary

* Fix

* Oups

* Preparation for cmp

* Refactor comparison

* Remove templates

* Cleanup

* Cleanup

* Fix typo

* Code review
This commit is contained in:
Nathaniel Simard 2023-12-19 14:44:55 -05:00 committed by GitHub
parent b5c49c5bf7
commit 75062c51e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 639 additions and 995 deletions

View File

@ -1,6 +1,7 @@
#[burn_tensor_testgen::testgen(add)]
mod tests {
use super::*;
use burn_tensor::backend::Backend;
use burn_tensor::{Data, Int, Tensor};
#[test]
@ -29,6 +30,54 @@ mod tests {
assert_eq!(data_expected, data_actual);
}
#[test]
fn test_add_different_strides_rhs() {
let data_1 = Data::from([[0.0, 1.0], [2.0, 3.0]]);
let data_2 = Data::from([[4.0, 5.0], [6.0, 7.0]]);
// We need to execute an operation after `from data` to trigger inplace in some backends.
// Which is the operation that might be problematic in this case.
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1) * 1;
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2) * 1;
let data_actual = (tensor_1 + tensor_2.transpose()).into_data();
let data_expected = Data::from([[4.0, 7.0], [7.0, 10.0]]);
assert_eq!(data_expected, data_actual);
}
#[test]
fn test_add_different_strides_lhs() {
let data_1 = Data::from([[0.0, 1.0], [2.0, 3.0]]);
let data_2 = Data::from([[4.0, 5.0], [6.0, 7.0]]);
// We need to execute an operation after `from data` to trigger inplace in some backends.
// Which is the operation that might be problematic in this case.
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1) * 1;
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2) * 1;
let data_actual = (tensor_1.transpose() + tensor_2).into_data();
let data_expected = Data::from([[4.0, 7.0], [7.0, 10.0]]);
assert_eq!(data_expected, data_actual);
}
#[test]
fn test_add_different_strides_broadcast() {
let data_1 = Data::from([[0.0, 1.0], [2.0, 3.0]]);
let data_2 = Data::from([[4.0, 5.0]]);
// We need to execute an operation after `from data` to trigger inplace in some backends.
// Which is the operation that might be problematic in this case.
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1) * 1;
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2) * 1;
let data_actual = (tensor_1.transpose() + tensor_2).into_data();
let data_expected = Data::from([[4.0, 7.0], [5.0, 8.0]]);
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_add_scalar_ops() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);

View File

@ -53,7 +53,9 @@ pub enum Input {
}
pub enum ReadingStrategy {
IntoContiguous,
/// Each element will be read in a way to be compatible with the output layout.
OutputLayout,
/// Keep the current layout.
Plain,
}
@ -79,18 +81,6 @@ impl ElemWiseKernelCodegen<InputPhase> {
pub fn inputs(mut self, inputs: &[Input]) -> ElemWiseKernelCodegen<BodyPhase> {
let mut index: u16 = 0;
let first_output_index = inputs
.iter()
.filter(|input| match input {
Input::Array {
elem: _,
visibility: _,
strategy: _,
} => true,
Input::Scalar { elem: _, size: _ } => false,
})
.count();
for input in inputs {
match input {
Input::Array {
@ -106,11 +96,12 @@ impl ElemWiseKernelCodegen<InputPhase> {
});
match strategy {
ReadingStrategy::IntoContiguous => {
self.operations.push(Operator::ReadGlobalIntoContiguous {
ReadingStrategy::OutputLayout => {
self.operations.push(Operator::ReadGlobalWithLayout {
variable: Variable::Input(index, *elem),
position: index as usize,
position_out: first_output_index, // First output
tensor_read_pos: index as usize,
tensor_layout_pos: 0, // Will set the right value during the output
// phase.
});
}
ReadingStrategy::Plain => {
@ -195,6 +186,7 @@ impl ElemWiseKernelCodegen<OutputPhase> {
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
pub fn outputs(mut self, outputs: &[Output]) -> ElemWiseKernelCodegen<CompilationPhase> {
let mut index = 0;
let mut position_out = 0;
for array in outputs {
match array {
@ -212,16 +204,36 @@ impl ElemWiseKernelCodegen<OutputPhase> {
out: Variable::Output(index, elem_adapted),
});
index += 1;
if index == 1 {
position_out = self.input_bindings.len(); // First output when we have a
// new array for the output.
}
}
Output::Input { elem, input, local } => {
self.operations.push(Operator::AssignGlobal {
input: Variable::Local(*local, *elem),
out: Variable::Input(*input, bool_elem(*elem)),
});
position_out = *input as usize; // Input number when we use inplace operation.
}
}
}
// We set the output number that will be used for the stride definition.
for i in 0..self.input_bindings.len() {
if let Some(Operator::ReadGlobalWithLayout {
variable: _,
tensor_read_pos: _,
tensor_layout_pos,
}) = self.operations.get_mut(i)
{
{
*tensor_layout_pos = position_out;
}
};
}
ElemWiseKernelCodegen {
operations: self.operations,
input_bindings: self.input_bindings,
@ -275,6 +287,12 @@ pub struct StaticHandle<'a> {
shape: &'a [usize],
}
/// The position of the input or output to calculate the number of workgroups to launch.
pub enum WorkgroupLaunch {
Input { pos: usize },
Output { pos: usize },
}
/// Execute a static kernel.
///
///
@ -284,6 +302,7 @@ pub fn execute_static<K, E: WgpuElement>(
inputs: &[StaticHandle],
outputs: &[StaticHandle],
scalar_elems: Option<&[E]>,
launch: WorkgroupLaunch,
client: WgpuComputeClient,
) where
K: StaticKernelSource + 'static,
@ -305,20 +324,26 @@ pub fn execute_static<K, E: WgpuElement>(
}
};
let mut num_elems_output = 0;
// We start by registering the inputs.
for input in inputs.iter() {
for (i, input) in inputs.iter().enumerate() {
if let WorkgroupLaunch::Input { pos } = &launch {
if i == *pos {
num_elems_output = calculate_num_elems_dyn_rank(input.shape);
}
};
register_info_tensor(input.strides, input.shape);
handles.push(input.handle);
}
let mut num_elems_output = 0;
// Then we follow with the outputs.
for output in outputs.iter() {
let num_elems = calculate_num_elems_dyn_rank(output.shape);
if num_elems > num_elems_output {
num_elems_output = num_elems;
}
for (i, output) in outputs.iter().enumerate() {
if let WorkgroupLaunch::Output { pos } = &launch {
if i == *pos {
num_elems_output = calculate_num_elems_dyn_rank(output.shape);
}
};
register_info_tensor(output.strides, output.shape);
handles.push(output.handle);
}
@ -352,8 +377,8 @@ pub(crate) fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
fn bool_elem(elem: Elem) -> Elem {
match elem {
// I32 are used for bool tensors
Elem::Bool => Elem::I32,
// U32 are used for bool tensors
Elem::Bool => Elem::U32,
_ => elem,
}
}

View File

@ -118,10 +118,11 @@ pub enum Operator {
ReadGlobal {
variable: Variable,
},
ReadGlobalIntoContiguous {
/// Read the tensor in a way to be compatible with another tensor layout.
ReadGlobalWithLayout {
variable: Variable,
position: usize,
position_out: usize,
tensor_read_pos: usize,
tensor_layout_pos: usize,
},
}
@ -202,10 +203,10 @@ impl Display for Operator {
)),
Variable::Scalar(_, _) => panic!("Can't read global scalar variable."),
},
Operator::ReadGlobalIntoContiguous {
Operator::ReadGlobalWithLayout {
variable,
position,
position_out,
tensor_read_pos: position,
tensor_layout_pos: position_out,
} => {
let (global, local, elem) = match variable {
Variable::Input(number, elem) => (

View File

@ -74,26 +74,38 @@ where
mod tests {
use super::*;
use crate::{
binary_elemwise, compute::compute_client, kernel::KernelSettings, AutoGraphicsApi,
WgpuDevice,
binary,
codegen::{Elem, Operator, Variable},
compute::compute_client,
kernel::{KernelSettings, WORKGROUP_DEFAULT},
AutoGraphicsApi, WgpuDevice,
};
#[test]
fn can_run_kernel() {
binary_elemwise!(Add, "+");
binary!(
operator: |elem: Elem| Operator::Add {
lhs: Variable::Input(0, elem),
rhs: Variable::Input(1, elem),
out: Variable::Local(0, elem),
},
elem_in: f32,
elem_out: f32
);
let client = compute_client::<AutoGraphicsApi>(&WgpuDevice::default());
let lhs: Vec<f32> = vec![0., 1., 2., 3., 4., 5., 6., 7.];
let rhs: Vec<f32> = vec![10., 11., 12., 6., 7., 3., 1., 0.];
let info: Vec<u32> = vec![1, 1, 1, 1, 8, 8, 8];
let info: Vec<u32> = vec![1, 1, 8, 1, 8, 1, 8];
let lhs = client.create(bytemuck::cast_slice(&lhs));
let rhs = client.create(bytemuck::cast_slice(&rhs));
let out = client.empty(core::mem::size_of::<f32>() * 8);
let info = client.create(bytemuck::cast_slice(&info));
type Kernel = KernelSettings<Add, f32, i32, 16, 16, 1>;
type Kernel =
KernelSettings<Ops<f32, f32>, f32, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>;
let kernel = Box::new(StaticKernel::<Kernel>::new(WorkGroup::new(1, 1, 1)));
client.execute(kernel, &[&lhs, &rhs, &out, &info]);

View File

@ -190,10 +190,10 @@ where
Operator::AssignLocal { input: _, out: _ } => {
// Nothing to do here.
}
Operator::ReadGlobalIntoContiguous {
Operator::ReadGlobalWithLayout {
variable: _,
position: _,
position_out: _,
tensor_read_pos: _,
tensor_layout_pos: _,
} => {
// Nothing to do here.
}

View File

@ -41,7 +41,7 @@ where
.map(|(_tensor, elem)| Input::Array {
elem: *elem,
visibility: Visibility::Read,
strategy: ReadingStrategy::IntoContiguous,
strategy: ReadingStrategy::OutputLayout,
})
.collect::<Vec<_>>();

View File

@ -264,11 +264,11 @@ mod tests {
#[test]
fn test_kernel_type_id() {
kernel_wgsl!(Add, "../template/binary_elemwise.wgsl");
kernel_wgsl!(Cat, "../template/cat.wgsl");
let type_id_1 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 4>>();
let type_id_2 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 5>>();
let type_id_3 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 4>>();
let type_id_1 = TypeId::of::<KernelSettings<Cat, f32, i32, 2, 3, 4>>();
let type_id_2 = TypeId::of::<KernelSettings<Cat, f32, i32, 2, 3, 5>>();
let type_id_3 = TypeId::of::<KernelSettings<Cat, f32, i32, 2, 3, 4>>();
assert_ne!(type_id_1, type_id_2);
assert_eq!(type_id_1, type_id_3);

View File

@ -0,0 +1,211 @@
use crate::{
codegen::{execute_static, StaticHandle, WorkgroupLaunch},
element::WgpuElement,
tensor::WgpuTensor,
};
use burn_tensor::Shape;
/// Creates a binary kernel.
#[macro_export]
macro_rules! binary {
(
operator: $ops:expr,
input: $lhs:expr; $rhs:expr,
elem: $elem:ty
) => {{
binary!(operator: $ops, elem_in: $elem, elem_out: $elem);
$crate::kernel::binary::<Ops<$elem, $elem>, OpsInplaceLhs<$elem, $elem>, OpsInplaceRhs<$elem, $elem>, $elem, D>(
$lhs, $rhs, true
)
}};
(
operator: $ops:expr,
elem_in: $elem_in:ty,
elem_out: $elem_out:ty
) => {
pub struct Ops<I, O> {
_i: core::marker::PhantomData<I>,
_o: core::marker::PhantomData<O>,
}
pub struct OpsInplaceLhs<I, O> {
_i: core::marker::PhantomData<I>,
_o: core::marker::PhantomData<O>,
}
pub struct OpsInplaceRhs<I, O> {
_i: core::marker::PhantomData<I>,
_o: core::marker::PhantomData<O>,
}
#[allow(clippy::redundant_closure_call)]
impl<I, O> $crate::kernel::StaticKernelSource for Ops<I, O>
where
I: $crate::element::WgpuElement,
O: $crate::element::WgpuElement
{
fn source() -> $crate::kernel::SourceTemplate {
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
.inputs(&[
$crate::codegen::Input::Array {
elem: I::elem_type(),
visibility: $crate::codegen::Visibility::Read,
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
},
$crate::codegen::Input::Array {
elem: I::elem_type(),
visibility: $crate::codegen::Visibility::Read,
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
},
])
.body(&[$ops(I::elem_type())])
.outputs(&[$crate::codegen::Output::Array {
elem: O::elem_type(),
local: 0,
}])
.compile();
$crate::kernel::SourceTemplate::new(shader.to_string())
}
}
#[allow(clippy::redundant_closure_call)]
impl<I, O> $crate::kernel::StaticKernelSource
for OpsInplaceLhs<I, O>
where
I: $crate::element::WgpuElement,
O: $crate::element::WgpuElement
{
fn source() -> $crate::kernel::SourceTemplate {
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
.inputs(&[
$crate::codegen::Input::Array {
elem: I::elem_type(),
visibility: $crate::codegen::Visibility::ReadWrite,
strategy: $crate::codegen::ReadingStrategy::Plain,
},
$crate::codegen::Input::Array {
elem: I::elem_type(),
visibility: $crate::codegen::Visibility::Read,
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
},
])
.body(&[$ops(I::elem_type())])
.outputs(&[$crate::codegen::Output::Input {
elem: I::elem_type(),
input: 0,
local: 0,
}])
.compile();
$crate::kernel::SourceTemplate::new(shader.to_string())
}
}
#[allow(clippy::redundant_closure_call)]
impl<I, O> $crate::kernel::StaticKernelSource
for OpsInplaceRhs<I, O>
where
I: $crate::element::WgpuElement,
O: $crate::element::WgpuElement
{
fn source() -> $crate::kernel::SourceTemplate {
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
.inputs(&[
$crate::codegen::Input::Array {
elem: I::elem_type(),
visibility: $crate::codegen::Visibility::Read,
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
},
$crate::codegen::Input::Array {
elem: I::elem_type(),
visibility: $crate::codegen::Visibility::ReadWrite,
strategy: $crate::codegen::ReadingStrategy::Plain,
},
])
.body(&[$ops(I::elem_type())])
.outputs(&[$crate::codegen::Output::Input {
elem: I::elem_type(),
input: 1,
local: 0,
}])
.compile();
$crate::kernel::SourceTemplate::new(shader.to_string())
}
}
};
}
/// Launch an binary operation.
pub fn binary<Kernel, KernelInplaceLhs, KernelInplaceRhs, E, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
inplace_enabled: bool,
) -> WgpuTensor<E, D>
where
Kernel: crate::kernel::StaticKernelSource,
KernelInplaceLhs: crate::kernel::StaticKernelSource,
KernelInplaceRhs: crate::kernel::StaticKernelSource,
E: WgpuElement,
{
if inplace_enabled && lhs.can_mut_broadcast(&rhs) {
execute_static::<KernelInplaceLhs, E>(
&[
StaticHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
StaticHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
],
&[],
None,
WorkgroupLaunch::Input { pos: 0 },
rhs.client,
);
lhs
} else if inplace_enabled && rhs.can_mut_broadcast(&lhs) {
execute_static::<KernelInplaceRhs, E>(
&[
StaticHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
StaticHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
],
&[],
None,
WorkgroupLaunch::Input { pos: 1 },
lhs.client,
);
rhs
} else {
let mut shape_out = [0; D];
lhs.shape
.dims
.iter()
.zip(rhs.shape.dims.iter())
.enumerate()
.for_each(|(index, (dim_lhs, dim_rhs))| {
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
});
let shape_out = Shape::new(shape_out);
let num_elems = shape_out.num_elements();
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
let out = WgpuTensor::new(lhs.client.clone(), lhs.device, shape_out, buffer);
execute_static::<Kernel, E>(
&[
StaticHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
StaticHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
],
&[StaticHandle::new(
&out.handle,
&out.strides,
&out.shape.dims,
)],
None,
WorkgroupLaunch::Output { pos: 0 },
lhs.client,
);
out
}
}

View File

@ -1,181 +0,0 @@
use super::{
build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT,
};
use crate::compute::StaticKernel;
use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
use burn_tensor::Shape;
kernel_wgsl!(BinaryElemwiseRaw, "../template/binary_elemwise.wgsl");
kernel_wgsl!(
BinaryElemwiseInplaceRaw,
"../template/binary_elemwise_inplace.wgsl"
);
/// Creates a binary elementwise kernel.
#[macro_export]
macro_rules! binary_elemwise {
(
$struct:ident,
$ops:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::BinaryElemwiseRaw::source().register(
"body",
format!("output[id] = lhs[index_lhs] {} rhs[index_rhs];", $ops),
)
}
}
};
}
/// Creates a binary elementwise inplace kernel.
#[macro_export]
macro_rules! binary_elemwise_inplace {
(
$struct:ident,
$ops:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::BinaryElemwiseInplaceRaw::source().register(
"body",
format!("lhs[id] = lhs[id] {} rhs[index_rhs];", $ops),
)
}
}
};
}
/// Execute a binary kernel using the default settings.
pub fn binary_elemwise_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
binary_elemwise::<K, E, D, WORKGROUP_DEFAULT>(lhs, rhs)
}
/// Execute a binary kernel using the provided WORKGROUP.
pub fn binary_elemwise<
K: StaticKernelSource,
E: WgpuElement,
const D: usize,
const WORKGROUP: usize,
>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
lhs.assert_is_on_same_device(&rhs);
let mut shape_out = [0; D];
lhs.shape
.dims
.iter()
.zip(rhs.shape.dims.iter())
.enumerate()
.for_each(|(index, (dim_lhs, dim_rhs))| {
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
});
let shape_out = Shape::new(shape_out);
let num_elems = shape_out.num_elements();
let handle = lhs.client.empty(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, handle);
let info = build_info(&[&lhs, &rhs, &output]);
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
lhs.client.execute(
Box::new(kernel),
&[&lhs.handle, &rhs.handle, &output.handle, &info_handle],
);
output
}
/// Execute a binary inplace kernel using the default settings.
pub fn binary_elemwise_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
binary_elemwise_inplace::<K, E, D, WORKGROUP_DEFAULT>(lhs, rhs)
}
/// Execute a binary inplace kernel using the provided WORKGROUP.
pub fn binary_elemwise_inplace<
K: StaticKernelSource,
E: WgpuElement,
const D: usize,
const WORKGROUP: usize,
>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
lhs.assert_is_on_same_device(&rhs);
let info = build_info(&[&lhs, &rhs]);
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP),
);
lhs.client
.execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]);
lhs
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{Distribution, Tensor};
binary_elemwise!(TestKernel, "*");
binary_elemwise_inplace!(TestKernelInplace, "*");
#[test]
fn binary_should_work_with_multiple_invocations() {
let lhs = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
let rhs = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
let lhs_ref = Tensor::<ReferenceBackend, 2>::from_data(lhs.to_data());
let rhs_ref = Tensor::<ReferenceBackend, 2>::from_data(rhs.to_data());
let actual =
binary_elemwise::<TestKernel, _, 2, 16>(lhs.into_primitive(), rhs.into_primitive());
let expected = lhs_ref * rhs_ref;
expected.into_data().assert_approx_eq(
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
3,
);
}
#[test]
fn binary_inplace_should_work_with_multiple_invocations() {
let lhs = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
let rhs = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
let lhs_ref = Tensor::<ReferenceBackend, 2>::from_data(lhs.to_data());
let rhs_ref = Tensor::<ReferenceBackend, 2>::from_data(rhs.to_data());
let actual = binary_elemwise_inplace::<TestKernelInplace, _, 2, 16>(
lhs.into_primitive(),
rhs.into_primitive(),
);
let expected = lhs_ref * rhs_ref;
expected.into_data().assert_approx_eq(
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
3,
);
}
}

View File

@ -21,7 +21,7 @@ pub(crate) fn clamp<E: WgpuElement, const D: usize>(
min_value: E,
max_value: E,
) -> WgpuTensor<E, D> {
unary::<Ops<E>, OpsInplace<E>, E, D>(input, Some(&[min_value, max_value]))
unary::<Ops<E>, OpsInplace<E>, E, D>(input, Some(&[min_value, max_value]), true)
}
#[cfg(test)]

View File

@ -0,0 +1,219 @@
use crate::{
binary,
codegen::{Elem, Operator, Variable},
element::WgpuElement,
kernel::StaticKernelSource,
kernel::{binary::binary, unary::unary},
tensor::WgpuTensor,
unary,
};
use std::mem;
macro_rules! comparison {
(
binary: $ops:expr,
input: $lhs:expr; $rhs:expr,
elem: $elem:ty
) => {{
binary!(operator: $ops, elem_in: $elem, elem_out: $elem);
launch_binary::<Ops<E, u32>, OpsInplaceLhs<E, u32>, OpsInplaceRhs<E, u32>, E, D>($lhs, $rhs)
}};
(
unary: $ops:expr,
input: $lhs:expr; $rhs:expr,
elem: $elem:ty
) => {{
unary!($ops, scalar 1);
launch_unary::<Ops<E>, OpsInplace<E>, E, D>($lhs, $rhs)
}};
}
pub fn equal<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
comparison!(
binary: |elem: Elem| Operator::Equal {
lhs: Variable::Input(0, elem),
rhs: Variable::Input(1, elem),
out: Variable::Local(0, Elem::Bool),
},
input: lhs; rhs,
elem: E
)
}
pub fn greater<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
comparison!(
binary: |elem: Elem| Operator::Greater {
lhs: Variable::Input(0, elem),
rhs: Variable::Input(1, elem),
out: Variable::Local(0, Elem::Bool),
},
input: lhs; rhs,
elem: E
)
}
pub fn greater_equal<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
comparison!(
binary: |elem: Elem| Operator::GreaterEqual {
lhs: Variable::Input(0, elem),
rhs: Variable::Input(1, elem),
out: Variable::Local(0, Elem::Bool),
},
input: lhs; rhs,
elem: E
)
}
pub fn lower<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
comparison!(
binary: |elem: Elem| Operator::Lower {
lhs: Variable::Input(0, elem),
rhs: Variable::Input(1, elem),
out: Variable::Local(0, Elem::Bool),
},
input: lhs; rhs,
elem: E
)
}
pub fn lower_equal<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
comparison!(
binary: |elem: Elem| Operator::LowerEqual {
lhs: Variable::Input(0, elem),
rhs: Variable::Input(1, elem),
out: Variable::Local(0, Elem::Bool),
},
input: lhs; rhs,
elem: E
)
}
pub fn equal_elem<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
comparison!(
unary: |elem: Elem| Operator::Equal {
lhs: Variable::Input(0, elem),
rhs: Variable::Scalar(0, elem),
out: Variable::Local(0, Elem::Bool),
},
input: lhs; rhs,
elem: E
)
}
pub fn greater_elem<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
comparison!(
unary: |elem: Elem| Operator::Greater {
lhs: Variable::Input(0, elem),
rhs: Variable::Scalar(0, elem),
out: Variable::Local(0, Elem::Bool),
},
input: lhs; rhs,
elem: E
)
}
pub fn lower_elem<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
comparison!(
unary: |elem: Elem| Operator::Lower {
lhs: Variable::Input(0, elem),
rhs: Variable::Scalar(0, elem),
out: Variable::Local(0, Elem::Bool),
},
input: lhs; rhs,
elem: E
)
}
pub fn greater_equal_elem<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
comparison!(
unary: |elem: Elem| Operator::GreaterEqual {
lhs: Variable::Input(0, elem),
rhs: Variable::Scalar(0, elem),
out: Variable::Local(0, Elem::Bool),
},
input: lhs; rhs,
elem: E
)
}
pub fn lower_equal_elem<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
comparison!(
unary: |elem: Elem| Operator::LowerEqual {
lhs: Variable::Input(0, elem),
rhs: Variable::Scalar(0, elem),
out: Variable::Local(0, Elem::Bool),
},
input: lhs; rhs,
elem: E
)
}
fn launch_binary<Kernel, KernelInplaceLhs, KernelInplaceRhs, E, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D>
where
Kernel: StaticKernelSource,
KernelInplaceLhs: StaticKernelSource,
KernelInplaceRhs: StaticKernelSource,
E: WgpuElement,
{
let can_be_used_as_bool = mem::size_of::<E>() == mem::size_of::<u32>();
let output =
binary::<Kernel, KernelInplaceLhs, KernelInplaceRhs, E, D>(lhs, rhs, can_be_used_as_bool);
// We recast the tensor type.
WgpuTensor::new(output.client, output.device, output.shape, output.handle)
}
fn launch_unary<Kernel, KernelInplace, E, const D: usize>(
tensor: WgpuTensor<E, D>,
scalars: E,
) -> WgpuTensor<u32, D>
where
Kernel: StaticKernelSource,
KernelInplace: StaticKernelSource,
E: WgpuElement,
{
let can_be_used_as_bool = mem::size_of::<E>() == mem::size_of::<u32>();
let output =
unary::<Kernel, KernelInplace, E, D>(tensor, Some(&[scalars]), can_be_used_as_bool);
// We recast the tensor type.
WgpuTensor::new(output.client, output.device, output.shape, output.handle)
}

View File

@ -1,166 +0,0 @@
use crate::{
comparison, comparison_elem, comparison_elem_inplace, comparison_inplace,
element::WgpuElement,
kernel::{comparison, comparison_elem, comparison_elem_inplace, comparison_inplace},
tensor::WgpuTensor,
};
use std::mem;
comparison!(Equal, "==");
comparison!(Greater, ">");
comparison!(GreaterEqual, ">=");
comparison!(Lower, "<");
comparison!(LowerEqual, "<=");
comparison_inplace!(EqualInplace, "==");
comparison_inplace!(GreaterInplace, ">");
comparison_inplace!(GreaterEqualInplace, ">=");
comparison_inplace!(LowerInplace, "<");
comparison_inplace!(LowerEqualInplace, "<=");
comparison_elem!(EqualElem, "==");
comparison_elem!(GreaterElem, ">");
comparison_elem!(GreaterEqualElem, ">=");
comparison_elem!(LowerElem, "<");
comparison_elem!(LowerEqualElem, "<=");
comparison_elem_inplace!(EqualElemInplace, "==");
comparison_elem_inplace!(GreaterElemInplace, ">");
comparison_elem_inplace!(GreaterEqualElemInplace, ">=");
comparison_elem_inplace!(LowerElemInplace, "<");
comparison_elem_inplace!(LowerEqualElemInplace, "<=");
pub fn equal<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
let can_be_used_as_bool = mem::size_of::<E>() == mem::size_of::<u32>();
if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) {
return comparison_inplace::<EqualInplace, E, D>(lhs, rhs);
}
if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) {
return comparison_inplace::<EqualInplace, E, D>(rhs, lhs);
}
comparison::<Equal, E, D>(lhs, rhs)
}
pub fn greater<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
let can_be_used_as_bool = mem::size_of::<E>() == mem::size_of::<u32>();
if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) {
return comparison_inplace::<GreaterInplace, E, D>(lhs, rhs);
}
if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) {
return comparison_inplace::<LowerInplace, E, D>(rhs, lhs);
}
comparison::<Greater, E, D>(lhs, rhs)
}
pub fn greater_equal<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
let can_be_used_as_bool = mem::size_of::<E>() == mem::size_of::<u32>();
if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) {
return comparison_inplace::<GreaterEqualInplace, E, D>(lhs, rhs);
}
if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) {
return comparison_inplace::<LowerEqualInplace, E, D>(rhs, lhs);
}
comparison::<GreaterEqual, E, D>(lhs, rhs)
}
pub fn lower<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
let can_be_used_as_bool = mem::size_of::<E>() == mem::size_of::<u32>();
if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) {
return comparison_inplace::<LowerInplace, E, D>(lhs, rhs);
}
if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) {
return comparison_inplace::<GreaterInplace, E, D>(rhs, lhs);
}
comparison::<Lower, E, D>(lhs, rhs)
}
pub fn lower_equal<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
let can_be_used_as_bool = mem::size_of::<E>() == mem::size_of::<u32>();
if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) {
return comparison_inplace::<LowerEqualInplace, E, D>(lhs, rhs);
}
if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) {
return comparison_inplace::<GreaterEqualInplace, E, D>(rhs, lhs);
}
comparison::<LowerEqual, E, D>(lhs, rhs)
}
pub fn equal_elem<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
if mem::size_of::<E>() == mem::size_of::<u32>() && lhs.can_mut() {
return comparison_elem_inplace::<EqualElemInplace, E, D>(lhs, rhs);
}
comparison_elem::<EqualElem, E, D>(lhs, rhs)
}
pub fn greater_elem<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
if mem::size_of::<E>() == mem::size_of::<u32>() && lhs.can_mut() {
return comparison_elem_inplace::<GreaterElemInplace, E, D>(lhs, rhs);
}
comparison_elem::<GreaterElem, E, D>(lhs, rhs)
}
pub fn lower_elem<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
if mem::size_of::<E>() == mem::size_of::<u32>() && lhs.can_mut() {
return comparison_elem_inplace::<LowerElemInplace, E, D>(lhs, rhs);
}
comparison_elem::<LowerElem, E, D>(lhs, rhs)
}
pub fn greater_equal_elem<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
if mem::size_of::<E>() == mem::size_of::<u32>() && lhs.can_mut() {
return comparison_elem_inplace::<GreaterEqualElemInplace, E, D>(lhs, rhs);
}
comparison_elem::<GreaterEqualElem, E, D>(lhs, rhs)
}
pub fn lower_equal_elem<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
if mem::size_of::<E>() == mem::size_of::<u32>() && lhs.can_mut() {
return comparison_elem_inplace::<LowerEqualElemInplace, E, D>(lhs, rhs);
}
comparison_elem::<LowerEqualElem, E, D>(lhs, rhs)
}

View File

@ -1,177 +0,0 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{
build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT,
},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};
use burn_tensor::Shape;
kernel_wgsl!(ComparisonRaw, "../../template/comparison/binary.wgsl");
kernel_wgsl!(
ComparisonInplaceRaw,
"../../template/comparison/binary_inplace.wgsl"
);
/// Creates a comparison kernel.
#[macro_export]
macro_rules! comparison {
(
$struct:ident,
$ops:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::ComparisonRaw::source().register(
"body",
format!("output[id] = u32(lhs[index_lhs] {} rhs[index_rhs]);", $ops),
)
}
}
};
}
/// Creates a comparison inplace kernel.
#[macro_export]
macro_rules! comparison_inplace {
(
$struct:ident,
$ops:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::ComparisonInplaceRaw::source()
.register(
"body",
"lhs[index_lhs] = compare(lhs[index_lhs], rhs[index_rhs]);",
)
.add_template(format!(
"{}return {{{{ elem }}}}(lhs {} rhs);{}",
"fn compare(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {\n",
$ops,
"\n}\n"
))
}
}
};
}
pub fn comparison<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
lhs.assert_is_on_same_device(&rhs);
let mut shape_out = [0; D];
lhs.shape
.dims
.iter()
.zip(rhs.shape.dims.iter())
.enumerate()
.for_each(|(index, (dim_lhs, dim_rhs))| {
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
});
let shape_out = Shape::new(shape_out);
let num_elems = shape_out.num_elements();
let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out);
let kernel =
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP_DEFAULT),
);
let info = build_info(&[&lhs, &rhs, &output]);
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
lhs.client.execute(
Box::new(kernel),
&[&lhs.handle, &rhs.handle, &output.handle, &info_handle],
);
WgpuTensor::new(output.client, output.device, output.shape, output.handle)
}
pub fn comparison_inplace<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
lhs.assert_is_on_same_device(&rhs);
let kernel =
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT),
);
let info = build_info(&[&lhs, &rhs]);
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
lhs.client
.execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]);
WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{backend::Backend, Bool, Distribution, Tensor};
comparison!(LowerEqual, "<=");
comparison_inplace!(LowerEqualInplace, "<=");
#[test]
fn comparison_should_work_with_multiple_invocations() {
let (lhs, rhs, lhs_ref, rhs_ref) = inputs();
let value = Tensor::<TestBackend, 3, Bool>::from_primitive(
comparison::<LowerEqual, f32, 3>(lhs.into_primitive(), rhs.into_primitive()),
);
let value_ref = lhs_ref.lower_equal(rhs_ref);
value
.into_data()
.assert_approx_eq(&value_ref.into_data(), 3);
}
#[test]
fn comparison_inplace_should_work_with_multiple_invocations() {
let (lhs, rhs, lhs_ref, rhs_ref) = inputs();
let value = Tensor::<TestBackend, 3, Bool>::from_primitive(comparison_inplace::<
LowerEqualInplace,
f32,
3,
>(
lhs.into_primitive(),
rhs.into_primitive(),
));
let value_ref = lhs_ref.lower_equal(rhs_ref);
value
.into_data()
.assert_approx_eq(&value_ref.into_data(), 3);
}
#[allow(clippy::type_complexity)]
fn inputs() -> (
Tensor<TestBackend, 3>,
Tensor<TestBackend, 3>,
Tensor<ReferenceBackend, 3>,
Tensor<ReferenceBackend, 3>,
) {
TestBackend::seed(0);
let lhs = Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Uniform(0.0, 1.0));
let rhs = Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Uniform(0.0, 1.0));
let lhs_ref = Tensor::<ReferenceBackend, 3>::from_data(lhs.to_data());
let rhs_ref = Tensor::<ReferenceBackend, 3>::from_data(rhs.to_data());
(lhs, rhs, lhs_ref, rhs_ref)
}
}

View File

@ -1,141 +0,0 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT},
kernel_wgsl,
tensor::WgpuTensor,
};
kernel_wgsl!(ComparisonElemRaw, "../../template/comparison/elem.wgsl");
kernel_wgsl!(
ComparisonElemInplaceRaw,
"../../template/comparison/elem_inplace.wgsl"
);
/// Creates a comparison elementwise kernel.
#[macro_export]
macro_rules! comparison_elem {
(
$struct:ident,
$ops:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::ComparisonElemRaw::source()
.register("body", format!("output[id] = u32(lhs[id] {} rhs);", $ops))
}
}
};
}
/// Creates a comparison elementwise inplace kernel.
#[macro_export]
macro_rules! comparison_elem_inplace {
(
$struct:ident,
$ops:expr
) => {
pub struct $struct;
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::ComparisonElemInplaceRaw::source()
.register("body", "lhs[id] = compare(lhs[id], rhs);")
.add_template(format!(
"{}return {{{{ elem }}}}(lhs {} rhs);{}",
"fn compare(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {\n",
$ops,
"\n}\n"
))
}
}
};
}
pub fn comparison_elem<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
let num_elems = lhs.shape.num_elements();
let handle = lhs.client.empty(num_elems * core::mem::size_of::<u32>());
let rhs_handle = lhs.client.create(E::as_bytes(&[rhs]));
let kernel =
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP_DEFAULT),
);
lhs.client
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle, &handle]);
WgpuTensor::new(lhs.client, lhs.device, lhs.shape, handle)
}
pub fn comparison_elem_inplace<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
let kernel =
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT),
);
let rhs_handle = lhs.client.create(E::as_bytes(&[rhs]));
lhs.client
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]);
WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{backend::Backend, Bool, Distribution, Tensor};
comparison_elem!(LowerEqual, "<=");
comparison_elem_inplace!(LowerEqualInplace, "<=");
#[test]
fn comparison_elem_should_work_with_multiple_invocations() {
let (lhs, lhs_ref, rhs) = inputs();
let value =
Tensor::<TestBackend, 3, Bool>::from_primitive(comparison_elem::<LowerEqual, f32, 3>(
lhs.into_primitive(),
rhs,
));
let value_ref = lhs_ref.lower_equal_elem(rhs);
value
.into_data()
.assert_approx_eq(&value_ref.into_data(), 3);
}
#[test]
fn comparison_elem_inplace_should_work_with_multiple_invocations() {
let (lhs, lhs_ref, rhs) = inputs();
let value =
Tensor::<TestBackend, 3, Bool>::from_primitive(comparison_elem_inplace::<
LowerEqualInplace,
f32,
3,
>(lhs.into_primitive(), rhs));
let value_ref = lhs_ref.lower_equal_elem(rhs);
value
.into_data()
.assert_approx_eq(&value_ref.into_data(), 3);
}
#[allow(clippy::type_complexity)]
fn inputs() -> (Tensor<TestBackend, 3>, Tensor<ReferenceBackend, 3>, f32) {
TestBackend::seed(0);
let lhs = Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Uniform(0.0, 1.0));
let lhs_ref = Tensor::<ReferenceBackend, 3>::from_data(lhs.to_data());
(lhs, lhs_ref, 5.0)
}
}

View File

@ -1,7 +0,0 @@
mod base;
mod binary;
mod elem;
pub use base::*;
pub use binary::*;
pub use elem::*;

View File

@ -1,5 +1,5 @@
mod base;
mod binary_elemwise;
mod binary;
mod cast;
mod cat;
mod clamp;
@ -10,7 +10,7 @@ mod source;
mod unary;
pub use base::*;
pub use binary_elemwise::*;
pub use binary::*;
pub use cast::*;
pub use source::*;
pub use unary::*;

View File

@ -1,6 +1,6 @@
use super::StaticKernelSource;
use crate::{
codegen::{execute_static, StaticHandle},
codegen::{execute_static, StaticHandle, WorkgroupLaunch},
element::WgpuElement,
tensor::WgpuTensor,
};
@ -15,7 +15,7 @@ macro_rules! unary {
) => {{
unary!($ops);
$crate::kernel::unary::<Ops<$elem>, OpsInplace<$elem>, $elem, D>($input, None)
$crate::kernel::unary::<Ops<$elem>, OpsInplace<$elem>, $elem, D>($input, None, true)
}};
(
operator: $ops:expr,
@ -24,7 +24,7 @@ macro_rules! unary {
) => {{
unary!($ops, scalar 1);
$crate::kernel::unary::<Ops<$elem>, OpsInplace<$elem>, $elem, D>($input, Some(&[$scalar]))
$crate::kernel::unary::<Ops<$elem>, OpsInplace<$elem>, $elem, D>($input, Some(&[$scalar]), true)
}};
(
@ -44,7 +44,7 @@ macro_rules! unary {
.inputs(&[$crate::codegen::Input::Array {
elem: E::elem_type(),
visibility: $crate::codegen::Visibility::Read,
strategy: $crate::codegen::ReadingStrategy::IntoContiguous,
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
}])
.body(&[$ops(E::elem_type())])
.outputs(&[$crate::codegen::Output::Array {
@ -97,7 +97,7 @@ macro_rules! unary {
$crate::codegen::Input::Array {
elem: E::elem_type(),
visibility: $crate::codegen::Visibility::Read,
strategy: $crate::codegen::ReadingStrategy::IntoContiguous,
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
},
$crate::codegen::Input::Scalar {
elem: E::elem_type(),
@ -145,16 +145,31 @@ macro_rules! unary {
}
/// Launch an unary operation.
pub fn unary<K, KI, E, const D: usize>(
pub fn unary<Kernel, KernelInplace, E, const D: usize>(
tensor: WgpuTensor<E, D>,
scalars: Option<&[E]>,
inplace_enabled: bool,
) -> WgpuTensor<E, D>
where
K: StaticKernelSource,
KI: StaticKernelSource,
Kernel: StaticKernelSource,
KernelInplace: StaticKernelSource,
E: WgpuElement,
{
if !tensor.can_mut() {
if inplace_enabled && tensor.can_mut() {
execute_static::<KernelInplace, E>(
&[StaticHandle::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
&[],
scalars,
WorkgroupLaunch::Input { pos: 0 },
tensor.client.clone(),
);
tensor
} else {
let num_elems = tensor.shape.num_elements();
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(
@ -164,7 +179,7 @@ where
buffer,
);
execute_static::<K, E>(
execute_static::<Kernel, E>(
&[StaticHandle::new(
&tensor.handle,
&tensor.strides,
@ -176,23 +191,11 @@ where
&output.shape.dims,
)],
scalars,
WorkgroupLaunch::Output { pos: 0 },
tensor.client,
);
output
} else {
execute_static::<KI, E>(
&[],
&[StaticHandle::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
scalars,
tensor.client.clone(),
);
tensor
}
}
@ -213,7 +216,8 @@ mod tests {
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
let actual = unary::<Ops<f32>, OpsInplace<f32>, f32, 2>(tensor.into_primitive(), None);
let actual =
unary::<Ops<f32>, OpsInplace<f32>, f32, 2>(tensor.into_primitive(), None, true);
let expected = tensor_ref.tanh();
expected.into_data().assert_approx_eq(
@ -227,7 +231,8 @@ mod tests {
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
let actual = unary::<Ops<f32>, OpsInplace<f32>, f32, 2>(tensor.into_primitive(), None);
let actual =
unary::<Ops<f32>, OpsInplace<f32>, f32, 2>(tensor.into_primitive(), None, true);
let expected = tensor_ref.tanh();
expected.into_data().assert_approx_eq(

View File

@ -1,10 +1,7 @@
use crate::codegen::{Elem, Operator, Variable};
use crate::compute::{compute_client, WgpuComputeClient};
use crate::kernel::{binary_elemwise_default, binary_elemwise_inplace_default};
use crate::{
binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor, unary,
};
use crate::{GraphicsApi, WgpuDevice};
use crate::{binary, GraphicsApi, WgpuDevice};
use crate::{element::WgpuElement, tensor::WgpuTensor, unary};
use burn_tensor::{Element, ElementConversion, Shape};
pub fn full<G: GraphicsApi, E: WgpuElement + Element, const D: usize>(
@ -83,18 +80,15 @@ pub fn add<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
binary_elemwise!(Add, "+");
binary_elemwise_inplace!(AddInplace, "+");
if lhs.can_mut_broadcast(&rhs) {
return binary_elemwise_inplace_default::<AddInplace, E, D>(lhs, rhs);
}
if rhs.can_mut_broadcast(&lhs) {
return binary_elemwise_inplace_default::<AddInplace, E, D>(rhs, lhs);
}
binary_elemwise_default::<Add, E, D>(lhs, rhs)
binary!(
operator: |elem: Elem| Operator::Add {
lhs: Variable::Input(0, elem),
rhs: Variable::Input(1, elem),
out: Variable::Local(0, elem),
},
input: lhs; rhs,
elem: E
)
}
pub fn add_scalar<E: WgpuElement, const D: usize>(
@ -116,14 +110,15 @@ pub fn sub<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
binary_elemwise!(Sub, "-");
binary_elemwise_inplace!(SubInplace, "-");
if lhs.can_mut_broadcast(&rhs) {
return binary_elemwise_inplace_default::<SubInplace, E, D>(lhs, rhs);
}
binary_elemwise_default::<Sub, E, D>(lhs, rhs)
binary!(
operator: |elem: Elem| Operator::Sub {
lhs: Variable::Input(0, elem),
rhs: Variable::Input(1, elem),
out: Variable::Local(0, elem),
},
input: lhs; rhs,
elem: E
)
}
pub fn sub_scalar<E: WgpuElement, const D: usize>(
@ -145,18 +140,15 @@ pub fn mul<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
binary_elemwise!(Mul, "*");
binary_elemwise_inplace!(MulInplace, "*");
if lhs.can_mut_broadcast(&rhs) {
return binary_elemwise_inplace_default::<MulInplace, E, D>(lhs, rhs);
}
if rhs.can_mut_broadcast(&lhs) {
return binary_elemwise_inplace_default::<MulInplace, E, D>(rhs, lhs);
}
binary_elemwise_default::<Mul, E, D>(lhs, rhs)
binary!(
operator: |elem: Elem| Operator::Mul {
lhs: Variable::Input(0, elem),
rhs: Variable::Input(1, elem),
out: Variable::Local(0, elem),
},
input: lhs; rhs,
elem: E
)
}
pub fn mul_scalar<E: WgpuElement, const D: usize>(
@ -178,14 +170,15 @@ pub fn div<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
binary_elemwise!(Div, "/");
binary_elemwise_inplace!(DivInplace, "/");
if lhs.can_mut_broadcast(&rhs) {
return binary_elemwise_inplace_default::<DivInplace, E, D>(lhs, rhs);
}
binary_elemwise_default::<Div, E, D>(lhs, rhs)
binary!(
operator: |elem: Elem| Operator::Div {
lhs: Variable::Input(0, elem),
rhs: Variable::Input(1, elem),
out: Variable::Local(0, elem),
},
input: lhs; rhs,
elem: E
)
}
pub fn div_scalar<E: WgpuElement, const D: usize>(

View File

@ -1,42 +0,0 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let dim: u32 = info[0];
var index_lhs: u32 = 0u;
var index_rhs: u32 = 0u;
for (var i: u32 = 1u; i <= dim; i++) {
let stride_lhs = info[i];
let stride_rhs = info[i + dim];
let stride_output = info[i + 2u * dim];
let shape_lhs = info[i + 3u * dim];
let shape_rhs = info[i + 4u * dim];
index_lhs += id / stride_output % shape_lhs * stride_lhs;
index_rhs += id / stride_output % shape_rhs * stride_rhs;
}
{{ body }}
}

View File

@ -1,34 +0,0 @@
@group(0)
@binding(0)
var<storage, read_write> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let dim: u32 = info[0];
var index_rhs: u32 = 0u;
for (var i: u32 = 1u; i <= dim; i++) {
let stride_lhs = info[i];
let stride_rhs = info[i + dim];
let shape_rhs = info[i + 3u * dim];
index_rhs += id / stride_lhs % shape_rhs * stride_rhs;
}
{{ body }}
}

View File

@ -1,43 +0,0 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> output: array<u32>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let dim: u32 = info[0];
var index_lhs: u32 = 0u;
var index_rhs: u32 = 0u;
for (var i: u32 = 1u; i <= dim; i++) {
let stride_lhs = info[i];
let stride_rhs = info[i + dim];
let stride_output = info[i + 2u * dim];
let shape_lhs = info[i + 3u * dim];
let shape_rhs = info[i + 4u * dim];
index_lhs += id / stride_output % shape_lhs * stride_lhs;
index_rhs += id / stride_output % shape_rhs * stride_rhs;
}
{{ body }}
}

View File

@ -1,41 +0,0 @@
@group(0)
@binding(0)
var<storage, read_write> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let dim: u32 = info[0];
var index_lhs: u32 = 0u;
var index_rhs: u32 = 0u;
var num_elem = 1u;
for (var i: u32 = 1u; i <= dim; i++) {
let stride_lhs = info[i];
let stride_rhs = info[i + dim];
let shape_lhs = info[i + 2u * dim];
let shape_rhs = info[i + 3u * dim];
num_elem *= shape_lhs;
index_lhs += id / stride_lhs % shape_lhs * stride_lhs;
index_rhs += id / stride_lhs % shape_rhs * stride_rhs;
}
if id < num_elem {
{{ body }}
}
}

View File

@ -1,23 +0,0 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: {{ elem }};
@group(0)
@binding(2)
var<storage, read_write> output: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
{{ body }}
}

View File

@ -1,19 +0,0 @@
@group(0)
@binding(0)
var<storage, read_write> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: {{ elem }};
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
{{ body }}
}

View File

@ -73,14 +73,17 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
}
}
pub(crate) fn can_mut_broadcast(&self, tensor_other: &WgpuTensor<E, D>) -> bool {
pub(crate) fn can_mut_broadcast(&self, rhs: &WgpuTensor<E, D>) -> bool {
if !self.handle.can_mut() {
return false;
}
for i in 0..D {
let shape_lhs = self.shape.dims[i];
let shape_rhs = rhs.shape.dims[i];
// Output tensor will be different from the mutable tensor.
if self.shape.dims[i] < tensor_other.shape.dims[i] {
if shape_lhs < shape_rhs {
return false;
}
}