Migrate/jit/mask (#1456)

This commit is contained in:
Louis Fortier-Dubois 2024-03-12 12:43:05 -04:00 committed by GitHub
parent fa0dfec3c7
commit 278fcb3dad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 373 additions and 303 deletions

View File

@ -145,6 +145,16 @@ macro_rules! gpu {
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs != rhs
($scope:expr, $out:ident = $lhs:ident != $rhs:expr) => {
gpu!($scope, $out = not_equal($lhs, $rhs))
};
// out = not_equal(lhs, rhs)
($scope:expr, $out:ident = not_equal($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::NotEqual(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs > rhs
($scope:expr, $out:ident = $lhs:ident > $rhs:expr) => {
gpu!($scope, $out = greater($lhs, $rhs))

View File

@ -39,6 +39,7 @@ pub enum Operator {
Erf(UnaryOperator),
Recip(UnaryOperator),
Equal(BinaryOperator),
NotEqual(BinaryOperator),
Lower(BinaryOperator),
Clamp(ClampOperator),
Greater(BinaryOperator),

View File

@ -55,6 +55,7 @@ impl Operator {
Operator::Erf(op) => Operator::Erf(op.vectorize(vectorization)),
Operator::Recip(op) => Operator::Recip(op.vectorize(vectorization)),
Operator::Equal(op) => Operator::Equal(op.vectorize(vectorization)),
Operator::NotEqual(op) => Operator::NotEqual(op.vectorize(vectorization)),
Operator::Lower(op) => Operator::Lower(op.vectorize(vectorization)),
Operator::Clamp(op) => Operator::Clamp(op.vectorize(vectorization)),
Operator::Greater(op) => Operator::Greater(op.vectorize(vectorization)),

View File

@ -311,6 +311,11 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::NotEqual(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Sqrt(op) => mark_unary(
op,
&mut local_tensor_ids_input,

View File

@ -105,7 +105,6 @@ impl IntoContiguousShader {
let offset_input = scope.zero(Elem::UInt);
// Batch offset for the lhs & rhs matrices.
IndexOffsetGlobalWithLayout {
tensors: vec![tensor],
indexes: vec![offset_input],

View File

@ -1,15 +1,12 @@
use crate::{
compute::StaticKernel,
codegen::{EagerHandle, Execution, WorkgroupLaunch},
element::JitElement,
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::JitTensor,
Runtime,
};
kernel_wgsl!(MaskFill, "../../template/mask/fill.wgsl");
kernel_wgsl!(MaskFillInplace, "../../template/mask/fill_inplace.wgsl");
use super::{MaskFill, MaskInplaceEagerKernel, MaskReadOnlyEagerKernel};
#[derive(Clone, Copy, Debug)]
/// Define how to run the mask fill kernel.
@ -37,58 +34,52 @@ pub fn mask_fill<R: Runtime, E: JitElement, const D: usize>(
}
}
fn mask_fill_readonly<R: Runtime, E: JitElement, const D: usize>(
input: JitTensor<R, E, D>,
mask: JitTensor<R, u32, D>,
value: E,
) -> JitTensor<R, E, D> {
let num_elems = input.shape.num_elements();
fn mask_fill_readonly<R: Runtime, EI: JitElement, EM: JitElement, const D: usize>(
input: JitTensor<R, EI, D>,
mask: JitTensor<R, EM, D>,
value: EI,
) -> JitTensor<R, EI, D> {
let client = input.client.clone();
let kernel = MaskReadOnlyEagerKernel::<MaskFill, R, EI, EM>::new(false);
let output = empty_device(
input.client.clone(),
input.device.clone(),
input.shape.clone(),
);
let value_handle = output.client.create(E::as_bytes(&[value]));
let kernel = StaticKernel::<
KernelSettings<MaskFill, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let info = build_info(&[&input, &mask, &output]);
let info_handle = input.client.create(bytemuck::cast_slice(&info));
input.client.execute(
Box::new(kernel),
&[
&input.handle,
&value_handle,
&mask.handle,
Execution::start(kernel, client)
.inputs(&[
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims),
])
.outputs(&[EagerHandle::new(
&output.handle,
&info_handle,
],
);
&output.strides,
&output.shape.dims,
)])
.with_scalars(&[value])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}
fn mask_fill_inplace<R: Runtime, E: JitElement, const D: usize>(
input: JitTensor<R, E, D>,
mask: JitTensor<R, u32, D>,
value: E,
) -> JitTensor<R, E, D> {
let num_elems = input.shape.num_elements();
let value_handle = input.client.create(E::as_bytes(&[value]));
let kernel = StaticKernel::<
KernelSettings<MaskFillInplace, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let info = build_info(&[&input, &mask]);
let info_handle = input.client.create(bytemuck::cast_slice(&info));
fn mask_fill_inplace<R: Runtime, EI: JitElement, EM: JitElement, const D: usize>(
input: JitTensor<R, EI, D>,
mask: JitTensor<R, EM, D>,
value: EI,
) -> JitTensor<R, EI, D> {
let kernel = MaskInplaceEagerKernel::<MaskFill, R, EI, EM>::new(false);
input.client.execute(
Box::new(kernel),
&[&input.handle, &value_handle, &mask.handle, &info_handle],
);
let client = input.client.clone();
Execution::start(kernel, client)
.inputs(&[
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims),
])
.with_scalars(&[value])
.execute(WorkgroupLaunch::Input { pos: 0 });
input
}

View File

@ -1,15 +1,12 @@
use crate::{
compute::StaticKernel,
codegen::{EagerHandle, Execution, WorkgroupLaunch},
element::JitElement,
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::JitTensor,
Runtime,
};
kernel_wgsl!(MaskWhere, "../../template/mask/where.wgsl");
kernel_wgsl!(MaskWhereInplace, "../../template/mask/where_inplace.wgsl");
use super::{MaskInplaceEagerKernel, MaskReadOnlyEagerKernel, MaskWhere};
#[derive(Clone, Copy, Debug)]
/// Define how to run the mask where kernel.
@ -40,63 +37,53 @@ pub fn mask_where<R: Runtime, E: JitElement, const D: usize>(
}
}
fn mask_where_readonly<R: Runtime, E: JitElement, const D: usize>(
input: JitTensor<R, E, D>,
mask: JitTensor<R, u32, D>,
value: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
let num_elems = input.shape.num_elements();
fn mask_where_readonly<R: Runtime, EI: JitElement, EM: JitElement, const D: usize>(
input: JitTensor<R, EI, D>,
mask: JitTensor<R, EM, D>,
value: JitTensor<R, EI, D>,
) -> JitTensor<R, EI, D> {
let client = input.client.clone();
let kernel = MaskReadOnlyEagerKernel::<MaskWhere, R, EI, EM>::new(false);
let output = empty_device(
input.client.clone(),
input.device.clone(),
input.shape.clone(),
);
let kernel = StaticKernel::<
KernelSettings<MaskWhere, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let info = build_info(&[&input, &value, &mask, &output]);
let info_handle = input.client.create(bytemuck::cast_slice(&info));
input.client.execute(
Box::new(kernel),
&[
&input.handle,
&value.handle,
&mask.handle,
Execution::start(kernel, client)
.inputs(&[
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims),
EagerHandle::new(&value.handle, &value.strides, &value.shape.dims),
])
.outputs(&[EagerHandle::new(
&output.handle,
&info_handle,
],
);
&output.strides,
&output.shape.dims,
)])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}
fn mask_where_inplace<R: Runtime, E: JitElement, const D: usize>(
input: JitTensor<R, E, D>,
mask: JitTensor<R, u32, D>,
value: JitTensor<R, E, D>,
fn mask_where_inplace<R: Runtime, EI: JitElement, EM: JitElement, const D: usize>(
input: JitTensor<R, EI, D>,
mask: JitTensor<R, EM, D>,
value: JitTensor<R, EI, D>,
reverse: bool,
) -> JitTensor<R, E, D> {
let kernel = StaticKernel::<
KernelSettings<MaskWhereInplace, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
input.shape.num_elements(),
WORKGROUP_DEFAULT,
));
let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let mut info = build_info(&[&input, &value, &mask]);
info.push(match reverse {
true => 1,
false => 0,
});
let info_handle = input.client.create(bytemuck::cast_slice(&info));
) -> JitTensor<R, EI, D> {
let kernel = MaskInplaceEagerKernel::<MaskWhere, R, EI, EM>::new(reverse);
input.client.execute(
Box::new(kernel),
&[&input.handle, &value.handle, &mask.handle, &info_handle],
);
let client = input.client.clone();
Execution::start(kernel, client)
.inputs(&[
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims),
EagerHandle::new(&value.handle, &value.strides, &value.shape.dims),
])
.execute(WorkgroupLaunch::Input { pos: 0 });
input
}

View File

@ -1,8 +1,10 @@
mod base;
mod mask_fill;
mod mask_where;
mod shader;
pub(crate) use base::*;
pub(crate) use shader::*;
pub use mask_fill::*;
pub use mask_where::*;

View File

@ -0,0 +1,271 @@
use std::marker::PhantomData;
use crate::{
codegen::{Compilation, CompilationInfo, CompilationSettings, InputInfo, OutputInfo},
gpu::{gpu, Elem, IndexOffsetGlobalWithLayout, Item, Scope, Variable, Visibility},
kernel::{DynamicKernelSource, SourceTemplate},
Compiler, JitElement, Runtime,
};
pub(crate) trait MaskStrategy: Send + Sync + 'static {
fn mask(
scope: &mut Scope,
masked_value: Variable,
value: Variable,
index: Variable,
) -> Variable;
fn value_info(value_item: Item) -> InputInfo;
fn value_variable(value_item: Item) -> Variable;
}
pub(crate) struct MaskFill;
impl MaskStrategy for MaskFill {
fn mask(
scope: &mut Scope,
masked_value: Variable,
value: Variable,
_index: Variable,
) -> Variable {
gpu!(scope, masked_value = value);
masked_value
}
fn value_info(value_item: Item) -> InputInfo {
InputInfo::Scalar {
elem: value_item.elem(),
size: 1,
}
}
fn value_variable(value_item: Item) -> Variable {
Variable::GlobalScalar(0, value_item.elem())
}
}
pub(crate) struct MaskWhere;
impl MaskStrategy for MaskWhere {
fn mask(
scope: &mut Scope,
masked_value: Variable,
value: Variable,
index: Variable,
) -> Variable {
gpu!(scope, masked_value = value[index]);
masked_value
}
fn value_info(value_item: Item) -> InputInfo {
InputInfo::Array {
item: value_item,
visibility: Visibility::Read,
}
}
fn value_variable(value_item: Item) -> Variable {
Variable::GlobalInputArray(2, value_item)
}
}
pub(crate) struct MaskShader<EI: JitElement, EM: JitElement, M: MaskStrategy> {
input: Variable,
mask: Variable,
value: Variable,
output: Variable,
reversed: bool,
_mask_strategy: PhantomData<M>,
_input_elem: PhantomData<EI>,
_mask_elem: PhantomData<EM>,
}
#[derive(new)]
pub(crate) struct MaskReadOnlyEagerKernel<
M: MaskStrategy,
R: Runtime,
EI: JitElement,
EM: JitElement,
> {
reversed: bool,
_mask_strategy: PhantomData<M>,
_runtime: PhantomData<R>,
_input_elem: PhantomData<EI>,
_mask_elem: PhantomData<EM>,
}
impl<M: MaskStrategy, R: Runtime, EI: JitElement, EM: JitElement> DynamicKernelSource
for MaskReadOnlyEagerKernel<M, R, EI, EM>
{
fn source(&self) -> crate::kernel::SourceTemplate {
let mut scope = Scope::root();
let tensor_item = EI::gpu_elem().into();
let mask_item = EM::gpu_elem().into();
let input = Variable::GlobalInputArray(0, tensor_item);
let mask = Variable::GlobalInputArray(1, mask_item);
let value = M::value_variable(tensor_item);
let output = Variable::GlobalOutputArray(0, tensor_item);
MaskShader::<EI, EM, M> {
input,
mask,
value,
output,
reversed: self.reversed,
_mask_strategy: PhantomData::<M>,
_input_elem: PhantomData::<EI>,
_mask_elem: PhantomData::<EM>,
}
.expand(&mut scope);
scope.write_global_custom(output);
let input = InputInfo::Array {
item: tensor_item,
visibility: Visibility::Read,
};
let mask = InputInfo::Array {
item: mask_item,
visibility: Visibility::Read,
};
let value = M::value_info(tensor_item);
let out = OutputInfo::Array { item: tensor_item };
let info = CompilationInfo {
inputs: vec![input, mask, value],
outputs: vec![out],
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!(
"{:?} rev={}",
core::any::TypeId::of::<Self>(),
self.reversed
)
}
}
#[derive(new)]
pub(crate) struct MaskInplaceEagerKernel<
M: MaskStrategy,
R: Runtime,
EI: JitElement,
EM: JitElement,
> {
reversed: bool,
_mask_strategy: PhantomData<M>,
_runtime: PhantomData<R>,
_input_elem: PhantomData<EI>,
_mask_elem: PhantomData<EM>,
}
impl<M: MaskStrategy, R: Runtime, EI: JitElement, EM: JitElement> DynamicKernelSource
for MaskInplaceEagerKernel<M, R, EI, EM>
{
fn source(&self) -> crate::kernel::SourceTemplate {
let mut scope = Scope::root();
let tensor_item = EI::gpu_elem().into();
let mask_item = EM::gpu_elem().into();
let input = Variable::GlobalInputArray(0, tensor_item);
let mask = Variable::GlobalInputArray(1, mask_item);
let value = M::value_variable(tensor_item);
MaskShader::<EI, EM, M> {
input,
mask,
value,
output: input,
reversed: self.reversed,
_mask_strategy: PhantomData::<M>,
_input_elem: PhantomData::<EI>,
_mask_elem: PhantomData::<EM>,
}
.expand(&mut scope);
let input = InputInfo::Array {
item: tensor_item,
visibility: Visibility::ReadWrite,
};
let mask = InputInfo::Array {
item: mask_item,
visibility: Visibility::Read,
};
let value = M::value_info(tensor_item);
let info = CompilationInfo {
inputs: vec![input, mask, value],
outputs: vec![],
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!(
"{:?} rev={}",
core::any::TypeId::of::<Self>(),
self.reversed
)
}
}
impl<EI: JitElement, EM: JitElement, M: MaskStrategy> MaskShader<EI, EM, M> {
pub(crate) fn expand(self, scope: &mut Scope) {
let id = Variable::Id;
let input = self.input;
let mask = self.mask;
let value = self.value;
let output = self.output;
let index_input = scope.zero(Elem::UInt);
let index_mask = scope.zero(Elem::UInt);
IndexOffsetGlobalWithLayout {
tensors: vec![input, mask],
indexes: vec![index_input, index_mask],
layout: output,
index_ref: id,
dim_start: 0u32.into(),
dim_end: Variable::Rank,
}
.expand(scope);
// Determine if index should be masked
let value_in_mask = scope.create_local(mask.item());
gpu!(scope, value_in_mask = mask[index_mask]);
let masked = scope.create_local(Elem::Bool);
let zero = scope.zero(value_in_mask.item());
if self.reversed {
gpu!(scope, masked = value_in_mask == zero);
} else {
gpu!(scope, masked = value_in_mask != zero);
}
// Assign a value at the index
let used_value = scope.create_local(output.item());
gpu!(scope, if(masked).then(|scope| {
M::mask(scope, used_value, value, index_input );
}).else(|scope| {
gpu!(scope, used_value = input[index_input]);
}));
gpu!(scope, output[id] = used_value);
}
}

View File

@ -1,51 +0,0 @@
@group(0)
@binding(0)
var<storage, read> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> value: {{ elem }};
@group(0)
@binding(2)
var<storage, read> mask: array<u32>;
@group(0)
@binding(3)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(4)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let dim = info[0];
var index_input = 0u;
var index_mask = 0u;
for (var i = 1u; i <= dim; i++) {
let stride_input = info[i];
let stride_mask = info[i + dim];
let stride_output = info[i + 2u * dim];
let shape_input = info[i + 3u * dim];
let shape_mask = info[i + 4u * dim];
index_input += id / stride_output % shape_input * stride_input;
index_mask += id / stride_output % shape_mask * stride_mask;
}
if mask[index_mask] != 0u {
output[id] = value;
} else {
output[id] = input[index_input];
}
}

View File

@ -1,44 +0,0 @@
@group(0)
@binding(0)
var<storage, read_write> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> value: {{ elem }};
@group(0)
@binding(2)
var<storage, read> mask: array<u32>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let dim = info[0];
var index_input = 0u;
var index_mask = 0u;
for (var i = 1u; i <= dim; i++) {
let stride_input = info[i];
let stride_mask = info[i + dim];
let shape_input = info[i + 2u * dim];
let shape_mask = info[i + 3u * dim];
index_input += id / stride_input % shape_input * stride_input;
index_mask += id / stride_input % shape_mask * stride_mask;
}
if mask[index_mask] != 0u {
input[index_input] = value;
}
}

View File

@ -1,55 +0,0 @@
@group(0)
@binding(0)
var<storage, read> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> value: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> mask: array<u32>;
@group(0)
@binding(3)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(4)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let dim = info[0];
var index_input = 0u;
var index_value = 0u;
var index_mask = 0u;
for (var i = 1u; i <= dim; i++) {
let stride_input = info[i];
let stride_value = info[i + dim];
let stride_mask = info[i + 2u * dim];
let stride_output = info[i + 3u * dim];
let shape_input = info[i + 4u * dim];
let shape_value = info[i + 5u * dim];
let shape_mask = info[i + 6u * dim];
index_input += id / stride_output % shape_input * stride_input;
index_value += id / stride_output % shape_value * stride_value;
index_mask += id / stride_output % shape_mask * stride_mask;
}
if mask[index_mask] != 0u {
output[id] = value[index_value];
} else {
output[id] = input[index_input];
}
}

View File

@ -1,58 +0,0 @@
@group(0)
@binding(0)
var<storage, read_write> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> value: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> mask: array<u32>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let dim = info[0];
let reverse = info[6u * dim + 1u];
var index_input = 0u;
var index_value = 0u;
var index_mask = 0u;
for (var i = 1u; i <= dim; i++) {
let stride_input = info[i];
let stride_value = info[i + dim];
let stride_mask = info[i + 2u * dim];
let shape_input = info[i + 3u * dim];
let shape_value = info[i + 4u * dim];
let shape_mask = info[i + 5u * dim];
index_input += id / stride_input % shape_input * stride_input;
index_value += id / stride_input % shape_value * stride_value;
index_mask += id / stride_input % shape_mask * stride_mask;
}
var condition = mask[index_mask] != 0u;
if reverse == 1u {
condition = !condition;
}
if condition {
input[index_input] = value[index_value];
} else {
input[index_input] = input[index_input];
}
}

View File

@ -484,6 +484,11 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
gpu::Operator::NotEqual(op) => wgsl::Instruction::NotEqual {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
gpu::Operator::Assign(op) => wgsl::Instruction::Assign {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),

View File

@ -148,6 +148,11 @@ pub enum Instruction {
rhs: Variable,
out: Variable,
},
NotEqual {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Stride {
dim: Variable,
position: usize,
@ -289,6 +294,7 @@ impl Display for Instruction {
Instruction::Greater { lhs, rhs, out } => comparison(lhs, rhs, out, ">", f),
Instruction::LowerEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "<=", f),
Instruction::GreaterEqual { lhs, rhs, out } => comparison(lhs, rhs, out, ">=", f),
Instruction::NotEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "!=", f),
Instruction::Assign { input, out } => match out.item() {
Item::Vec4(elem) => {
let input0 = input.index(0);