mirror of https://github.com/tracel-ai/burn.git
Migrate/jit/mask (#1456)
This commit is contained in:
parent
fa0dfec3c7
commit
278fcb3dad
|
@ -145,6 +145,16 @@ macro_rules! gpu {
|
|||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs != rhs
|
||||
($scope:expr, $out:ident = $lhs:ident != $rhs:expr) => {
|
||||
gpu!($scope, $out = not_equal($lhs, $rhs))
|
||||
};
|
||||
// out = not_equal(lhs, rhs)
|
||||
($scope:expr, $out:ident = not_equal($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::NotEqual(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs > rhs
|
||||
($scope:expr, $out:ident = $lhs:ident > $rhs:expr) => {
|
||||
gpu!($scope, $out = greater($lhs, $rhs))
|
||||
|
|
|
@ -39,6 +39,7 @@ pub enum Operator {
|
|||
Erf(UnaryOperator),
|
||||
Recip(UnaryOperator),
|
||||
Equal(BinaryOperator),
|
||||
NotEqual(BinaryOperator),
|
||||
Lower(BinaryOperator),
|
||||
Clamp(ClampOperator),
|
||||
Greater(BinaryOperator),
|
||||
|
|
|
@ -55,6 +55,7 @@ impl Operator {
|
|||
Operator::Erf(op) => Operator::Erf(op.vectorize(vectorization)),
|
||||
Operator::Recip(op) => Operator::Recip(op.vectorize(vectorization)),
|
||||
Operator::Equal(op) => Operator::Equal(op.vectorize(vectorization)),
|
||||
Operator::NotEqual(op) => Operator::NotEqual(op.vectorize(vectorization)),
|
||||
Operator::Lower(op) => Operator::Lower(op.vectorize(vectorization)),
|
||||
Operator::Clamp(op) => Operator::Clamp(op.vectorize(vectorization)),
|
||||
Operator::Greater(op) => Operator::Greater(op.vectorize(vectorization)),
|
||||
|
|
|
@ -311,6 +311,11 @@ impl TraceBuilder {
|
|||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::NotEqual(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Sqrt(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
|
|
|
@ -105,7 +105,6 @@ impl IntoContiguousShader {
|
|||
|
||||
let offset_input = scope.zero(Elem::UInt);
|
||||
|
||||
// Batch offset for the lhs & rhs matrices.
|
||||
IndexOffsetGlobalWithLayout {
|
||||
tensors: vec![tensor],
|
||||
indexes: vec![offset_input],
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
codegen::{EagerHandle, Execution, WorkgroupLaunch},
|
||||
element::JitElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
};
|
||||
|
||||
kernel_wgsl!(MaskFill, "../../template/mask/fill.wgsl");
|
||||
kernel_wgsl!(MaskFillInplace, "../../template/mask/fill_inplace.wgsl");
|
||||
use super::{MaskFill, MaskInplaceEagerKernel, MaskReadOnlyEagerKernel};
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
/// Define how to run the mask fill kernel.
|
||||
|
@ -37,58 +34,52 @@ pub fn mask_fill<R: Runtime, E: JitElement, const D: usize>(
|
|||
}
|
||||
}
|
||||
|
||||
fn mask_fill_readonly<R: Runtime, E: JitElement, const D: usize>(
|
||||
input: JitTensor<R, E, D>,
|
||||
mask: JitTensor<R, u32, D>,
|
||||
value: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
let num_elems = input.shape.num_elements();
|
||||
fn mask_fill_readonly<R: Runtime, EI: JitElement, EM: JitElement, const D: usize>(
|
||||
input: JitTensor<R, EI, D>,
|
||||
mask: JitTensor<R, EM, D>,
|
||||
value: EI,
|
||||
) -> JitTensor<R, EI, D> {
|
||||
let client = input.client.clone();
|
||||
let kernel = MaskReadOnlyEagerKernel::<MaskFill, R, EI, EM>::new(false);
|
||||
|
||||
let output = empty_device(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
input.shape.clone(),
|
||||
);
|
||||
|
||||
let value_handle = output.client.create(E::as_bytes(&[value]));
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<MaskFill, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle);
|
||||
let info = build_info(&[&input, &mask, &output]);
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[
|
||||
&input.handle,
|
||||
&value_handle,
|
||||
&mask.handle,
|
||||
Execution::start(kernel, client)
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
|
||||
EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims),
|
||||
])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&info_handle,
|
||||
],
|
||||
);
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)])
|
||||
.with_scalars(&[value])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn mask_fill_inplace<R: Runtime, E: JitElement, const D: usize>(
|
||||
input: JitTensor<R, E, D>,
|
||||
mask: JitTensor<R, u32, D>,
|
||||
value: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
let num_elems = input.shape.num_elements();
|
||||
let value_handle = input.client.create(E::as_bytes(&[value]));
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<MaskFillInplace, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle);
|
||||
let info = build_info(&[&input, &mask]);
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
fn mask_fill_inplace<R: Runtime, EI: JitElement, EM: JitElement, const D: usize>(
|
||||
input: JitTensor<R, EI, D>,
|
||||
mask: JitTensor<R, EM, D>,
|
||||
value: EI,
|
||||
) -> JitTensor<R, EI, D> {
|
||||
let kernel = MaskInplaceEagerKernel::<MaskFill, R, EI, EM>::new(false);
|
||||
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&input.handle, &value_handle, &mask.handle, &info_handle],
|
||||
);
|
||||
let client = input.client.clone();
|
||||
|
||||
Execution::start(kernel, client)
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
|
||||
EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims),
|
||||
])
|
||||
.with_scalars(&[value])
|
||||
.execute(WorkgroupLaunch::Input { pos: 0 });
|
||||
|
||||
input
|
||||
}
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
codegen::{EagerHandle, Execution, WorkgroupLaunch},
|
||||
element::JitElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
};
|
||||
|
||||
kernel_wgsl!(MaskWhere, "../../template/mask/where.wgsl");
|
||||
kernel_wgsl!(MaskWhereInplace, "../../template/mask/where_inplace.wgsl");
|
||||
use super::{MaskInplaceEagerKernel, MaskReadOnlyEagerKernel, MaskWhere};
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
/// Define how to run the mask where kernel.
|
||||
|
@ -40,63 +37,53 @@ pub fn mask_where<R: Runtime, E: JitElement, const D: usize>(
|
|||
}
|
||||
}
|
||||
|
||||
fn mask_where_readonly<R: Runtime, E: JitElement, const D: usize>(
|
||||
input: JitTensor<R, E, D>,
|
||||
mask: JitTensor<R, u32, D>,
|
||||
value: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
let num_elems = input.shape.num_elements();
|
||||
fn mask_where_readonly<R: Runtime, EI: JitElement, EM: JitElement, const D: usize>(
|
||||
input: JitTensor<R, EI, D>,
|
||||
mask: JitTensor<R, EM, D>,
|
||||
value: JitTensor<R, EI, D>,
|
||||
) -> JitTensor<R, EI, D> {
|
||||
let client = input.client.clone();
|
||||
let kernel = MaskReadOnlyEagerKernel::<MaskWhere, R, EI, EM>::new(false);
|
||||
|
||||
let output = empty_device(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
input.shape.clone(),
|
||||
);
|
||||
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<MaskWhere, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle);
|
||||
let info = build_info(&[&input, &value, &mask, &output]);
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[
|
||||
&input.handle,
|
||||
&value.handle,
|
||||
&mask.handle,
|
||||
Execution::start(kernel, client)
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
|
||||
EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims),
|
||||
EagerHandle::new(&value.handle, &value.strides, &value.shape.dims),
|
||||
])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&info_handle,
|
||||
],
|
||||
);
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn mask_where_inplace<R: Runtime, E: JitElement, const D: usize>(
|
||||
input: JitTensor<R, E, D>,
|
||||
mask: JitTensor<R, u32, D>,
|
||||
value: JitTensor<R, E, D>,
|
||||
fn mask_where_inplace<R: Runtime, EI: JitElement, EM: JitElement, const D: usize>(
|
||||
input: JitTensor<R, EI, D>,
|
||||
mask: JitTensor<R, EM, D>,
|
||||
value: JitTensor<R, EI, D>,
|
||||
reverse: bool,
|
||||
) -> JitTensor<R, E, D> {
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<MaskWhereInplace, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
input.shape.num_elements(),
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle);
|
||||
let mut info = build_info(&[&input, &value, &mask]);
|
||||
info.push(match reverse {
|
||||
true => 1,
|
||||
false => 0,
|
||||
});
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
) -> JitTensor<R, EI, D> {
|
||||
let kernel = MaskInplaceEagerKernel::<MaskWhere, R, EI, EM>::new(reverse);
|
||||
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&input.handle, &value.handle, &mask.handle, &info_handle],
|
||||
);
|
||||
let client = input.client.clone();
|
||||
|
||||
Execution::start(kernel, client)
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
|
||||
EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims),
|
||||
EagerHandle::new(&value.handle, &value.strides, &value.shape.dims),
|
||||
])
|
||||
.execute(WorkgroupLaunch::Input { pos: 0 });
|
||||
|
||||
input
|
||||
}
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
mod base;
|
||||
mod mask_fill;
|
||||
mod mask_where;
|
||||
mod shader;
|
||||
|
||||
pub(crate) use base::*;
|
||||
pub(crate) use shader::*;
|
||||
|
||||
pub use mask_fill::*;
|
||||
pub use mask_where::*;
|
||||
|
|
|
@ -0,0 +1,271 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
codegen::{Compilation, CompilationInfo, CompilationSettings, InputInfo, OutputInfo},
|
||||
gpu::{gpu, Elem, IndexOffsetGlobalWithLayout, Item, Scope, Variable, Visibility},
|
||||
kernel::{DynamicKernelSource, SourceTemplate},
|
||||
Compiler, JitElement, Runtime,
|
||||
};
|
||||
|
||||
pub(crate) trait MaskStrategy: Send + Sync + 'static {
|
||||
fn mask(
|
||||
scope: &mut Scope,
|
||||
masked_value: Variable,
|
||||
value: Variable,
|
||||
index: Variable,
|
||||
) -> Variable;
|
||||
|
||||
fn value_info(value_item: Item) -> InputInfo;
|
||||
fn value_variable(value_item: Item) -> Variable;
|
||||
}
|
||||
|
||||
pub(crate) struct MaskFill;
|
||||
|
||||
impl MaskStrategy for MaskFill {
|
||||
fn mask(
|
||||
scope: &mut Scope,
|
||||
masked_value: Variable,
|
||||
value: Variable,
|
||||
_index: Variable,
|
||||
) -> Variable {
|
||||
gpu!(scope, masked_value = value);
|
||||
masked_value
|
||||
}
|
||||
|
||||
fn value_info(value_item: Item) -> InputInfo {
|
||||
InputInfo::Scalar {
|
||||
elem: value_item.elem(),
|
||||
size: 1,
|
||||
}
|
||||
}
|
||||
|
||||
fn value_variable(value_item: Item) -> Variable {
|
||||
Variable::GlobalScalar(0, value_item.elem())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct MaskWhere;
|
||||
|
||||
impl MaskStrategy for MaskWhere {
|
||||
fn mask(
|
||||
scope: &mut Scope,
|
||||
masked_value: Variable,
|
||||
value: Variable,
|
||||
index: Variable,
|
||||
) -> Variable {
|
||||
gpu!(scope, masked_value = value[index]);
|
||||
masked_value
|
||||
}
|
||||
|
||||
fn value_info(value_item: Item) -> InputInfo {
|
||||
InputInfo::Array {
|
||||
item: value_item,
|
||||
visibility: Visibility::Read,
|
||||
}
|
||||
}
|
||||
|
||||
fn value_variable(value_item: Item) -> Variable {
|
||||
Variable::GlobalInputArray(2, value_item)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct MaskShader<EI: JitElement, EM: JitElement, M: MaskStrategy> {
|
||||
input: Variable,
|
||||
mask: Variable,
|
||||
value: Variable,
|
||||
output: Variable,
|
||||
reversed: bool,
|
||||
_mask_strategy: PhantomData<M>,
|
||||
_input_elem: PhantomData<EI>,
|
||||
_mask_elem: PhantomData<EM>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct MaskReadOnlyEagerKernel<
|
||||
M: MaskStrategy,
|
||||
R: Runtime,
|
||||
EI: JitElement,
|
||||
EM: JitElement,
|
||||
> {
|
||||
reversed: bool,
|
||||
_mask_strategy: PhantomData<M>,
|
||||
_runtime: PhantomData<R>,
|
||||
_input_elem: PhantomData<EI>,
|
||||
_mask_elem: PhantomData<EM>,
|
||||
}
|
||||
|
||||
impl<M: MaskStrategy, R: Runtime, EI: JitElement, EM: JitElement> DynamicKernelSource
|
||||
for MaskReadOnlyEagerKernel<M, R, EI, EM>
|
||||
{
|
||||
fn source(&self) -> crate::kernel::SourceTemplate {
|
||||
let mut scope = Scope::root();
|
||||
let tensor_item = EI::gpu_elem().into();
|
||||
let mask_item = EM::gpu_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, tensor_item);
|
||||
let mask = Variable::GlobalInputArray(1, mask_item);
|
||||
let value = M::value_variable(tensor_item);
|
||||
let output = Variable::GlobalOutputArray(0, tensor_item);
|
||||
|
||||
MaskShader::<EI, EM, M> {
|
||||
input,
|
||||
mask,
|
||||
value,
|
||||
output,
|
||||
reversed: self.reversed,
|
||||
_mask_strategy: PhantomData::<M>,
|
||||
_input_elem: PhantomData::<EI>,
|
||||
_mask_elem: PhantomData::<EM>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
let input = InputInfo::Array {
|
||||
item: tensor_item,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
|
||||
let mask = InputInfo::Array {
|
||||
item: mask_item,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
|
||||
let value = M::value_info(tensor_item);
|
||||
|
||||
let out = OutputInfo::Array { item: tensor_item };
|
||||
|
||||
let info = CompilationInfo {
|
||||
inputs: vec![input, mask, value],
|
||||
outputs: vec![out],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = CompilationSettings::default();
|
||||
let shader = Compilation::new(info).compile(settings);
|
||||
let shader = <R::Compiler as Compiler>::compile(shader);
|
||||
SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!(
|
||||
"{:?} rev={}",
|
||||
core::any::TypeId::of::<Self>(),
|
||||
self.reversed
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct MaskInplaceEagerKernel<
|
||||
M: MaskStrategy,
|
||||
R: Runtime,
|
||||
EI: JitElement,
|
||||
EM: JitElement,
|
||||
> {
|
||||
reversed: bool,
|
||||
_mask_strategy: PhantomData<M>,
|
||||
_runtime: PhantomData<R>,
|
||||
_input_elem: PhantomData<EI>,
|
||||
_mask_elem: PhantomData<EM>,
|
||||
}
|
||||
|
||||
impl<M: MaskStrategy, R: Runtime, EI: JitElement, EM: JitElement> DynamicKernelSource
|
||||
for MaskInplaceEagerKernel<M, R, EI, EM>
|
||||
{
|
||||
fn source(&self) -> crate::kernel::SourceTemplate {
|
||||
let mut scope = Scope::root();
|
||||
let tensor_item = EI::gpu_elem().into();
|
||||
let mask_item = EM::gpu_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, tensor_item);
|
||||
let mask = Variable::GlobalInputArray(1, mask_item);
|
||||
let value = M::value_variable(tensor_item);
|
||||
|
||||
MaskShader::<EI, EM, M> {
|
||||
input,
|
||||
mask,
|
||||
value,
|
||||
output: input,
|
||||
reversed: self.reversed,
|
||||
_mask_strategy: PhantomData::<M>,
|
||||
_input_elem: PhantomData::<EI>,
|
||||
_mask_elem: PhantomData::<EM>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
let input = InputInfo::Array {
|
||||
item: tensor_item,
|
||||
visibility: Visibility::ReadWrite,
|
||||
};
|
||||
|
||||
let mask = InputInfo::Array {
|
||||
item: mask_item,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
|
||||
let value = M::value_info(tensor_item);
|
||||
|
||||
let info = CompilationInfo {
|
||||
inputs: vec![input, mask, value],
|
||||
outputs: vec![],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = CompilationSettings::default();
|
||||
let shader = Compilation::new(info).compile(settings);
|
||||
let shader = <R::Compiler as Compiler>::compile(shader);
|
||||
SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!(
|
||||
"{:?} rev={}",
|
||||
core::any::TypeId::of::<Self>(),
|
||||
self.reversed
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<EI: JitElement, EM: JitElement, M: MaskStrategy> MaskShader<EI, EM, M> {
|
||||
pub(crate) fn expand(self, scope: &mut Scope) {
|
||||
let id = Variable::Id;
|
||||
let input = self.input;
|
||||
let mask = self.mask;
|
||||
let value = self.value;
|
||||
let output = self.output;
|
||||
|
||||
let index_input = scope.zero(Elem::UInt);
|
||||
let index_mask = scope.zero(Elem::UInt);
|
||||
|
||||
IndexOffsetGlobalWithLayout {
|
||||
tensors: vec![input, mask],
|
||||
indexes: vec![index_input, index_mask],
|
||||
layout: output,
|
||||
index_ref: id,
|
||||
dim_start: 0u32.into(),
|
||||
dim_end: Variable::Rank,
|
||||
}
|
||||
.expand(scope);
|
||||
|
||||
// Determine if index should be masked
|
||||
let value_in_mask = scope.create_local(mask.item());
|
||||
gpu!(scope, value_in_mask = mask[index_mask]);
|
||||
let masked = scope.create_local(Elem::Bool);
|
||||
let zero = scope.zero(value_in_mask.item());
|
||||
if self.reversed {
|
||||
gpu!(scope, masked = value_in_mask == zero);
|
||||
} else {
|
||||
gpu!(scope, masked = value_in_mask != zero);
|
||||
}
|
||||
|
||||
// Assign a value at the index
|
||||
let used_value = scope.create_local(output.item());
|
||||
gpu!(scope, if(masked).then(|scope| {
|
||||
M::mask(scope, used_value, value, index_input );
|
||||
}).else(|scope| {
|
||||
gpu!(scope, used_value = input[index_input]);
|
||||
}));
|
||||
gpu!(scope, output[id] = used_value);
|
||||
}
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> input: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> value: {{ elem }};
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> mask: array<u32>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(4)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let dim = info[0];
|
||||
var index_input = 0u;
|
||||
var index_mask = 0u;
|
||||
|
||||
for (var i = 1u; i <= dim; i++) {
|
||||
let stride_input = info[i];
|
||||
let stride_mask = info[i + dim];
|
||||
let stride_output = info[i + 2u * dim];
|
||||
let shape_input = info[i + 3u * dim];
|
||||
let shape_mask = info[i + 4u * dim];
|
||||
|
||||
index_input += id / stride_output % shape_input * stride_input;
|
||||
index_mask += id / stride_output % shape_mask * stride_mask;
|
||||
}
|
||||
|
||||
|
||||
if mask[index_mask] != 0u {
|
||||
output[id] = value;
|
||||
} else {
|
||||
output[id] = input[index_input];
|
||||
}
|
||||
}
|
|
@ -1,44 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> input: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> value: {{ elem }};
|
||||
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> mask: array<u32>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let dim = info[0];
|
||||
var index_input = 0u;
|
||||
var index_mask = 0u;
|
||||
|
||||
for (var i = 1u; i <= dim; i++) {
|
||||
let stride_input = info[i];
|
||||
let stride_mask = info[i + dim];
|
||||
let shape_input = info[i + 2u * dim];
|
||||
let shape_mask = info[i + 3u * dim];
|
||||
|
||||
index_input += id / stride_input % shape_input * stride_input;
|
||||
index_mask += id / stride_input % shape_mask * stride_mask;
|
||||
}
|
||||
|
||||
if mask[index_mask] != 0u {
|
||||
input[index_input] = value;
|
||||
}
|
||||
}
|
|
@ -1,55 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> input: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> value: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> mask: array<u32>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(4)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let dim = info[0];
|
||||
var index_input = 0u;
|
||||
var index_value = 0u;
|
||||
var index_mask = 0u;
|
||||
|
||||
for (var i = 1u; i <= dim; i++) {
|
||||
let stride_input = info[i];
|
||||
let stride_value = info[i + dim];
|
||||
let stride_mask = info[i + 2u * dim];
|
||||
let stride_output = info[i + 3u * dim];
|
||||
|
||||
let shape_input = info[i + 4u * dim];
|
||||
let shape_value = info[i + 5u * dim];
|
||||
let shape_mask = info[i + 6u * dim];
|
||||
|
||||
index_input += id / stride_output % shape_input * stride_input;
|
||||
index_value += id / stride_output % shape_value * stride_value;
|
||||
index_mask += id / stride_output % shape_mask * stride_mask;
|
||||
}
|
||||
|
||||
if mask[index_mask] != 0u {
|
||||
output[id] = value[index_value];
|
||||
} else {
|
||||
output[id] = input[index_input];
|
||||
}
|
||||
}
|
|
@ -1,58 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> input: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> value: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> mask: array<u32>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let dim = info[0];
|
||||
let reverse = info[6u * dim + 1u];
|
||||
|
||||
var index_input = 0u;
|
||||
var index_value = 0u;
|
||||
var index_mask = 0u;
|
||||
|
||||
for (var i = 1u; i <= dim; i++) {
|
||||
let stride_input = info[i];
|
||||
let stride_value = info[i + dim];
|
||||
let stride_mask = info[i + 2u * dim];
|
||||
|
||||
let shape_input = info[i + 3u * dim];
|
||||
let shape_value = info[i + 4u * dim];
|
||||
let shape_mask = info[i + 5u * dim];
|
||||
|
||||
index_input += id / stride_input % shape_input * stride_input;
|
||||
index_value += id / stride_input % shape_value * stride_value;
|
||||
index_mask += id / stride_input % shape_mask * stride_mask;
|
||||
}
|
||||
|
||||
var condition = mask[index_mask] != 0u;
|
||||
|
||||
if reverse == 1u {
|
||||
condition = !condition;
|
||||
}
|
||||
|
||||
if condition {
|
||||
input[index_input] = value[index_value];
|
||||
} else {
|
||||
input[index_input] = input[index_input];
|
||||
}
|
||||
}
|
|
@ -484,6 +484,11 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
|
|||
rhs: self.compile_variable(op.rhs),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::NotEqual(op) => wgsl::Instruction::NotEqual {
|
||||
lhs: self.compile_variable(op.lhs),
|
||||
rhs: self.compile_variable(op.rhs),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Assign(op) => wgsl::Instruction::Assign {
|
||||
input: self.compile_variable(op.input),
|
||||
out: self.compile_variable(op.out),
|
||||
|
|
|
@ -148,6 +148,11 @@ pub enum Instruction {
|
|||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
NotEqual {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Stride {
|
||||
dim: Variable,
|
||||
position: usize,
|
||||
|
@ -289,6 +294,7 @@ impl Display for Instruction {
|
|||
Instruction::Greater { lhs, rhs, out } => comparison(lhs, rhs, out, ">", f),
|
||||
Instruction::LowerEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "<=", f),
|
||||
Instruction::GreaterEqual { lhs, rhs, out } => comparison(lhs, rhs, out, ">=", f),
|
||||
Instruction::NotEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "!=", f),
|
||||
Instruction::Assign { input, out } => match out.item() {
|
||||
Item::Vec4(elem) => {
|
||||
let input0 = input.index(0);
|
||||
|
|
Loading…
Reference in New Issue