mirror of https://github.com/tracel-ai/burn.git
Refactor/wgpu/binary (#1078)
* Refactor binary * Fix * Oups * Preparation for cmp * Refactor comparison * Remove templates * Cleanup * Cleanup * Fix typo * Code review
This commit is contained in:
parent
b5c49c5bf7
commit
75062c51e0
|
@ -1,6 +1,7 @@
|
|||
#[burn_tensor_testgen::testgen(add)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{Data, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
|
@ -29,6 +30,54 @@ mod tests {
|
|||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_different_strides_rhs() {
|
||||
let data_1 = Data::from([[0.0, 1.0], [2.0, 3.0]]);
|
||||
let data_2 = Data::from([[4.0, 5.0], [6.0, 7.0]]);
|
||||
|
||||
// We need to execute an operation after `from data` to trigger inplace in some backends.
|
||||
// Which is the operation that might be problematic in this case.
|
||||
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1) * 1;
|
||||
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2) * 1;
|
||||
|
||||
let data_actual = (tensor_1 + tensor_2.transpose()).into_data();
|
||||
|
||||
let data_expected = Data::from([[4.0, 7.0], [7.0, 10.0]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_different_strides_lhs() {
|
||||
let data_1 = Data::from([[0.0, 1.0], [2.0, 3.0]]);
|
||||
let data_2 = Data::from([[4.0, 5.0], [6.0, 7.0]]);
|
||||
|
||||
// We need to execute an operation after `from data` to trigger inplace in some backends.
|
||||
// Which is the operation that might be problematic in this case.
|
||||
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1) * 1;
|
||||
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2) * 1;
|
||||
|
||||
let data_actual = (tensor_1.transpose() + tensor_2).into_data();
|
||||
|
||||
let data_expected = Data::from([[4.0, 7.0], [7.0, 10.0]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_different_strides_broadcast() {
|
||||
let data_1 = Data::from([[0.0, 1.0], [2.0, 3.0]]);
|
||||
let data_2 = Data::from([[4.0, 5.0]]);
|
||||
|
||||
// We need to execute an operation after `from data` to trigger inplace in some backends.
|
||||
// Which is the operation that might be problematic in this case.
|
||||
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1) * 1;
|
||||
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2) * 1;
|
||||
|
||||
let data_actual = (tensor_1.transpose() + tensor_2).into_data();
|
||||
|
||||
let data_expected = Data::from([[4.0, 7.0], [5.0, 8.0]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_add_scalar_ops() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
|
|
@ -53,7 +53,9 @@ pub enum Input {
|
|||
}
|
||||
|
||||
pub enum ReadingStrategy {
|
||||
IntoContiguous,
|
||||
/// Each element will be read in a way to be compatible with the output layout.
|
||||
OutputLayout,
|
||||
/// Keep the current layout.
|
||||
Plain,
|
||||
}
|
||||
|
||||
|
@ -79,18 +81,6 @@ impl ElemWiseKernelCodegen<InputPhase> {
|
|||
pub fn inputs(mut self, inputs: &[Input]) -> ElemWiseKernelCodegen<BodyPhase> {
|
||||
let mut index: u16 = 0;
|
||||
|
||||
let first_output_index = inputs
|
||||
.iter()
|
||||
.filter(|input| match input {
|
||||
Input::Array {
|
||||
elem: _,
|
||||
visibility: _,
|
||||
strategy: _,
|
||||
} => true,
|
||||
Input::Scalar { elem: _, size: _ } => false,
|
||||
})
|
||||
.count();
|
||||
|
||||
for input in inputs {
|
||||
match input {
|
||||
Input::Array {
|
||||
|
@ -106,11 +96,12 @@ impl ElemWiseKernelCodegen<InputPhase> {
|
|||
});
|
||||
|
||||
match strategy {
|
||||
ReadingStrategy::IntoContiguous => {
|
||||
self.operations.push(Operator::ReadGlobalIntoContiguous {
|
||||
ReadingStrategy::OutputLayout => {
|
||||
self.operations.push(Operator::ReadGlobalWithLayout {
|
||||
variable: Variable::Input(index, *elem),
|
||||
position: index as usize,
|
||||
position_out: first_output_index, // First output
|
||||
tensor_read_pos: index as usize,
|
||||
tensor_layout_pos: 0, // Will set the right value during the output
|
||||
// phase.
|
||||
});
|
||||
}
|
||||
ReadingStrategy::Plain => {
|
||||
|
@ -195,6 +186,7 @@ impl ElemWiseKernelCodegen<OutputPhase> {
|
|||
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
|
||||
pub fn outputs(mut self, outputs: &[Output]) -> ElemWiseKernelCodegen<CompilationPhase> {
|
||||
let mut index = 0;
|
||||
let mut position_out = 0;
|
||||
|
||||
for array in outputs {
|
||||
match array {
|
||||
|
@ -212,16 +204,36 @@ impl ElemWiseKernelCodegen<OutputPhase> {
|
|||
out: Variable::Output(index, elem_adapted),
|
||||
});
|
||||
index += 1;
|
||||
|
||||
if index == 1 {
|
||||
position_out = self.input_bindings.len(); // First output when we have a
|
||||
// new array for the output.
|
||||
}
|
||||
}
|
||||
Output::Input { elem, input, local } => {
|
||||
self.operations.push(Operator::AssignGlobal {
|
||||
input: Variable::Local(*local, *elem),
|
||||
out: Variable::Input(*input, bool_elem(*elem)),
|
||||
});
|
||||
position_out = *input as usize; // Input number when we use inplace operation.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We set the output number that will be used for the stride definition.
|
||||
for i in 0..self.input_bindings.len() {
|
||||
if let Some(Operator::ReadGlobalWithLayout {
|
||||
variable: _,
|
||||
tensor_read_pos: _,
|
||||
tensor_layout_pos,
|
||||
}) = self.operations.get_mut(i)
|
||||
{
|
||||
{
|
||||
*tensor_layout_pos = position_out;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
ElemWiseKernelCodegen {
|
||||
operations: self.operations,
|
||||
input_bindings: self.input_bindings,
|
||||
|
@ -275,6 +287,12 @@ pub struct StaticHandle<'a> {
|
|||
shape: &'a [usize],
|
||||
}
|
||||
|
||||
/// The position of the input or output to calculate the number of workgroups to launch.
|
||||
pub enum WorkgroupLaunch {
|
||||
Input { pos: usize },
|
||||
Output { pos: usize },
|
||||
}
|
||||
|
||||
/// Execute a static kernel.
|
||||
///
|
||||
///
|
||||
|
@ -284,6 +302,7 @@ pub fn execute_static<K, E: WgpuElement>(
|
|||
inputs: &[StaticHandle],
|
||||
outputs: &[StaticHandle],
|
||||
scalar_elems: Option<&[E]>,
|
||||
launch: WorkgroupLaunch,
|
||||
client: WgpuComputeClient,
|
||||
) where
|
||||
K: StaticKernelSource + 'static,
|
||||
|
@ -305,20 +324,26 @@ pub fn execute_static<K, E: WgpuElement>(
|
|||
}
|
||||
};
|
||||
|
||||
let mut num_elems_output = 0;
|
||||
|
||||
// We start by registering the inputs.
|
||||
for input in inputs.iter() {
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
if let WorkgroupLaunch::Input { pos } = &launch {
|
||||
if i == *pos {
|
||||
num_elems_output = calculate_num_elems_dyn_rank(input.shape);
|
||||
}
|
||||
};
|
||||
register_info_tensor(input.strides, input.shape);
|
||||
handles.push(input.handle);
|
||||
}
|
||||
|
||||
let mut num_elems_output = 0;
|
||||
|
||||
// Then we follow with the outputs.
|
||||
for output in outputs.iter() {
|
||||
let num_elems = calculate_num_elems_dyn_rank(output.shape);
|
||||
if num_elems > num_elems_output {
|
||||
num_elems_output = num_elems;
|
||||
}
|
||||
for (i, output) in outputs.iter().enumerate() {
|
||||
if let WorkgroupLaunch::Output { pos } = &launch {
|
||||
if i == *pos {
|
||||
num_elems_output = calculate_num_elems_dyn_rank(output.shape);
|
||||
}
|
||||
};
|
||||
register_info_tensor(output.strides, output.shape);
|
||||
handles.push(output.handle);
|
||||
}
|
||||
|
@ -352,8 +377,8 @@ pub(crate) fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
|
|||
|
||||
fn bool_elem(elem: Elem) -> Elem {
|
||||
match elem {
|
||||
// I32 are used for bool tensors
|
||||
Elem::Bool => Elem::I32,
|
||||
// U32 are used for bool tensors
|
||||
Elem::Bool => Elem::U32,
|
||||
_ => elem,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -118,10 +118,11 @@ pub enum Operator {
|
|||
ReadGlobal {
|
||||
variable: Variable,
|
||||
},
|
||||
ReadGlobalIntoContiguous {
|
||||
/// Read the tensor in a way to be compatible with another tensor layout.
|
||||
ReadGlobalWithLayout {
|
||||
variable: Variable,
|
||||
position: usize,
|
||||
position_out: usize,
|
||||
tensor_read_pos: usize,
|
||||
tensor_layout_pos: usize,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -202,10 +203,10 @@ impl Display for Operator {
|
|||
)),
|
||||
Variable::Scalar(_, _) => panic!("Can't read global scalar variable."),
|
||||
},
|
||||
Operator::ReadGlobalIntoContiguous {
|
||||
Operator::ReadGlobalWithLayout {
|
||||
variable,
|
||||
position,
|
||||
position_out,
|
||||
tensor_read_pos: position,
|
||||
tensor_layout_pos: position_out,
|
||||
} => {
|
||||
let (global, local, elem) = match variable {
|
||||
Variable::Input(number, elem) => (
|
||||
|
|
|
@ -74,26 +74,38 @@ where
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
binary_elemwise, compute::compute_client, kernel::KernelSettings, AutoGraphicsApi,
|
||||
WgpuDevice,
|
||||
binary,
|
||||
codegen::{Elem, Operator, Variable},
|
||||
compute::compute_client,
|
||||
kernel::{KernelSettings, WORKGROUP_DEFAULT},
|
||||
AutoGraphicsApi, WgpuDevice,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn can_run_kernel() {
|
||||
binary_elemwise!(Add, "+");
|
||||
binary!(
|
||||
operator: |elem: Elem| Operator::Add {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Input(1, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
elem_in: f32,
|
||||
elem_out: f32
|
||||
);
|
||||
|
||||
let client = compute_client::<AutoGraphicsApi>(&WgpuDevice::default());
|
||||
|
||||
let lhs: Vec<f32> = vec![0., 1., 2., 3., 4., 5., 6., 7.];
|
||||
let rhs: Vec<f32> = vec![10., 11., 12., 6., 7., 3., 1., 0.];
|
||||
let info: Vec<u32> = vec![1, 1, 1, 1, 8, 8, 8];
|
||||
let info: Vec<u32> = vec![1, 1, 8, 1, 8, 1, 8];
|
||||
|
||||
let lhs = client.create(bytemuck::cast_slice(&lhs));
|
||||
let rhs = client.create(bytemuck::cast_slice(&rhs));
|
||||
let out = client.empty(core::mem::size_of::<f32>() * 8);
|
||||
let info = client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
type Kernel = KernelSettings<Add, f32, i32, 16, 16, 1>;
|
||||
type Kernel =
|
||||
KernelSettings<Ops<f32, f32>, f32, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>;
|
||||
let kernel = Box::new(StaticKernel::<Kernel>::new(WorkGroup::new(1, 1, 1)));
|
||||
|
||||
client.execute(kernel, &[&lhs, &rhs, &out, &info]);
|
||||
|
|
|
@ -190,10 +190,10 @@ where
|
|||
Operator::AssignLocal { input: _, out: _ } => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
Operator::ReadGlobalIntoContiguous {
|
||||
Operator::ReadGlobalWithLayout {
|
||||
variable: _,
|
||||
position: _,
|
||||
position_out: _,
|
||||
tensor_read_pos: _,
|
||||
tensor_layout_pos: _,
|
||||
} => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ where
|
|||
.map(|(_tensor, elem)| Input::Array {
|
||||
elem: *elem,
|
||||
visibility: Visibility::Read,
|
||||
strategy: ReadingStrategy::IntoContiguous,
|
||||
strategy: ReadingStrategy::OutputLayout,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
|
|
|
@ -264,11 +264,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_kernel_type_id() {
|
||||
kernel_wgsl!(Add, "../template/binary_elemwise.wgsl");
|
||||
kernel_wgsl!(Cat, "../template/cat.wgsl");
|
||||
|
||||
let type_id_1 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 4>>();
|
||||
let type_id_2 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 5>>();
|
||||
let type_id_3 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 4>>();
|
||||
let type_id_1 = TypeId::of::<KernelSettings<Cat, f32, i32, 2, 3, 4>>();
|
||||
let type_id_2 = TypeId::of::<KernelSettings<Cat, f32, i32, 2, 3, 5>>();
|
||||
let type_id_3 = TypeId::of::<KernelSettings<Cat, f32, i32, 2, 3, 4>>();
|
||||
|
||||
assert_ne!(type_id_1, type_id_2);
|
||||
assert_eq!(type_id_1, type_id_3);
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
use crate::{
|
||||
codegen::{execute_static, StaticHandle, WorkgroupLaunch},
|
||||
element::WgpuElement,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
/// Creates a binary kernel.
|
||||
#[macro_export]
|
||||
macro_rules! binary {
|
||||
(
|
||||
operator: $ops:expr,
|
||||
input: $lhs:expr; $rhs:expr,
|
||||
elem: $elem:ty
|
||||
) => {{
|
||||
binary!(operator: $ops, elem_in: $elem, elem_out: $elem);
|
||||
|
||||
$crate::kernel::binary::<Ops<$elem, $elem>, OpsInplaceLhs<$elem, $elem>, OpsInplaceRhs<$elem, $elem>, $elem, D>(
|
||||
$lhs, $rhs, true
|
||||
)
|
||||
}};
|
||||
|
||||
(
|
||||
operator: $ops:expr,
|
||||
elem_in: $elem_in:ty,
|
||||
elem_out: $elem_out:ty
|
||||
) => {
|
||||
pub struct Ops<I, O> {
|
||||
_i: core::marker::PhantomData<I>,
|
||||
_o: core::marker::PhantomData<O>,
|
||||
}
|
||||
pub struct OpsInplaceLhs<I, O> {
|
||||
_i: core::marker::PhantomData<I>,
|
||||
_o: core::marker::PhantomData<O>,
|
||||
}
|
||||
pub struct OpsInplaceRhs<I, O> {
|
||||
_i: core::marker::PhantomData<I>,
|
||||
_o: core::marker::PhantomData<O>,
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
impl<I, O> $crate::kernel::StaticKernelSource for Ops<I, O>
|
||||
where
|
||||
I: $crate::element::WgpuElement,
|
||||
O: $crate::element::WgpuElement
|
||||
{
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||
.inputs(&[
|
||||
$crate::codegen::Input::Array {
|
||||
elem: I::elem_type(),
|
||||
visibility: $crate::codegen::Visibility::Read,
|
||||
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
|
||||
},
|
||||
$crate::codegen::Input::Array {
|
||||
elem: I::elem_type(),
|
||||
visibility: $crate::codegen::Visibility::Read,
|
||||
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
|
||||
},
|
||||
])
|
||||
.body(&[$ops(I::elem_type())])
|
||||
.outputs(&[$crate::codegen::Output::Array {
|
||||
elem: O::elem_type(),
|
||||
local: 0,
|
||||
}])
|
||||
.compile();
|
||||
|
||||
$crate::kernel::SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
impl<I, O> $crate::kernel::StaticKernelSource
|
||||
for OpsInplaceLhs<I, O>
|
||||
where
|
||||
I: $crate::element::WgpuElement,
|
||||
O: $crate::element::WgpuElement
|
||||
{
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||
.inputs(&[
|
||||
$crate::codegen::Input::Array {
|
||||
elem: I::elem_type(),
|
||||
visibility: $crate::codegen::Visibility::ReadWrite,
|
||||
strategy: $crate::codegen::ReadingStrategy::Plain,
|
||||
},
|
||||
$crate::codegen::Input::Array {
|
||||
elem: I::elem_type(),
|
||||
visibility: $crate::codegen::Visibility::Read,
|
||||
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
|
||||
},
|
||||
])
|
||||
.body(&[$ops(I::elem_type())])
|
||||
.outputs(&[$crate::codegen::Output::Input {
|
||||
elem: I::elem_type(),
|
||||
input: 0,
|
||||
local: 0,
|
||||
}])
|
||||
.compile();
|
||||
|
||||
$crate::kernel::SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
impl<I, O> $crate::kernel::StaticKernelSource
|
||||
for OpsInplaceRhs<I, O>
|
||||
where
|
||||
I: $crate::element::WgpuElement,
|
||||
O: $crate::element::WgpuElement
|
||||
{
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||
.inputs(&[
|
||||
$crate::codegen::Input::Array {
|
||||
elem: I::elem_type(),
|
||||
visibility: $crate::codegen::Visibility::Read,
|
||||
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
|
||||
},
|
||||
$crate::codegen::Input::Array {
|
||||
elem: I::elem_type(),
|
||||
visibility: $crate::codegen::Visibility::ReadWrite,
|
||||
strategy: $crate::codegen::ReadingStrategy::Plain,
|
||||
},
|
||||
])
|
||||
.body(&[$ops(I::elem_type())])
|
||||
.outputs(&[$crate::codegen::Output::Input {
|
||||
elem: I::elem_type(),
|
||||
input: 1,
|
||||
local: 0,
|
||||
}])
|
||||
.compile();
|
||||
|
||||
$crate::kernel::SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Launch an binary operation.
|
||||
pub fn binary<Kernel, KernelInplaceLhs, KernelInplaceRhs, E, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
inplace_enabled: bool,
|
||||
) -> WgpuTensor<E, D>
|
||||
where
|
||||
Kernel: crate::kernel::StaticKernelSource,
|
||||
KernelInplaceLhs: crate::kernel::StaticKernelSource,
|
||||
KernelInplaceRhs: crate::kernel::StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
{
|
||||
if inplace_enabled && lhs.can_mut_broadcast(&rhs) {
|
||||
execute_static::<KernelInplaceLhs, E>(
|
||||
&[
|
||||
StaticHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
StaticHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
],
|
||||
&[],
|
||||
None,
|
||||
WorkgroupLaunch::Input { pos: 0 },
|
||||
rhs.client,
|
||||
);
|
||||
|
||||
lhs
|
||||
} else if inplace_enabled && rhs.can_mut_broadcast(&lhs) {
|
||||
execute_static::<KernelInplaceRhs, E>(
|
||||
&[
|
||||
StaticHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
StaticHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
],
|
||||
&[],
|
||||
None,
|
||||
WorkgroupLaunch::Input { pos: 1 },
|
||||
lhs.client,
|
||||
);
|
||||
|
||||
rhs
|
||||
} else {
|
||||
let mut shape_out = [0; D];
|
||||
lhs.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(rhs.shape.dims.iter())
|
||||
.enumerate()
|
||||
.for_each(|(index, (dim_lhs, dim_rhs))| {
|
||||
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
|
||||
});
|
||||
|
||||
let shape_out = Shape::new(shape_out);
|
||||
let num_elems = shape_out.num_elements();
|
||||
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let out = WgpuTensor::new(lhs.client.clone(), lhs.device, shape_out, buffer);
|
||||
|
||||
execute_static::<Kernel, E>(
|
||||
&[
|
||||
StaticHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
StaticHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
],
|
||||
&[StaticHandle::new(
|
||||
&out.handle,
|
||||
&out.strides,
|
||||
&out.shape.dims,
|
||||
)],
|
||||
None,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
lhs.client,
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
}
|
|
@ -1,181 +0,0 @@
|
|||
use super::{
|
||||
build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT,
|
||||
};
|
||||
use crate::compute::StaticKernel;
|
||||
use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
kernel_wgsl!(BinaryElemwiseRaw, "../template/binary_elemwise.wgsl");
|
||||
kernel_wgsl!(
|
||||
BinaryElemwiseInplaceRaw,
|
||||
"../template/binary_elemwise_inplace.wgsl"
|
||||
);
|
||||
|
||||
/// Creates a binary elementwise kernel.
|
||||
#[macro_export]
|
||||
macro_rules! binary_elemwise {
|
||||
(
|
||||
$struct:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::BinaryElemwiseRaw::source().register(
|
||||
"body",
|
||||
format!("output[id] = lhs[index_lhs] {} rhs[index_rhs];", $ops),
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Creates a binary elementwise inplace kernel.
|
||||
#[macro_export]
|
||||
macro_rules! binary_elemwise_inplace {
|
||||
(
|
||||
$struct:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::BinaryElemwiseInplaceRaw::source().register(
|
||||
"body",
|
||||
format!("lhs[id] = lhs[id] {} rhs[index_rhs];", $ops),
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Execute a binary kernel using the default settings.
|
||||
pub fn binary_elemwise_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
binary_elemwise::<K, E, D, WORKGROUP_DEFAULT>(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Execute a binary kernel using the provided WORKGROUP.
|
||||
pub fn binary_elemwise<
|
||||
K: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
const D: usize,
|
||||
const WORKGROUP: usize,
|
||||
>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
lhs.assert_is_on_same_device(&rhs);
|
||||
|
||||
let mut shape_out = [0; D];
|
||||
lhs.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(rhs.shape.dims.iter())
|
||||
.enumerate()
|
||||
.for_each(|(index, (dim_lhs, dim_rhs))| {
|
||||
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
|
||||
});
|
||||
|
||||
let shape_out = Shape::new(shape_out);
|
||||
let num_elems = shape_out.num_elements();
|
||||
|
||||
let handle = lhs.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, handle);
|
||||
|
||||
let info = build_info(&[&lhs, &rhs, &output]);
|
||||
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
|
||||
lhs.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&lhs.handle, &rhs.handle, &output.handle, &info_handle],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Execute a binary inplace kernel using the default settings.
|
||||
pub fn binary_elemwise_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
binary_elemwise_inplace::<K, E, D, WORKGROUP_DEFAULT>(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Execute a binary inplace kernel using the provided WORKGROUP.
|
||||
pub fn binary_elemwise_inplace<
|
||||
K: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
const D: usize,
|
||||
const WORKGROUP: usize,
|
||||
>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
lhs.assert_is_on_same_device(&rhs);
|
||||
|
||||
let info = build_info(&[&lhs, &rhs]);
|
||||
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP),
|
||||
);
|
||||
|
||||
lhs.client
|
||||
.execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]);
|
||||
|
||||
lhs
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tests::{ReferenceBackend, TestBackend};
|
||||
use burn_tensor::{Distribution, Tensor};
|
||||
|
||||
binary_elemwise!(TestKernel, "*");
|
||||
binary_elemwise_inplace!(TestKernelInplace, "*");
|
||||
|
||||
#[test]
|
||||
fn binary_should_work_with_multiple_invocations() {
|
||||
let lhs = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
||||
let rhs = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
||||
let lhs_ref = Tensor::<ReferenceBackend, 2>::from_data(lhs.to_data());
|
||||
let rhs_ref = Tensor::<ReferenceBackend, 2>::from_data(rhs.to_data());
|
||||
|
||||
let actual =
|
||||
binary_elemwise::<TestKernel, _, 2, 16>(lhs.into_primitive(), rhs.into_primitive());
|
||||
let expected = lhs_ref * rhs_ref;
|
||||
|
||||
expected.into_data().assert_approx_eq(
|
||||
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
||||
3,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_inplace_should_work_with_multiple_invocations() {
|
||||
let lhs = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
||||
let rhs = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
||||
let lhs_ref = Tensor::<ReferenceBackend, 2>::from_data(lhs.to_data());
|
||||
let rhs_ref = Tensor::<ReferenceBackend, 2>::from_data(rhs.to_data());
|
||||
|
||||
let actual = binary_elemwise_inplace::<TestKernelInplace, _, 2, 16>(
|
||||
lhs.into_primitive(),
|
||||
rhs.into_primitive(),
|
||||
);
|
||||
let expected = lhs_ref * rhs_ref;
|
||||
|
||||
expected.into_data().assert_approx_eq(
|
||||
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
||||
3,
|
||||
);
|
||||
}
|
||||
}
|
|
@ -21,7 +21,7 @@ pub(crate) fn clamp<E: WgpuElement, const D: usize>(
|
|||
min_value: E,
|
||||
max_value: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary::<Ops<E>, OpsInplace<E>, E, D>(input, Some(&[min_value, max_value]))
|
||||
unary::<Ops<E>, OpsInplace<E>, E, D>(input, Some(&[min_value, max_value]), true)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -0,0 +1,219 @@
|
|||
use crate::{
|
||||
binary,
|
||||
codegen::{Elem, Operator, Variable},
|
||||
element::WgpuElement,
|
||||
kernel::StaticKernelSource,
|
||||
kernel::{binary::binary, unary::unary},
|
||||
tensor::WgpuTensor,
|
||||
unary,
|
||||
};
|
||||
use std::mem;
|
||||
|
||||
macro_rules! comparison {
|
||||
(
|
||||
binary: $ops:expr,
|
||||
input: $lhs:expr; $rhs:expr,
|
||||
elem: $elem:ty
|
||||
) => {{
|
||||
binary!(operator: $ops, elem_in: $elem, elem_out: $elem);
|
||||
|
||||
launch_binary::<Ops<E, u32>, OpsInplaceLhs<E, u32>, OpsInplaceRhs<E, u32>, E, D>($lhs, $rhs)
|
||||
}};
|
||||
|
||||
(
|
||||
unary: $ops:expr,
|
||||
input: $lhs:expr; $rhs:expr,
|
||||
elem: $elem:ty
|
||||
) => {{
|
||||
unary!($ops, scalar 1);
|
||||
|
||||
launch_unary::<Ops<E>, OpsInplace<E>, E, D>($lhs, $rhs)
|
||||
}};
|
||||
}
|
||||
|
||||
pub fn equal<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
comparison!(
|
||||
binary: |elem: Elem| Operator::Equal {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Input(1, elem),
|
||||
out: Variable::Local(0, Elem::Bool),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn greater<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
comparison!(
|
||||
binary: |elem: Elem| Operator::Greater {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Input(1, elem),
|
||||
out: Variable::Local(0, Elem::Bool),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn greater_equal<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
comparison!(
|
||||
binary: |elem: Elem| Operator::GreaterEqual {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Input(1, elem),
|
||||
out: Variable::Local(0, Elem::Bool),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn lower<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
comparison!(
|
||||
binary: |elem: Elem| Operator::Lower {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Input(1, elem),
|
||||
out: Variable::Local(0, Elem::Bool),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn lower_equal<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
comparison!(
|
||||
binary: |elem: Elem| Operator::LowerEqual {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Input(1, elem),
|
||||
out: Variable::Local(0, Elem::Bool),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn equal_elem<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
comparison!(
|
||||
unary: |elem: Elem| Operator::Equal {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Scalar(0, elem),
|
||||
out: Variable::Local(0, Elem::Bool),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn greater_elem<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
comparison!(
|
||||
unary: |elem: Elem| Operator::Greater {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Scalar(0, elem),
|
||||
out: Variable::Local(0, Elem::Bool),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn lower_elem<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
comparison!(
|
||||
unary: |elem: Elem| Operator::Lower {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Scalar(0, elem),
|
||||
out: Variable::Local(0, Elem::Bool),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn greater_equal_elem<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
comparison!(
|
||||
unary: |elem: Elem| Operator::GreaterEqual {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Scalar(0, elem),
|
||||
out: Variable::Local(0, Elem::Bool),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn lower_equal_elem<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
comparison!(
|
||||
unary: |elem: Elem| Operator::LowerEqual {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Scalar(0, elem),
|
||||
out: Variable::Local(0, Elem::Bool),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
fn launch_binary<Kernel, KernelInplaceLhs, KernelInplaceRhs, E, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D>
|
||||
where
|
||||
Kernel: StaticKernelSource,
|
||||
KernelInplaceLhs: StaticKernelSource,
|
||||
KernelInplaceRhs: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
{
|
||||
let can_be_used_as_bool = mem::size_of::<E>() == mem::size_of::<u32>();
|
||||
|
||||
let output =
|
||||
binary::<Kernel, KernelInplaceLhs, KernelInplaceRhs, E, D>(lhs, rhs, can_be_used_as_bool);
|
||||
|
||||
// We recast the tensor type.
|
||||
WgpuTensor::new(output.client, output.device, output.shape, output.handle)
|
||||
}
|
||||
|
||||
fn launch_unary<Kernel, KernelInplace, E, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
scalars: E,
|
||||
) -> WgpuTensor<u32, D>
|
||||
where
|
||||
Kernel: StaticKernelSource,
|
||||
KernelInplace: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
{
|
||||
let can_be_used_as_bool = mem::size_of::<E>() == mem::size_of::<u32>();
|
||||
|
||||
let output =
|
||||
unary::<Kernel, KernelInplace, E, D>(tensor, Some(&[scalars]), can_be_used_as_bool);
|
||||
|
||||
// We recast the tensor type.
|
||||
WgpuTensor::new(output.client, output.device, output.shape, output.handle)
|
||||
}
|
|
@ -1,166 +0,0 @@
|
|||
use crate::{
|
||||
comparison, comparison_elem, comparison_elem_inplace, comparison_inplace,
|
||||
element::WgpuElement,
|
||||
kernel::{comparison, comparison_elem, comparison_elem_inplace, comparison_inplace},
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use std::mem;
|
||||
|
||||
comparison!(Equal, "==");
|
||||
comparison!(Greater, ">");
|
||||
comparison!(GreaterEqual, ">=");
|
||||
comparison!(Lower, "<");
|
||||
comparison!(LowerEqual, "<=");
|
||||
|
||||
comparison_inplace!(EqualInplace, "==");
|
||||
comparison_inplace!(GreaterInplace, ">");
|
||||
comparison_inplace!(GreaterEqualInplace, ">=");
|
||||
comparison_inplace!(LowerInplace, "<");
|
||||
comparison_inplace!(LowerEqualInplace, "<=");
|
||||
|
||||
comparison_elem!(EqualElem, "==");
|
||||
comparison_elem!(GreaterElem, ">");
|
||||
comparison_elem!(GreaterEqualElem, ">=");
|
||||
comparison_elem!(LowerElem, "<");
|
||||
comparison_elem!(LowerEqualElem, "<=");
|
||||
|
||||
comparison_elem_inplace!(EqualElemInplace, "==");
|
||||
comparison_elem_inplace!(GreaterElemInplace, ">");
|
||||
comparison_elem_inplace!(GreaterEqualElemInplace, ">=");
|
||||
comparison_elem_inplace!(LowerElemInplace, "<");
|
||||
comparison_elem_inplace!(LowerEqualElemInplace, "<=");
|
||||
|
||||
pub fn equal<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
let can_be_used_as_bool = mem::size_of::<E>() == mem::size_of::<u32>();
|
||||
|
||||
if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) {
|
||||
return comparison_inplace::<EqualInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) {
|
||||
return comparison_inplace::<EqualInplace, E, D>(rhs, lhs);
|
||||
}
|
||||
|
||||
comparison::<Equal, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn greater<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
let can_be_used_as_bool = mem::size_of::<E>() == mem::size_of::<u32>();
|
||||
|
||||
if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) {
|
||||
return comparison_inplace::<GreaterInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) {
|
||||
return comparison_inplace::<LowerInplace, E, D>(rhs, lhs);
|
||||
}
|
||||
|
||||
comparison::<Greater, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn greater_equal<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
let can_be_used_as_bool = mem::size_of::<E>() == mem::size_of::<u32>();
|
||||
|
||||
if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) {
|
||||
return comparison_inplace::<GreaterEqualInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) {
|
||||
return comparison_inplace::<LowerEqualInplace, E, D>(rhs, lhs);
|
||||
}
|
||||
|
||||
comparison::<GreaterEqual, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn lower<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
let can_be_used_as_bool = mem::size_of::<E>() == mem::size_of::<u32>();
|
||||
|
||||
if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) {
|
||||
return comparison_inplace::<LowerInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) {
|
||||
return comparison_inplace::<GreaterInplace, E, D>(rhs, lhs);
|
||||
}
|
||||
|
||||
comparison::<Lower, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn lower_equal<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
let can_be_used_as_bool = mem::size_of::<E>() == mem::size_of::<u32>();
|
||||
|
||||
if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) {
|
||||
return comparison_inplace::<LowerEqualInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) {
|
||||
return comparison_inplace::<GreaterEqualInplace, E, D>(rhs, lhs);
|
||||
}
|
||||
|
||||
comparison::<LowerEqual, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn equal_elem<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
if mem::size_of::<E>() == mem::size_of::<u32>() && lhs.can_mut() {
|
||||
return comparison_elem_inplace::<EqualElemInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
comparison_elem::<EqualElem, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn greater_elem<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
if mem::size_of::<E>() == mem::size_of::<u32>() && lhs.can_mut() {
|
||||
return comparison_elem_inplace::<GreaterElemInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
comparison_elem::<GreaterElem, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn lower_elem<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
if mem::size_of::<E>() == mem::size_of::<u32>() && lhs.can_mut() {
|
||||
return comparison_elem_inplace::<LowerElemInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
comparison_elem::<LowerElem, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn greater_equal_elem<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
if mem::size_of::<E>() == mem::size_of::<u32>() && lhs.can_mut() {
|
||||
return comparison_elem_inplace::<GreaterEqualElemInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
comparison_elem::<GreaterEqualElem, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn lower_equal_elem<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
if mem::size_of::<E>() == mem::size_of::<u32>() && lhs.can_mut() {
|
||||
return comparison_elem_inplace::<LowerEqualElemInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
comparison_elem::<LowerEqualElem, E, D>(lhs, rhs)
|
||||
}
|
|
@ -1,177 +0,0 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{
|
||||
build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT,
|
||||
},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
kernel_wgsl!(ComparisonRaw, "../../template/comparison/binary.wgsl");
|
||||
kernel_wgsl!(
|
||||
ComparisonInplaceRaw,
|
||||
"../../template/comparison/binary_inplace.wgsl"
|
||||
);
|
||||
|
||||
/// Creates a comparison kernel.
|
||||
#[macro_export]
|
||||
macro_rules! comparison {
|
||||
(
|
||||
$struct:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::ComparisonRaw::source().register(
|
||||
"body",
|
||||
format!("output[id] = u32(lhs[index_lhs] {} rhs[index_rhs]);", $ops),
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Creates a comparison inplace kernel.
|
||||
#[macro_export]
|
||||
macro_rules! comparison_inplace {
|
||||
(
|
||||
$struct:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::ComparisonInplaceRaw::source()
|
||||
.register(
|
||||
"body",
|
||||
"lhs[index_lhs] = compare(lhs[index_lhs], rhs[index_rhs]);",
|
||||
)
|
||||
.add_template(format!(
|
||||
"{}return {{{{ elem }}}}(lhs {} rhs);{}",
|
||||
"fn compare(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {\n",
|
||||
$ops,
|
||||
"\n}\n"
|
||||
))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn comparison<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
lhs.assert_is_on_same_device(&rhs);
|
||||
let mut shape_out = [0; D];
|
||||
lhs.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(rhs.shape.dims.iter())
|
||||
.enumerate()
|
||||
.for_each(|(index, (dim_lhs, dim_rhs))| {
|
||||
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
|
||||
});
|
||||
|
||||
let shape_out = Shape::new(shape_out);
|
||||
let num_elems = shape_out.num_elements();
|
||||
|
||||
let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out);
|
||||
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP_DEFAULT),
|
||||
);
|
||||
let info = build_info(&[&lhs, &rhs, &output]);
|
||||
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
lhs.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&lhs.handle, &rhs.handle, &output.handle, &info_handle],
|
||||
);
|
||||
|
||||
WgpuTensor::new(output.client, output.device, output.shape, output.handle)
|
||||
}
|
||||
|
||||
pub fn comparison_inplace<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
lhs.assert_is_on_same_device(&rhs);
|
||||
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
|
||||
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT),
|
||||
);
|
||||
let info = build_info(&[&lhs, &rhs]);
|
||||
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
lhs.client
|
||||
.execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]);
|
||||
|
||||
WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tests::{ReferenceBackend, TestBackend};
|
||||
use burn_tensor::{backend::Backend, Bool, Distribution, Tensor};
|
||||
|
||||
comparison!(LowerEqual, "<=");
|
||||
comparison_inplace!(LowerEqualInplace, "<=");
|
||||
|
||||
#[test]
|
||||
fn comparison_should_work_with_multiple_invocations() {
|
||||
let (lhs, rhs, lhs_ref, rhs_ref) = inputs();
|
||||
|
||||
let value = Tensor::<TestBackend, 3, Bool>::from_primitive(
|
||||
comparison::<LowerEqual, f32, 3>(lhs.into_primitive(), rhs.into_primitive()),
|
||||
);
|
||||
|
||||
let value_ref = lhs_ref.lower_equal(rhs_ref);
|
||||
value
|
||||
.into_data()
|
||||
.assert_approx_eq(&value_ref.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comparison_inplace_should_work_with_multiple_invocations() {
|
||||
let (lhs, rhs, lhs_ref, rhs_ref) = inputs();
|
||||
|
||||
let value = Tensor::<TestBackend, 3, Bool>::from_primitive(comparison_inplace::<
|
||||
LowerEqualInplace,
|
||||
f32,
|
||||
3,
|
||||
>(
|
||||
lhs.into_primitive(),
|
||||
rhs.into_primitive(),
|
||||
));
|
||||
|
||||
let value_ref = lhs_ref.lower_equal(rhs_ref);
|
||||
value
|
||||
.into_data()
|
||||
.assert_approx_eq(&value_ref.into_data(), 3);
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn inputs() -> (
|
||||
Tensor<TestBackend, 3>,
|
||||
Tensor<TestBackend, 3>,
|
||||
Tensor<ReferenceBackend, 3>,
|
||||
Tensor<ReferenceBackend, 3>,
|
||||
) {
|
||||
TestBackend::seed(0);
|
||||
let lhs = Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Uniform(0.0, 1.0));
|
||||
let rhs = Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Uniform(0.0, 1.0));
|
||||
let lhs_ref = Tensor::<ReferenceBackend, 3>::from_data(lhs.to_data());
|
||||
let rhs_ref = Tensor::<ReferenceBackend, 3>::from_data(rhs.to_data());
|
||||
|
||||
(lhs, rhs, lhs_ref, rhs_ref)
|
||||
}
|
||||
}
|
|
@ -1,141 +0,0 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
kernel_wgsl!(ComparisonElemRaw, "../../template/comparison/elem.wgsl");
|
||||
kernel_wgsl!(
|
||||
ComparisonElemInplaceRaw,
|
||||
"../../template/comparison/elem_inplace.wgsl"
|
||||
);
|
||||
|
||||
/// Creates a comparison elementwise kernel.
|
||||
#[macro_export]
|
||||
macro_rules! comparison_elem {
|
||||
(
|
||||
$struct:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::ComparisonElemRaw::source()
|
||||
.register("body", format!("output[id] = u32(lhs[id] {} rhs);", $ops))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Creates a comparison elementwise inplace kernel.
|
||||
#[macro_export]
|
||||
macro_rules! comparison_elem_inplace {
|
||||
(
|
||||
$struct:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::ComparisonElemInplaceRaw::source()
|
||||
.register("body", "lhs[id] = compare(lhs[id], rhs);")
|
||||
.add_template(format!(
|
||||
"{}return {{{{ elem }}}}(lhs {} rhs);{}",
|
||||
"fn compare(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {\n",
|
||||
$ops,
|
||||
"\n}\n"
|
||||
))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn comparison_elem<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
let num_elems = lhs.shape.num_elements();
|
||||
|
||||
let handle = lhs.client.empty(num_elems * core::mem::size_of::<u32>());
|
||||
let rhs_handle = lhs.client.create(E::as_bytes(&[rhs]));
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP_DEFAULT),
|
||||
);
|
||||
|
||||
lhs.client
|
||||
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle, &handle]);
|
||||
|
||||
WgpuTensor::new(lhs.client, lhs.device, lhs.shape, handle)
|
||||
}
|
||||
|
||||
pub fn comparison_elem_inplace<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
|
||||
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT),
|
||||
);
|
||||
let rhs_handle = lhs.client.create(E::as_bytes(&[rhs]));
|
||||
lhs.client
|
||||
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]);
|
||||
|
||||
WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tests::{ReferenceBackend, TestBackend};
|
||||
use burn_tensor::{backend::Backend, Bool, Distribution, Tensor};
|
||||
|
||||
comparison_elem!(LowerEqual, "<=");
|
||||
comparison_elem_inplace!(LowerEqualInplace, "<=");
|
||||
|
||||
#[test]
|
||||
fn comparison_elem_should_work_with_multiple_invocations() {
|
||||
let (lhs, lhs_ref, rhs) = inputs();
|
||||
|
||||
let value =
|
||||
Tensor::<TestBackend, 3, Bool>::from_primitive(comparison_elem::<LowerEqual, f32, 3>(
|
||||
lhs.into_primitive(),
|
||||
rhs,
|
||||
));
|
||||
|
||||
let value_ref = lhs_ref.lower_equal_elem(rhs);
|
||||
value
|
||||
.into_data()
|
||||
.assert_approx_eq(&value_ref.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comparison_elem_inplace_should_work_with_multiple_invocations() {
|
||||
let (lhs, lhs_ref, rhs) = inputs();
|
||||
|
||||
let value =
|
||||
Tensor::<TestBackend, 3, Bool>::from_primitive(comparison_elem_inplace::<
|
||||
LowerEqualInplace,
|
||||
f32,
|
||||
3,
|
||||
>(lhs.into_primitive(), rhs));
|
||||
|
||||
let value_ref = lhs_ref.lower_equal_elem(rhs);
|
||||
value
|
||||
.into_data()
|
||||
.assert_approx_eq(&value_ref.into_data(), 3);
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn inputs() -> (Tensor<TestBackend, 3>, Tensor<ReferenceBackend, 3>, f32) {
|
||||
TestBackend::seed(0);
|
||||
let lhs = Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Uniform(0.0, 1.0));
|
||||
let lhs_ref = Tensor::<ReferenceBackend, 3>::from_data(lhs.to_data());
|
||||
|
||||
(lhs, lhs_ref, 5.0)
|
||||
}
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
mod base;
|
||||
mod binary;
|
||||
mod elem;
|
||||
|
||||
pub use base::*;
|
||||
pub use binary::*;
|
||||
pub use elem::*;
|
|
@ -1,5 +1,5 @@
|
|||
mod base;
|
||||
mod binary_elemwise;
|
||||
mod binary;
|
||||
mod cast;
|
||||
mod cat;
|
||||
mod clamp;
|
||||
|
@ -10,7 +10,7 @@ mod source;
|
|||
mod unary;
|
||||
|
||||
pub use base::*;
|
||||
pub use binary_elemwise::*;
|
||||
pub use binary::*;
|
||||
pub use cast::*;
|
||||
pub use source::*;
|
||||
pub use unary::*;
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::StaticKernelSource;
|
||||
use crate::{
|
||||
codegen::{execute_static, StaticHandle},
|
||||
codegen::{execute_static, StaticHandle, WorkgroupLaunch},
|
||||
element::WgpuElement,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
@ -15,7 +15,7 @@ macro_rules! unary {
|
|||
) => {{
|
||||
unary!($ops);
|
||||
|
||||
$crate::kernel::unary::<Ops<$elem>, OpsInplace<$elem>, $elem, D>($input, None)
|
||||
$crate::kernel::unary::<Ops<$elem>, OpsInplace<$elem>, $elem, D>($input, None, true)
|
||||
}};
|
||||
(
|
||||
operator: $ops:expr,
|
||||
|
@ -24,7 +24,7 @@ macro_rules! unary {
|
|||
) => {{
|
||||
unary!($ops, scalar 1);
|
||||
|
||||
$crate::kernel::unary::<Ops<$elem>, OpsInplace<$elem>, $elem, D>($input, Some(&[$scalar]))
|
||||
$crate::kernel::unary::<Ops<$elem>, OpsInplace<$elem>, $elem, D>($input, Some(&[$scalar]), true)
|
||||
}};
|
||||
|
||||
(
|
||||
|
@ -44,7 +44,7 @@ macro_rules! unary {
|
|||
.inputs(&[$crate::codegen::Input::Array {
|
||||
elem: E::elem_type(),
|
||||
visibility: $crate::codegen::Visibility::Read,
|
||||
strategy: $crate::codegen::ReadingStrategy::IntoContiguous,
|
||||
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
|
||||
}])
|
||||
.body(&[$ops(E::elem_type())])
|
||||
.outputs(&[$crate::codegen::Output::Array {
|
||||
|
@ -97,7 +97,7 @@ macro_rules! unary {
|
|||
$crate::codegen::Input::Array {
|
||||
elem: E::elem_type(),
|
||||
visibility: $crate::codegen::Visibility::Read,
|
||||
strategy: $crate::codegen::ReadingStrategy::IntoContiguous,
|
||||
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
|
||||
},
|
||||
$crate::codegen::Input::Scalar {
|
||||
elem: E::elem_type(),
|
||||
|
@ -145,16 +145,31 @@ macro_rules! unary {
|
|||
}
|
||||
|
||||
/// Launch an unary operation.
|
||||
pub fn unary<K, KI, E, const D: usize>(
|
||||
pub fn unary<Kernel, KernelInplace, E, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
scalars: Option<&[E]>,
|
||||
inplace_enabled: bool,
|
||||
) -> WgpuTensor<E, D>
|
||||
where
|
||||
K: StaticKernelSource,
|
||||
KI: StaticKernelSource,
|
||||
Kernel: StaticKernelSource,
|
||||
KernelInplace: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
{
|
||||
if !tensor.can_mut() {
|
||||
if inplace_enabled && tensor.can_mut() {
|
||||
execute_static::<KernelInplace, E>(
|
||||
&[StaticHandle::new(
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
)],
|
||||
&[],
|
||||
scalars,
|
||||
WorkgroupLaunch::Input { pos: 0 },
|
||||
tensor.client.clone(),
|
||||
);
|
||||
|
||||
tensor
|
||||
} else {
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(
|
||||
|
@ -164,7 +179,7 @@ where
|
|||
buffer,
|
||||
);
|
||||
|
||||
execute_static::<K, E>(
|
||||
execute_static::<Kernel, E>(
|
||||
&[StaticHandle::new(
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
|
@ -176,23 +191,11 @@ where
|
|||
&output.shape.dims,
|
||||
)],
|
||||
scalars,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
tensor.client,
|
||||
);
|
||||
|
||||
output
|
||||
} else {
|
||||
execute_static::<KI, E>(
|
||||
&[],
|
||||
&[StaticHandle::new(
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
)],
|
||||
scalars,
|
||||
tensor.client.clone(),
|
||||
);
|
||||
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -213,7 +216,8 @@ mod tests {
|
|||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||
|
||||
let actual = unary::<Ops<f32>, OpsInplace<f32>, f32, 2>(tensor.into_primitive(), None);
|
||||
let actual =
|
||||
unary::<Ops<f32>, OpsInplace<f32>, f32, 2>(tensor.into_primitive(), None, true);
|
||||
let expected = tensor_ref.tanh();
|
||||
|
||||
expected.into_data().assert_approx_eq(
|
||||
|
@ -227,7 +231,8 @@ mod tests {
|
|||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||
|
||||
let actual = unary::<Ops<f32>, OpsInplace<f32>, f32, 2>(tensor.into_primitive(), None);
|
||||
let actual =
|
||||
unary::<Ops<f32>, OpsInplace<f32>, f32, 2>(tensor.into_primitive(), None, true);
|
||||
let expected = tensor_ref.tanh();
|
||||
|
||||
expected.into_data().assert_approx_eq(
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
use crate::codegen::{Elem, Operator, Variable};
|
||||
use crate::compute::{compute_client, WgpuComputeClient};
|
||||
use crate::kernel::{binary_elemwise_default, binary_elemwise_inplace_default};
|
||||
use crate::{
|
||||
binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor, unary,
|
||||
};
|
||||
use crate::{GraphicsApi, WgpuDevice};
|
||||
use crate::{binary, GraphicsApi, WgpuDevice};
|
||||
use crate::{element::WgpuElement, tensor::WgpuTensor, unary};
|
||||
use burn_tensor::{Element, ElementConversion, Shape};
|
||||
|
||||
pub fn full<G: GraphicsApi, E: WgpuElement + Element, const D: usize>(
|
||||
|
@ -83,18 +80,15 @@ pub fn add<E: WgpuElement, const D: usize>(
|
|||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
binary_elemwise!(Add, "+");
|
||||
binary_elemwise_inplace!(AddInplace, "+");
|
||||
|
||||
if lhs.can_mut_broadcast(&rhs) {
|
||||
return binary_elemwise_inplace_default::<AddInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
if rhs.can_mut_broadcast(&lhs) {
|
||||
return binary_elemwise_inplace_default::<AddInplace, E, D>(rhs, lhs);
|
||||
}
|
||||
|
||||
binary_elemwise_default::<Add, E, D>(lhs, rhs)
|
||||
binary!(
|
||||
operator: |elem: Elem| Operator::Add {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Input(1, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn add_scalar<E: WgpuElement, const D: usize>(
|
||||
|
@ -116,14 +110,15 @@ pub fn sub<E: WgpuElement, const D: usize>(
|
|||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
binary_elemwise!(Sub, "-");
|
||||
binary_elemwise_inplace!(SubInplace, "-");
|
||||
|
||||
if lhs.can_mut_broadcast(&rhs) {
|
||||
return binary_elemwise_inplace_default::<SubInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
binary_elemwise_default::<Sub, E, D>(lhs, rhs)
|
||||
binary!(
|
||||
operator: |elem: Elem| Operator::Sub {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Input(1, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn sub_scalar<E: WgpuElement, const D: usize>(
|
||||
|
@ -145,18 +140,15 @@ pub fn mul<E: WgpuElement, const D: usize>(
|
|||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
binary_elemwise!(Mul, "*");
|
||||
binary_elemwise_inplace!(MulInplace, "*");
|
||||
|
||||
if lhs.can_mut_broadcast(&rhs) {
|
||||
return binary_elemwise_inplace_default::<MulInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
if rhs.can_mut_broadcast(&lhs) {
|
||||
return binary_elemwise_inplace_default::<MulInplace, E, D>(rhs, lhs);
|
||||
}
|
||||
|
||||
binary_elemwise_default::<Mul, E, D>(lhs, rhs)
|
||||
binary!(
|
||||
operator: |elem: Elem| Operator::Mul {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Input(1, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn mul_scalar<E: WgpuElement, const D: usize>(
|
||||
|
@ -178,14 +170,15 @@ pub fn div<E: WgpuElement, const D: usize>(
|
|||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
binary_elemwise!(Div, "/");
|
||||
binary_elemwise_inplace!(DivInplace, "/");
|
||||
|
||||
if lhs.can_mut_broadcast(&rhs) {
|
||||
return binary_elemwise_inplace_default::<DivInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
binary_elemwise_default::<Div, E, D>(lhs, rhs)
|
||||
binary!(
|
||||
operator: |elem: Elem| Operator::Div {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Input(1, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn div_scalar<E: WgpuElement, const D: usize>(
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let dim: u32 = info[0];
|
||||
var index_lhs: u32 = 0u;
|
||||
var index_rhs: u32 = 0u;
|
||||
|
||||
for (var i: u32 = 1u; i <= dim; i++) {
|
||||
let stride_lhs = info[i];
|
||||
let stride_rhs = info[i + dim];
|
||||
let stride_output = info[i + 2u * dim];
|
||||
let shape_lhs = info[i + 3u * dim];
|
||||
let shape_rhs = info[i + 4u * dim];
|
||||
|
||||
index_lhs += id / stride_output % shape_lhs * stride_lhs;
|
||||
index_rhs += id / stride_output % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
{{ body }}
|
||||
}
|
|
@ -1,34 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let dim: u32 = info[0];
|
||||
var index_rhs: u32 = 0u;
|
||||
|
||||
for (var i: u32 = 1u; i <= dim; i++) {
|
||||
let stride_lhs = info[i];
|
||||
let stride_rhs = info[i + dim];
|
||||
let shape_rhs = info[i + 3u * dim];
|
||||
|
||||
index_rhs += id / stride_lhs % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
{{ body }}
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read_write> output: array<u32>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let dim: u32 = info[0];
|
||||
var index_lhs: u32 = 0u;
|
||||
var index_rhs: u32 = 0u;
|
||||
|
||||
for (var i: u32 = 1u; i <= dim; i++) {
|
||||
let stride_lhs = info[i];
|
||||
let stride_rhs = info[i + dim];
|
||||
let stride_output = info[i + 2u * dim];
|
||||
let shape_lhs = info[i + 3u * dim];
|
||||
let shape_rhs = info[i + 4u * dim];
|
||||
|
||||
index_lhs += id / stride_output % shape_lhs * stride_lhs;
|
||||
index_rhs += id / stride_output % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
{{ body }}
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let dim: u32 = info[0];
|
||||
var index_lhs: u32 = 0u;
|
||||
var index_rhs: u32 = 0u;
|
||||
var num_elem = 1u;
|
||||
|
||||
for (var i: u32 = 1u; i <= dim; i++) {
|
||||
let stride_lhs = info[i];
|
||||
let stride_rhs = info[i + dim];
|
||||
let shape_lhs = info[i + 2u * dim];
|
||||
let shape_rhs = info[i + 3u * dim];
|
||||
num_elem *= shape_lhs;
|
||||
|
||||
index_lhs += id / stride_lhs % shape_lhs * stride_lhs;
|
||||
index_rhs += id / stride_lhs % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
if id < num_elem {
|
||||
{{ body }}
|
||||
}
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: {{ elem }};
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read_write> output: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
{{ body }}
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: {{ elem }};
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
{{ body }}
|
||||
}
|
|
@ -73,14 +73,17 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn can_mut_broadcast(&self, tensor_other: &WgpuTensor<E, D>) -> bool {
|
||||
pub(crate) fn can_mut_broadcast(&self, rhs: &WgpuTensor<E, D>) -> bool {
|
||||
if !self.handle.can_mut() {
|
||||
return false;
|
||||
}
|
||||
|
||||
for i in 0..D {
|
||||
let shape_lhs = self.shape.dims[i];
|
||||
let shape_rhs = rhs.shape.dims[i];
|
||||
|
||||
// Output tensor will be different from the mutable tensor.
|
||||
if self.shape.dims[i] < tensor_other.shape.dims[i] {
|
||||
if shape_lhs < shape_rhs {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue