diff --git a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs index 8438d5753..7f6dac0a9 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs @@ -145,6 +145,16 @@ macro_rules! gpu { gpu!(binary $lhs, $rhs, $out) )); }; + // out = lhs != rhs + ($scope:expr, $out:ident = $lhs:ident != $rhs:expr) => { + gpu!($scope, $out = not_equal($lhs, $rhs)) + }; + // out = not_equal(lhs, rhs) + ($scope:expr, $out:ident = not_equal($lhs:expr, $rhs:expr)) => { + $scope.register($crate::codegen::dialect::gpu::Operator::NotEqual( + gpu!(binary $lhs, $rhs, $out) + )); + }; // out = lhs > rhs ($scope:expr, $out:ident = $lhs:ident > $rhs:expr) => { gpu!($scope, $out = greater($lhs, $rhs)) diff --git a/crates/burn-jit/src/codegen/dialect/gpu/operation.rs b/crates/burn-jit/src/codegen/dialect/gpu/operation.rs index ddb956c52..43fe04d0b 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/operation.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/operation.rs @@ -39,6 +39,7 @@ pub enum Operator { Erf(UnaryOperator), Recip(UnaryOperator), Equal(BinaryOperator), + NotEqual(BinaryOperator), Lower(BinaryOperator), Clamp(ClampOperator), Greater(BinaryOperator), diff --git a/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs b/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs index 3d06a7652..9eb277a78 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs @@ -55,6 +55,7 @@ impl Operator { Operator::Erf(op) => Operator::Erf(op.vectorize(vectorization)), Operator::Recip(op) => Operator::Recip(op.vectorize(vectorization)), Operator::Equal(op) => Operator::Equal(op.vectorize(vectorization)), + Operator::NotEqual(op) => Operator::NotEqual(op.vectorize(vectorization)), Operator::Lower(op) => Operator::Lower(op.vectorize(vectorization)), Operator::Clamp(op) => Operator::Clamp(op.vectorize(vectorization)), Operator::Greater(op) => Operator::Greater(op.vectorize(vectorization)), diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index e56ec520f..bdabf8f70 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -311,6 +311,11 @@ impl TraceBuilder { &mut local_tensor_ids_input, &mut local_tensor_ids_output, ), + gpu::Operator::NotEqual(op) => mark_binary( + op, + &mut local_tensor_ids_input, + &mut local_tensor_ids_output, + ), gpu::Operator::Sqrt(op) => mark_unary( op, &mut local_tensor_ids_input, diff --git a/crates/burn-jit/src/kernel/contiguous.rs b/crates/burn-jit/src/kernel/contiguous.rs index 5641cf98c..c9bdd5f91 100644 --- a/crates/burn-jit/src/kernel/contiguous.rs +++ b/crates/burn-jit/src/kernel/contiguous.rs @@ -105,7 +105,6 @@ impl IntoContiguousShader { let offset_input = scope.zero(Elem::UInt); - // Batch offset for the lhs & rhs matrices. IndexOffsetGlobalWithLayout { tensors: vec![tensor], indexes: vec![offset_input], diff --git a/crates/burn-jit/src/kernel/mask/mask_fill.rs b/crates/burn-jit/src/kernel/mask/mask_fill.rs index 3faf30be9..46d80ce39 100644 --- a/crates/burn-jit/src/kernel/mask/mask_fill.rs +++ b/crates/burn-jit/src/kernel/mask/mask_fill.rs @@ -1,15 +1,12 @@ use crate::{ - compute::StaticKernel, + codegen::{EagerHandle, Execution, WorkgroupLaunch}, element::JitElement, - kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, ops::numeric::empty_device, tensor::JitTensor, Runtime, }; -kernel_wgsl!(MaskFill, "../../template/mask/fill.wgsl"); -kernel_wgsl!(MaskFillInplace, "../../template/mask/fill_inplace.wgsl"); +use super::{MaskFill, MaskInplaceEagerKernel, MaskReadOnlyEagerKernel}; #[derive(Clone, Copy, Debug)] /// Define how to run the mask fill kernel. @@ -37,58 +34,52 @@ pub fn mask_fill( } } -fn mask_fill_readonly( - input: JitTensor, - mask: JitTensor, - value: E, -) -> JitTensor { - let num_elems = input.shape.num_elements(); +fn mask_fill_readonly( + input: JitTensor, + mask: JitTensor, + value: EI, +) -> JitTensor { + let client = input.client.clone(); + let kernel = MaskReadOnlyEagerKernel::::new(false); + let output = empty_device( input.client.clone(), input.device.clone(), input.shape.clone(), ); - let value_handle = output.client.create(E::as_bytes(&[value])); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle); - let info = build_info(&[&input, &mask, &output]); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - input.client.execute( - Box::new(kernel), - &[ - &input.handle, - &value_handle, - &mask.handle, + Execution::start(kernel, client) + .inputs(&[ + EagerHandle::::new(&input.handle, &input.strides, &input.shape.dims), + EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims), + ]) + .outputs(&[EagerHandle::new( &output.handle, - &info_handle, - ], - ); + &output.strides, + &output.shape.dims, + )]) + .with_scalars(&[value]) + .execute(WorkgroupLaunch::Output { pos: 0 }); output } -fn mask_fill_inplace( - input: JitTensor, - mask: JitTensor, - value: E, -) -> JitTensor { - let num_elems = input.shape.num_elements(); - let value_handle = input.client.create(E::as_bytes(&[value])); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle); - let info = build_info(&[&input, &mask]); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); +fn mask_fill_inplace( + input: JitTensor, + mask: JitTensor, + value: EI, +) -> JitTensor { + let kernel = MaskInplaceEagerKernel::::new(false); - input.client.execute( - Box::new(kernel), - &[&input.handle, &value_handle, &mask.handle, &info_handle], - ); + let client = input.client.clone(); + + Execution::start(kernel, client) + .inputs(&[ + EagerHandle::::new(&input.handle, &input.strides, &input.shape.dims), + EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims), + ]) + .with_scalars(&[value]) + .execute(WorkgroupLaunch::Input { pos: 0 }); input } diff --git a/crates/burn-jit/src/kernel/mask/mask_where.rs b/crates/burn-jit/src/kernel/mask/mask_where.rs index ad8abe407..e8f5645d3 100644 --- a/crates/burn-jit/src/kernel/mask/mask_where.rs +++ b/crates/burn-jit/src/kernel/mask/mask_where.rs @@ -1,15 +1,12 @@ use crate::{ - compute::StaticKernel, + codegen::{EagerHandle, Execution, WorkgroupLaunch}, element::JitElement, - kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, ops::numeric::empty_device, tensor::JitTensor, Runtime, }; -kernel_wgsl!(MaskWhere, "../../template/mask/where.wgsl"); -kernel_wgsl!(MaskWhereInplace, "../../template/mask/where_inplace.wgsl"); +use super::{MaskInplaceEagerKernel, MaskReadOnlyEagerKernel, MaskWhere}; #[derive(Clone, Copy, Debug)] /// Define how to run the mask where kernel. @@ -40,63 +37,53 @@ pub fn mask_where( } } -fn mask_where_readonly( - input: JitTensor, - mask: JitTensor, - value: JitTensor, -) -> JitTensor { - let num_elems = input.shape.num_elements(); +fn mask_where_readonly( + input: JitTensor, + mask: JitTensor, + value: JitTensor, +) -> JitTensor { + let client = input.client.clone(); + let kernel = MaskReadOnlyEagerKernel::::new(false); + let output = empty_device( input.client.clone(), input.device.clone(), input.shape.clone(), ); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle); - let info = build_info(&[&input, &value, &mask, &output]); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - input.client.execute( - Box::new(kernel), - &[ - &input.handle, - &value.handle, - &mask.handle, + Execution::start(kernel, client) + .inputs(&[ + EagerHandle::::new(&input.handle, &input.strides, &input.shape.dims), + EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims), + EagerHandle::new(&value.handle, &value.strides, &value.shape.dims), + ]) + .outputs(&[EagerHandle::new( &output.handle, - &info_handle, - ], - ); + &output.strides, + &output.shape.dims, + )]) + .execute(WorkgroupLaunch::Output { pos: 0 }); output } -fn mask_where_inplace( - input: JitTensor, - mask: JitTensor, - value: JitTensor, +fn mask_where_inplace( + input: JitTensor, + mask: JitTensor, + value: JitTensor, reverse: bool, -) -> JitTensor { - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - input.shape.num_elements(), - WORKGROUP_DEFAULT, - )); - let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle); - let mut info = build_info(&[&input, &value, &mask]); - info.push(match reverse { - true => 1, - false => 0, - }); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); +) -> JitTensor { + let kernel = MaskInplaceEagerKernel::::new(reverse); - input.client.execute( - Box::new(kernel), - &[&input.handle, &value.handle, &mask.handle, &info_handle], - ); + let client = input.client.clone(); + + Execution::start(kernel, client) + .inputs(&[ + EagerHandle::::new(&input.handle, &input.strides, &input.shape.dims), + EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims), + EagerHandle::new(&value.handle, &value.strides, &value.shape.dims), + ]) + .execute(WorkgroupLaunch::Input { pos: 0 }); input } diff --git a/crates/burn-jit/src/kernel/mask/mod.rs b/crates/burn-jit/src/kernel/mask/mod.rs index 0044101f2..cda0456f6 100644 --- a/crates/burn-jit/src/kernel/mask/mod.rs +++ b/crates/burn-jit/src/kernel/mask/mod.rs @@ -1,8 +1,10 @@ mod base; mod mask_fill; mod mask_where; +mod shader; pub(crate) use base::*; +pub(crate) use shader::*; pub use mask_fill::*; pub use mask_where::*; diff --git a/crates/burn-jit/src/kernel/mask/shader.rs b/crates/burn-jit/src/kernel/mask/shader.rs new file mode 100644 index 000000000..b3d8fb414 --- /dev/null +++ b/crates/burn-jit/src/kernel/mask/shader.rs @@ -0,0 +1,271 @@ +use std::marker::PhantomData; + +use crate::{ + codegen::{Compilation, CompilationInfo, CompilationSettings, InputInfo, OutputInfo}, + gpu::{gpu, Elem, IndexOffsetGlobalWithLayout, Item, Scope, Variable, Visibility}, + kernel::{DynamicKernelSource, SourceTemplate}, + Compiler, JitElement, Runtime, +}; + +pub(crate) trait MaskStrategy: Send + Sync + 'static { + fn mask( + scope: &mut Scope, + masked_value: Variable, + value: Variable, + index: Variable, + ) -> Variable; + + fn value_info(value_item: Item) -> InputInfo; + fn value_variable(value_item: Item) -> Variable; +} + +pub(crate) struct MaskFill; + +impl MaskStrategy for MaskFill { + fn mask( + scope: &mut Scope, + masked_value: Variable, + value: Variable, + _index: Variable, + ) -> Variable { + gpu!(scope, masked_value = value); + masked_value + } + + fn value_info(value_item: Item) -> InputInfo { + InputInfo::Scalar { + elem: value_item.elem(), + size: 1, + } + } + + fn value_variable(value_item: Item) -> Variable { + Variable::GlobalScalar(0, value_item.elem()) + } +} + +pub(crate) struct MaskWhere; + +impl MaskStrategy for MaskWhere { + fn mask( + scope: &mut Scope, + masked_value: Variable, + value: Variable, + index: Variable, + ) -> Variable { + gpu!(scope, masked_value = value[index]); + masked_value + } + + fn value_info(value_item: Item) -> InputInfo { + InputInfo::Array { + item: value_item, + visibility: Visibility::Read, + } + } + + fn value_variable(value_item: Item) -> Variable { + Variable::GlobalInputArray(2, value_item) + } +} + +pub(crate) struct MaskShader { + input: Variable, + mask: Variable, + value: Variable, + output: Variable, + reversed: bool, + _mask_strategy: PhantomData, + _input_elem: PhantomData, + _mask_elem: PhantomData, +} + +#[derive(new)] +pub(crate) struct MaskReadOnlyEagerKernel< + M: MaskStrategy, + R: Runtime, + EI: JitElement, + EM: JitElement, +> { + reversed: bool, + _mask_strategy: PhantomData, + _runtime: PhantomData, + _input_elem: PhantomData, + _mask_elem: PhantomData, +} + +impl DynamicKernelSource + for MaskReadOnlyEagerKernel +{ + fn source(&self) -> crate::kernel::SourceTemplate { + let mut scope = Scope::root(); + let tensor_item = EI::gpu_elem().into(); + let mask_item = EM::gpu_elem().into(); + + let input = Variable::GlobalInputArray(0, tensor_item); + let mask = Variable::GlobalInputArray(1, mask_item); + let value = M::value_variable(tensor_item); + let output = Variable::GlobalOutputArray(0, tensor_item); + + MaskShader:: { + input, + mask, + value, + output, + reversed: self.reversed, + _mask_strategy: PhantomData::, + _input_elem: PhantomData::, + _mask_elem: PhantomData::, + } + .expand(&mut scope); + + scope.write_global_custom(output); + + let input = InputInfo::Array { + item: tensor_item, + visibility: Visibility::Read, + }; + + let mask = InputInfo::Array { + item: mask_item, + visibility: Visibility::Read, + }; + + let value = M::value_info(tensor_item); + + let out = OutputInfo::Array { item: tensor_item }; + + let info = CompilationInfo { + inputs: vec![input, mask, value], + outputs: vec![out], + scope, + }; + + let settings = CompilationSettings::default(); + let shader = Compilation::new(info).compile(settings); + let shader = ::compile(shader); + SourceTemplate::new(shader.to_string()) + } + + fn id(&self) -> String { + format!( + "{:?} rev={}", + core::any::TypeId::of::(), + self.reversed + ) + } +} + +#[derive(new)] +pub(crate) struct MaskInplaceEagerKernel< + M: MaskStrategy, + R: Runtime, + EI: JitElement, + EM: JitElement, +> { + reversed: bool, + _mask_strategy: PhantomData, + _runtime: PhantomData, + _input_elem: PhantomData, + _mask_elem: PhantomData, +} + +impl DynamicKernelSource + for MaskInplaceEagerKernel +{ + fn source(&self) -> crate::kernel::SourceTemplate { + let mut scope = Scope::root(); + let tensor_item = EI::gpu_elem().into(); + let mask_item = EM::gpu_elem().into(); + + let input = Variable::GlobalInputArray(0, tensor_item); + let mask = Variable::GlobalInputArray(1, mask_item); + let value = M::value_variable(tensor_item); + + MaskShader:: { + input, + mask, + value, + output: input, + reversed: self.reversed, + _mask_strategy: PhantomData::, + _input_elem: PhantomData::, + _mask_elem: PhantomData::, + } + .expand(&mut scope); + + let input = InputInfo::Array { + item: tensor_item, + visibility: Visibility::ReadWrite, + }; + + let mask = InputInfo::Array { + item: mask_item, + visibility: Visibility::Read, + }; + + let value = M::value_info(tensor_item); + + let info = CompilationInfo { + inputs: vec![input, mask, value], + outputs: vec![], + scope, + }; + + let settings = CompilationSettings::default(); + let shader = Compilation::new(info).compile(settings); + let shader = ::compile(shader); + SourceTemplate::new(shader.to_string()) + } + + fn id(&self) -> String { + format!( + "{:?} rev={}", + core::any::TypeId::of::(), + self.reversed + ) + } +} + +impl MaskShader { + pub(crate) fn expand(self, scope: &mut Scope) { + let id = Variable::Id; + let input = self.input; + let mask = self.mask; + let value = self.value; + let output = self.output; + + let index_input = scope.zero(Elem::UInt); + let index_mask = scope.zero(Elem::UInt); + + IndexOffsetGlobalWithLayout { + tensors: vec![input, mask], + indexes: vec![index_input, index_mask], + layout: output, + index_ref: id, + dim_start: 0u32.into(), + dim_end: Variable::Rank, + } + .expand(scope); + + // Determine if index should be masked + let value_in_mask = scope.create_local(mask.item()); + gpu!(scope, value_in_mask = mask[index_mask]); + let masked = scope.create_local(Elem::Bool); + let zero = scope.zero(value_in_mask.item()); + if self.reversed { + gpu!(scope, masked = value_in_mask == zero); + } else { + gpu!(scope, masked = value_in_mask != zero); + } + + // Assign a value at the index + let used_value = scope.create_local(output.item()); + gpu!(scope, if(masked).then(|scope| { + M::mask(scope, used_value, value, index_input ); + }).else(|scope| { + gpu!(scope, used_value = input[index_input]); + })); + gpu!(scope, output[id] = used_value); + } +} diff --git a/crates/burn-jit/src/template/mask/fill.wgsl b/crates/burn-jit/src/template/mask/fill.wgsl deleted file mode 100644 index a94661c56..000000000 --- a/crates/burn-jit/src/template/mask/fill.wgsl +++ /dev/null @@ -1,51 +0,0 @@ -@group(0) -@binding(0) -var input: array<{{ elem }}>; - -@group(0) -@binding(1) -var value: {{ elem }}; - -@group(0) -@binding(2) -var mask: array; - -@group(0) -@binding(3) -var output: array<{{ elem }}>; - -@group(0) -@binding(4) -var info: array; - -const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; - let dim = info[0]; - var index_input = 0u; - var index_mask = 0u; - - for (var i = 1u; i <= dim; i++) { - let stride_input = info[i]; - let stride_mask = info[i + dim]; - let stride_output = info[i + 2u * dim]; - let shape_input = info[i + 3u * dim]; - let shape_mask = info[i + 4u * dim]; - - index_input += id / stride_output % shape_input * stride_input; - index_mask += id / stride_output % shape_mask * stride_mask; - } - - - if mask[index_mask] != 0u { - output[id] = value; - } else { - output[id] = input[index_input]; - } -} diff --git a/crates/burn-jit/src/template/mask/fill_inplace.wgsl b/crates/burn-jit/src/template/mask/fill_inplace.wgsl deleted file mode 100644 index aafdf8860..000000000 --- a/crates/burn-jit/src/template/mask/fill_inplace.wgsl +++ /dev/null @@ -1,44 +0,0 @@ -@group(0) -@binding(0) -var input: array<{{ elem }}>; - -@group(0) -@binding(1) -var value: {{ elem }}; - - -@group(0) -@binding(2) -var mask: array; - -@group(0) -@binding(3) -var info: array; - -const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; - let dim = info[0]; - var index_input = 0u; - var index_mask = 0u; - - for (var i = 1u; i <= dim; i++) { - let stride_input = info[i]; - let stride_mask = info[i + dim]; - let shape_input = info[i + 2u * dim]; - let shape_mask = info[i + 3u * dim]; - - index_input += id / stride_input % shape_input * stride_input; - index_mask += id / stride_input % shape_mask * stride_mask; - } - - if mask[index_mask] != 0u { - input[index_input] = value; - } -} diff --git a/crates/burn-jit/src/template/mask/where.wgsl b/crates/burn-jit/src/template/mask/where.wgsl deleted file mode 100644 index 944222e11..000000000 --- a/crates/burn-jit/src/template/mask/where.wgsl +++ /dev/null @@ -1,55 +0,0 @@ -@group(0) -@binding(0) -var input: array<{{ elem }}>; - -@group(0) -@binding(1) -var value: array<{{ elem }}>; - -@group(0) -@binding(2) -var mask: array; - -@group(0) -@binding(3) -var output: array<{{ elem }}>; - -@group(0) -@binding(4) -var info: array; - -const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; - let dim = info[0]; - var index_input = 0u; - var index_value = 0u; - var index_mask = 0u; - - for (var i = 1u; i <= dim; i++) { - let stride_input = info[i]; - let stride_value = info[i + dim]; - let stride_mask = info[i + 2u * dim]; - let stride_output = info[i + 3u * dim]; - - let shape_input = info[i + 4u * dim]; - let shape_value = info[i + 5u * dim]; - let shape_mask = info[i + 6u * dim]; - - index_input += id / stride_output % shape_input * stride_input; - index_value += id / stride_output % shape_value * stride_value; - index_mask += id / stride_output % shape_mask * stride_mask; - } - - if mask[index_mask] != 0u { - output[id] = value[index_value]; - } else { - output[id] = input[index_input]; - } -} diff --git a/crates/burn-jit/src/template/mask/where_inplace.wgsl b/crates/burn-jit/src/template/mask/where_inplace.wgsl deleted file mode 100644 index d3931fee4..000000000 --- a/crates/burn-jit/src/template/mask/where_inplace.wgsl +++ /dev/null @@ -1,58 +0,0 @@ -@group(0) -@binding(0) -var input: array<{{ elem }}>; - -@group(0) -@binding(1) -var value: array<{{ elem }}>; - -@group(0) -@binding(2) -var mask: array; - -@group(0) -@binding(3) -var info: array; - -const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; - let dim = info[0]; - let reverse = info[6u * dim + 1u]; - - var index_input = 0u; - var index_value = 0u; - var index_mask = 0u; - - for (var i = 1u; i <= dim; i++) { - let stride_input = info[i]; - let stride_value = info[i + dim]; - let stride_mask = info[i + 2u * dim]; - - let shape_input = info[i + 3u * dim]; - let shape_value = info[i + 4u * dim]; - let shape_mask = info[i + 5u * dim]; - - index_input += id / stride_input % shape_input * stride_input; - index_value += id / stride_input % shape_value * stride_value; - index_mask += id / stride_input % shape_mask * stride_mask; - } - - var condition = mask[index_mask] != 0u; - - if reverse == 1u { - condition = !condition; - } - - if condition { - input[index_input] = value[index_value]; - } else { - input[index_input] = input[index_input]; - } -} diff --git a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs index 3a2f6a91d..349243505 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs @@ -484,6 +484,11 @@ impl WgslCompiler { rhs: self.compile_variable(op.rhs), out: self.compile_variable(op.out), }, + gpu::Operator::NotEqual(op) => wgsl::Instruction::NotEqual { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(op.out), + }, gpu::Operator::Assign(op) => wgsl::Instruction::Assign { input: self.compile_variable(op.input), out: self.compile_variable(op.out), diff --git a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs index 7b410b6c0..59798bdac 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs @@ -148,6 +148,11 @@ pub enum Instruction { rhs: Variable, out: Variable, }, + NotEqual { + lhs: Variable, + rhs: Variable, + out: Variable, + }, Stride { dim: Variable, position: usize, @@ -289,6 +294,7 @@ impl Display for Instruction { Instruction::Greater { lhs, rhs, out } => comparison(lhs, rhs, out, ">", f), Instruction::LowerEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "<=", f), Instruction::GreaterEqual { lhs, rhs, out } => comparison(lhs, rhs, out, ">=", f), + Instruction::NotEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "!=", f), Instruction::Assign { input, out } => match out.item() { Item::Vec4(elem) => { let input0 = input.index(0);