Refactor wgpu max pooling (#1398)

This commit is contained in:
Nathaniel Simard 2024-03-04 13:23:11 -05:00 committed by GitHub
parent 046d975b76
commit efbe818465
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1000 additions and 400 deletions

View File

@ -53,6 +53,10 @@ harness = false
name = "binary"
harness = false
[[bench]]
name = "max_pool2d"
harness = false
[[bench]]
name = "matmul"
harness = false

View File

@ -0,0 +1,60 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, module::max_pool2d, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
pub struct MaxPool2dBenchmark<B: Backend> {
shape: Shape<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
device: B::Device,
}
impl<B: Backend> Benchmark for MaxPool2dBenchmark<B> {
type Args = Tensor<B, 4>;
fn name(&self) -> String {
"max_pool2d".into()
}
fn shapes(&self) -> Vec<Vec<usize>> {
vec![self.shape.dims.into()]
}
fn execute(&self, x: Self::Args) {
max_pool2d(
x,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
);
}
fn prepare(&self) -> Self::Args {
Tensor::random(self.shape.clone(), Distribution::Default, &self.device)
}
fn sync(&self) {
B::sync(&self.device)
}
}
#[allow(dead_code)]
fn bench<B: Backend>(device: &B::Device, url: Option<&str>, token: Option<&str>) {
let benchmark = MaxPool2dBenchmark::<B> {
shape: [32, 32, 512, 512].into(),
kernel_size: [5, 5],
stride: [2, 2],
padding: [2, 2],
dilation: [2, 2],
device: device.clone(),
};
save::<B>(vec![run_benchmark(benchmark)], device, url, token).unwrap();
}
fn main() {
backend_comparison::bench_on_backend!();
}

View File

@ -110,6 +110,8 @@ pub(crate) enum BenchmarkValues {
Matmul,
#[strum(to_string = "unary")]
Unary,
#[strum(to_string = "max_pool2d")]
MaxPool2d,
}
pub fn execute() {

View File

@ -65,6 +65,36 @@ macro_rules! gpu {
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs && rhs
($scope:expr, $out:ident = $lhs:ident && $rhs:expr) => {
gpu!($scope, $out = and($lhs, $rhs))
};
// out = and(lhs, rhs)
($scope:expr, $out:ident = and($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::And(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs || rhs
($scope:expr, $out:ident = $lhs:ident || $rhs:expr) => {
gpu!($scope, $out = or($lhs, $rhs))
};
// out = or(lhs, rhs)
($scope:expr, $out:ident = or($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Or(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = !input
($scope:expr, $out:ident = !$input:expr) => {
gpu!($scope, $out = not($input))
};
// out = not(input)
($scope:expr, $out:ident = not($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Not(
gpu!(unary $input, $out)
));
};
// out = lhs == rhs
($scope:expr, $out:ident = $lhs:ident == $rhs:expr) => {
gpu!($scope, $out = equal($lhs, $rhs))
@ -115,6 +145,18 @@ macro_rules! gpu {
gpu!(binary $lhs, $rhs, $out)
));
};
// out = max(lhs, rhs)
($scope:expr, $out:ident = max($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Max(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = min(lhs, rhs)
($scope:expr, $out:ident = min($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Min(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs[rhs]
($scope:expr, $out:ident = $lhs:ident[$rhs:expr]) => {
gpu!($scope, $out = index($lhs, $rhs))

View File

@ -48,6 +48,11 @@ pub enum Operator {
Modulo(BinaryOperator),
Index(BinaryOperator),
IndexAssign(BinaryOperator),
And(BinaryOperator),
Or(BinaryOperator),
Not(UnaryOperator),
Max(BinaryOperator),
Min(BinaryOperator),
}
/// All metadata that can be access in a shader.

View File

@ -36,6 +36,8 @@ impl Operation {
impl Operator {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
match self {
Operator::Max(op) => Operator::Max(op.vectorize(vectorization)),
Operator::Min(op) => Operator::Min(op.vectorize(vectorization)),
Operator::Add(op) => Operator::Add(op.vectorize(vectorization)),
Operator::Index(op) => Operator::Index(op.vectorize(vectorization)),
Operator::Sub(op) => Operator::Sub(op.vectorize(vectorization)),
@ -69,6 +71,9 @@ impl Operator {
}
Operator::Modulo(op) => Operator::Modulo(op.vectorize(vectorization)),
Operator::IndexAssign(op) => Operator::IndexAssign(op.vectorize(vectorization)),
Operator::And(op) => Operator::And(op.vectorize(vectorization)),
Operator::Or(op) => Operator::Or(op.vectorize(vectorization)),
Operator::Not(op) => Operator::Not(op.vectorize(vectorization)),
}
}
}

View File

@ -372,6 +372,16 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
fn compile_instruction(&mut self, value: gpu::Operator) -> wgsl::Instruction {
match value {
gpu::Operator::Max(op) => wgsl::Instruction::Max {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
gpu::Operator::Min(op) => wgsl::Instruction::Min {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
gpu::Operator::Add(op) => wgsl::Instruction::Add {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
@ -487,6 +497,20 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
gpu::Operator::And(op) => wgsl::Instruction::And {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
gpu::Operator::Or(op) => wgsl::Instruction::Or {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
gpu::Operator::Not(op) => wgsl::Instruction::Not {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
},
}
}

View File

@ -8,6 +8,16 @@ pub enum Instruction {
DeclareVariable {
var: Variable,
},
Max {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Min {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Add {
lhs: Variable,
rhs: Variable,
@ -158,6 +168,20 @@ pub enum Instruction {
end: Variable,
instructions: Vec<Instruction>,
},
And {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Or {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Not {
input: Variable,
out: Variable,
},
Loop {
instructions: Vec<Instruction>,
},
@ -173,6 +197,19 @@ impl Display for Instruction {
Instruction::Add { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} + {rhs};\n"))
}
Instruction::Min { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = min({lhs}, {rhs});\n"))
}
Instruction::Max { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = max({lhs}, {rhs});\n"))
}
Instruction::And { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} && {rhs};\n"))
}
Instruction::Or { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} || {rhs};\n"))
}
Instruction::Not { input, out } => f.write_fmt(format_args!("{out} = !{input};\n")),
Instruction::Index { lhs, rhs, out } => {
let item = out.item();
f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n"))

View File

@ -186,6 +186,32 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Max(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Min(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::And(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Or(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Not(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Index(op) => mark_binary(
op,
&mut local_tensor_ids_input,

View File

@ -8,7 +8,7 @@ use crate::{
kernel::{DynamicKernelSource, SourceTemplate},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime,
Runtime, RuntimeInt,
};
use burn_tensor::{ElementConversion, Shape};
use std::{marker::PhantomData, ops::Range};
@ -120,8 +120,6 @@ pub(crate) fn slice<R: Runtime, E: JitElement, const D1: usize, const D2: usize>
slice_on_output(tensor, output, indices)
}
type IntType<R> = <<R as Runtime>::Compiler as Compiler>::Int;
pub(crate) fn slice_on_output<R: Runtime, E: JitElement, const D1: usize, const D2: usize>(
tensor: JitTensor<R, E, D1>,
output: JitTensor<R, E, D1>,
@ -136,7 +134,7 @@ pub(crate) fn slice_on_output<R: Runtime, E: JitElement, const D1: usize, const
let kernel = SliceEagerKernel::new(D1);
execute_dynamic::<R, SliceEagerKernel<R, E>, IntType<R>>(
execute_dynamic::<R, SliceEagerKernel<R, E>, RuntimeInt<R>>(
&[EagerHandle::new(
&tensor.handle,
&tensor.strides,

View File

@ -7,7 +7,7 @@ use crate::{
element::JitElement,
kernel::{DynamicKernelSource, SourceTemplate},
tensor::JitTensor,
Runtime,
Runtime, RuntimeInt,
};
use burn_tensor::ElementConversion;
use std::{marker::PhantomData, ops::Range};
@ -121,8 +121,6 @@ impl<R: Runtime, E: JitElement> DynamicKernelSource for SliceAssignEagerKernel<R
}
}
type IntType<R> = <<R as Runtime>::Compiler as Compiler>::Int;
pub(crate) fn slice_assign<R: Runtime, E: JitElement, const D1: usize, const D2: usize>(
tensor: JitTensor<R, E, D1>,
indices: [Range<usize>; D2],
@ -141,7 +139,7 @@ pub(crate) fn slice_assign<R: Runtime, E: JitElement, const D1: usize, const D2:
let kernel = SliceAssignEagerKernel::new(D1);
execute_dynamic::<R, SliceAssignEagerKernel<R, E>, IntType<R>>(
execute_dynamic::<R, SliceAssignEagerKernel<R, E>, RuntimeInt<R>>(
&[
EagerHandle::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
EagerHandle::new(&value.handle, &value.strides, &value.shape.dims),

View File

@ -1,47 +1,296 @@
use burn_tensor::{ops::conv::calculate_pool_output_size, ElementConversion, Shape};
use std::marker::PhantomData;
use crate::{
compute::StaticKernel,
element::JitElement,
kernel::{
self, elemwise_workgroup,
pool::{build_output_and_info_pool2d, build_pool2d_info},
KernelSettings, WORKGROUP_DEFAULT,
codegen::{
dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
InputInfo, OutputInfo, WorkgroupLaunch,
},
kernel_wgsl,
element::JitElement,
kernel::{DynamicKernelSource, SourceTemplate},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime,
Runtime, RuntimeInt,
};
kernel_wgsl!(MaxPool2d, "../../template/pool/max_pool2d.wgsl");
kernel_wgsl!(
MaxPool2dWithIndicesBackward,
"../../template/pool/max_pool2d_with_indices_backward.wgsl"
);
kernel_wgsl!(
MaxPool2dWithIndices,
"../../template/pool/max_pool2d_with_indices.wgsl"
);
pub(crate) fn max_pool2d<R: Runtime, E: JitElement>(
x: JitTensor<R, E, 4>,
#[derive(new)]
struct MaxPool2dEagerKernel<R: Runtime, E: JitElement> {
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> JitTensor<R, E, 4> {
let (info_handle, output) =
build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation);
let kernel = StaticKernel::<
KernelSettings<MaxPool2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
output.shape.num_elements(),
WORKGROUP_DEFAULT,
));
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}
x.client
.execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]);
#[derive(new)]
struct MaxPool2dWithIndicesEagerKernel<R: Runtime, E: JitElement> {
kernel_size: [usize; 2],
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}
output
struct MaxPool2dComputeShader {
x: Variable,
output: Variable,
kernel_size: [usize; 2],
indices: Option<Variable>,
}
impl MaxPool2dComputeShader {
fn expand(self, scope: &mut Scope) {
let x = self.x;
let output = self.output;
let id = Variable::Id;
let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt);
let input_stride_2 = scope.create_local(Elem::UInt);
let input_stride_3 = scope.create_local(Elem::UInt);
let input_shape_2 = scope.create_local(Elem::UInt);
let input_shape_3 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
gpu!(scope, input_stride_0 = stride(x, 0u32));
gpu!(scope, input_stride_1 = stride(x, 1u32));
gpu!(scope, input_stride_2 = stride(x, 2u32));
gpu!(scope, input_stride_3 = stride(x, 3u32));
gpu!(scope, input_shape_2 = shape(x, 2u32));
gpu!(scope, input_shape_3 = shape(x, 3u32));
gpu!(scope, output_stride_0 = stride(output, 0u32));
gpu!(scope, output_stride_1 = stride(output, 1u32));
gpu!(scope, output_stride_2 = stride(output, 2u32));
gpu!(scope, output_stride_3 = stride(output, 3u32));
gpu!(scope, output_shape_0 = shape(output, 0u32));
gpu!(scope, output_shape_1 = shape(output, 1u32));
gpu!(scope, output_shape_2 = shape(output, 2u32));
gpu!(scope, output_shape_3 = shape(output, 3u32));
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
let [kernel_size_0, kernel_size_1] = self.kernel_size;
let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
let oh = scope.create_local(Elem::UInt);
let ow = scope.create_local(Elem::UInt);
gpu!(scope, b = id / output_stride_0);
gpu!(scope, b = b % output_shape_0);
gpu!(scope, c = id / output_stride_1);
gpu!(scope, c = c % output_shape_1);
gpu!(scope, oh = id / output_stride_2);
gpu!(scope, oh = oh % output_shape_2);
gpu!(scope, ow = id / output_stride_3);
gpu!(scope, ow = ow % output_shape_3);
let tmp = scope.create_local(Elem::UInt);
let ih = scope.create_local(Elem::UInt);
let iw = scope.create_local(Elem::UInt);
let ih_pad = scope.create_local(Elem::UInt);
let iw_pad = scope.create_local(Elem::UInt);
let result = scope.create_local(x.item());
let cond = scope.create_local(Elem::Bool);
let cond_tmp = scope.create_local(Elem::Bool);
let index_input = scope.create_local(Elem::UInt);
let index_input_1 = scope.create_local(Elem::UInt);
let index_input_2 = scope.create_local(Elem::UInt);
let index_input_3 = scope.create_local(Elem::UInt);
let index_input_4 = scope.create_local(Elem::UInt);
let is_max = scope.create_local(Elem::Bool);
let max_val = scope.create_local(x.item());
let max_index = self.indices.map(|_| scope.create_local(Elem::UInt));
gpu!(scope, max_val = cast(-32767.0));
(0..kernel_size_0).for_each(|kh| {
gpu!(scope, ih = oh * pool_stride_0);
gpu!(scope, tmp = kh * dilation_0);
gpu!(scope, ih += tmp);
// Up
gpu!(scope, cond = ih < padding_0);
// Down
gpu!(scope, tmp = input_shape_2 + padding_0);
gpu!(scope, cond_tmp = ih >= tmp);
gpu!(scope, cond = cond || cond_tmp);
gpu!(scope, cond = !cond);
gpu!(scope, if (cond).then(|scope| {
(0..kernel_size_1).for_each(|kw| {
gpu!(scope, iw = ow * pool_stride_1);
gpu!(scope, tmp = kw * dilation_1);
gpu!(scope, iw = iw + tmp);
// Left
gpu!(scope, cond = iw < padding_1);
// Right
gpu!(scope, tmp = input_shape_3 + padding_1);
gpu!(scope, cond_tmp = iw >= tmp);
gpu!(scope, cond = cond || cond_tmp);
gpu!(scope, cond = !cond);
gpu!(scope, if (cond).then(|scope| {
gpu!(scope, ih_pad = ih - padding_0);
gpu!(scope, iw_pad = iw - padding_1);
gpu!(scope, index_input_1 = b * input_stride_0);
gpu!(scope, index_input_2 = c * input_stride_1);
gpu!(scope, index_input_3 = ih_pad * input_stride_2);
gpu!(scope, index_input_4 = iw_pad * input_stride_3);
gpu!(scope, index_input = index_input_1);
gpu!(scope, index_input += index_input_2);
gpu!(scope, index_input += index_input_3);
gpu!(scope, index_input += index_input_4);
gpu!(scope, result = x[index_input]);
gpu!(scope, is_max = result > max_val);
gpu!(scope, if(is_max).then(|scope|{
gpu!(scope, max_val = result);
if let Some(max_index) = max_index {
gpu!(scope, max_index = ih_pad * input_shape_2);
gpu!(scope, max_index += iw_pad);
}
}));
}));
});
}));
});
gpu!(scope, output[id] = max_val);
if let Some(indices) = self.indices {
let max_index = max_index.unwrap();
gpu!(scope, indices[id] = max_index);
}
}
}
impl<R: Runtime, E: JitElement> DynamicKernelSource for MaxPool2dEagerKernel<R, E> {
fn source(&self) -> crate::kernel::SourceTemplate {
let mut scope = Scope::root();
let item = E::gpu_elem().into();
let x = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
scope.write_global_custom(output);
MaxPool2dComputeShader {
x,
output,
kernel_size: self.kernel_size,
indices: None,
}
.expand(&mut scope);
let input = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let scalars = InputInfo::Scalar {
elem: Elem::UInt,
size: 6,
};
let output = OutputInfo::Array { item };
let info = CompilationInfo {
inputs: vec![input, scalars],
outputs: vec![output],
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!(
"{:?}k={:?}",
core::any::TypeId::of::<Self>(),
self.kernel_size,
)
}
}
impl<R: Runtime, E: JitElement> DynamicKernelSource for MaxPool2dWithIndicesEagerKernel<R, E> {
fn source(&self) -> crate::kernel::SourceTemplate {
let mut scope = Scope::root();
let item = E::gpu_elem().into();
let x = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let indices = Variable::GlobalOutputArray(1, Item::Scalar(Elem::Int));
scope.write_global_custom(output);
MaxPool2dComputeShader {
x,
output,
kernel_size: self.kernel_size,
indices: Some(indices),
}
.expand(&mut scope);
let input = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let scalars = InputInfo::Scalar {
elem: Elem::UInt,
size: 6,
};
let output = OutputInfo::Array { item };
let indices = OutputInfo::Array {
item: Item::Scalar(Elem::Int),
};
let info = CompilationInfo {
inputs: vec![input, scalars],
outputs: vec![output, indices],
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!(
"{:?}k={:?}",
core::any::TypeId::of::<Self>(),
self.kernel_size,
)
}
}
pub(crate) fn max_pool2d_with_indices<R: Runtime, E: JitElement, I: JitElement>(
@ -51,61 +300,106 @@ pub(crate) fn max_pool2d_with_indices<R: Runtime, E: JitElement, I: JitElement>(
padding: [usize; 2],
dilation: [usize; 2],
) -> (JitTensor<R, E, 4>, JitTensor<R, I, 4>) {
let (info_handle, output) =
build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation);
let indices = empty_device(x.client.clone(), x.device, output.shape.clone());
let [batch_size, channels, _, _] = x.shape.dims;
let kernel = StaticKernel::<
KernelSettings<MaxPool2dWithIndices, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
output.shape.num_elements(),
WORKGROUP_DEFAULT,
));
let size_0 = calculate_pool_output_size(
kernel_size[0],
stride[0],
padding[0],
dilation[0],
x.shape.dims[2],
);
let size_1 = calculate_pool_output_size(
kernel_size[1],
stride[1],
padding[1],
dilation[1],
x.shape.dims[3],
);
x.client.execute(
Box::new(kernel),
&[&x.handle, &output.handle, &indices.handle, &info_handle],
let shape_out = Shape::new([batch_size, channels, size_0, size_1]);
let output = empty_device(x.client.clone(), x.device.clone(), shape_out.clone());
let indices = empty_device(x.client.clone(), x.device.clone(), shape_out);
let kernel = MaxPool2dWithIndicesEagerKernel::new(kernel_size);
execute_dynamic::<R, MaxPool2dWithIndicesEagerKernel<R, E>, I>(
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
&[
EagerHandle::new(&output.handle, &output.strides, &output.shape.dims),
EagerHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
],
Some(&[
(stride[0] as i32).elem(),
(stride[1] as i32).elem(),
(dilation[0] as i32).elem(),
(dilation[1] as i32).elem(),
(padding[0] as i32).elem(),
(padding[1] as i32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
(output, indices)
}
pub(crate) fn max_pool2d_with_indices_backward<R: Runtime, E: JitElement, I: JitElement>(
pub(crate) fn max_pool2d<R: Runtime, E: JitElement>(
x: JitTensor<R, E, 4>,
grad: JitTensor<R, E, 4>,
indices: JitTensor<R, I, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> JitTensor<R, E, 4> {
let grad = kernel::into_contiguous(grad);
let indices = kernel::into_contiguous(indices);
let [batch_size, channels, _, _] = x.shape.dims;
let num_elems = x.shape.num_elements();
let buffer = x.client.empty(num_elems * core::mem::size_of::<E>());
let output = JitTensor::new(x.client.clone(), x.device.clone(), x.shape.clone(), buffer);
let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation);
let kernel = StaticKernel::<
KernelSettings<MaxPool2dWithIndicesBackward, E, I, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
output.shape.num_elements(),
WORKGROUP_DEFAULT,
));
x.client.execute(
Box::new(kernel),
&[&indices.handle, &grad.handle, &output.handle, &info_handle],
let size_0 = calculate_pool_output_size(
kernel_size[0],
stride[0],
padding[0],
dilation[0],
x.shape.dims[2],
);
let size_1 = calculate_pool_output_size(
kernel_size[1],
stride[1],
padding[1],
dilation[1],
x.shape.dims[3],
);
let shape_out = Shape::new([batch_size, channels, size_0, size_1]);
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
let kernel = MaxPool2dEagerKernel::new(kernel_size);
execute_dynamic::<R, MaxPool2dEagerKernel<R, E>, RuntimeInt<R>>(
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&[
(stride[0] as i32).elem(),
(stride[1] as i32).elem(),
(dilation[0] as i32).elem(),
(dilation[1] as i32).elem(),
(padding[0] as i32).elem(),
(padding[1] as i32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
output
}
#[cfg(test)]
mod tests {
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{module, ops::ModuleOps, Distribution, Tensor};
use burn_tensor::{module, Distribution, Tensor};
#[test]
pub fn max_pool2d_should_work_with_multiple_invocations() {
@ -153,58 +447,4 @@ mod tests {
.assert_approx_eq(&pooled_ref.into_data(), 3);
assert_eq!(indices.into_data(), indices_ref.into_data().convert());
}
#[test]
pub fn max_pool2d_with_indices_backward_should_work_with_multiple_invocations() {
let test_device = Default::default();
let tensor =
Tensor::<TestBackend, 4>::random([32, 32, 32, 32], Distribution::Default, &test_device);
let grad_output =
Tensor::<TestBackend, 4>::random([32, 32, 16, 16], Distribution::Default, &test_device);
let ref_device = Default::default();
let tensor_ref = Tensor::<ReferenceBackend, 4>::from_data(tensor.to_data(), &ref_device);
let grad_output_ref =
Tensor::<ReferenceBackend, 4>::from_data(grad_output.to_data(), &ref_device);
let kernel_size = [3, 3];
let stride = [2, 2];
let padding = [1, 1];
let dilation = [1, 1];
let (_, indices) =
module::max_pool2d_with_indices(tensor.clone(), kernel_size, stride, padding, dilation);
let (_, indices_ref) = module::max_pool2d_with_indices(
tensor_ref.clone(),
kernel_size,
stride,
padding,
dilation,
);
let grad = TestBackend::max_pool2d_with_indices_backward(
tensor.into_primitive(),
kernel_size,
stride,
padding,
dilation,
grad_output.into_primitive(),
indices.into_primitive(),
)
.x_grad;
let grad_ref = ReferenceBackend::max_pool2d_with_indices_backward(
tensor_ref.into_primitive(),
kernel_size,
stride,
padding,
dilation,
grad_output_ref.into_primitive(),
indices_ref.into_primitive(),
)
.x_grad;
Tensor::<TestBackend, 4>::from_primitive(grad)
.into_data()
.assert_approx_eq(
&Tensor::<ReferenceBackend, 4>::from_primitive(grad_ref).into_data(),
3,
);
}
}

View File

@ -0,0 +1,421 @@
use burn_tensor::ElementConversion;
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
kernel::{self, DynamicKernelSource, SourceTemplate},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime,
};
use std::marker::PhantomData;
#[derive(new)]
struct MaxPool2dWithIndicesBackwardEagerKernel<R: Runtime, E: JitElement> {
kernel_size: [usize; 2],
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}
struct MaxPool2dBackwardComputeShader {
indices: Variable,
grad: Variable,
output: Variable,
kernel_size: [usize; 2],
}
impl MaxPool2dBackwardComputeShader {
fn expand(self, scope: &mut Scope) {
let grad = self.grad;
let output = self.output;
let indices = self.indices;
let id = Variable::Id;
let grad_stride_0 = scope.create_local(Elem::UInt);
let grad_stride_1 = scope.create_local(Elem::UInt);
let grad_stride_2 = scope.create_local(Elem::UInt);
let grad_stride_3 = scope.create_local(Elem::UInt);
let grad_shape_2 = scope.create_local(Elem::UInt);
let grad_shape_3 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
gpu!(scope, grad_stride_0 = stride(grad, 0u32));
gpu!(scope, grad_stride_1 = stride(grad, 1u32));
gpu!(scope, grad_stride_2 = stride(grad, 2u32));
gpu!(scope, grad_stride_3 = stride(grad, 3u32));
gpu!(scope, grad_shape_2 = shape(grad, 2u32));
gpu!(scope, grad_shape_3 = shape(grad, 3u32));
gpu!(scope, output_stride_0 = stride(output, 0u32));
gpu!(scope, output_stride_1 = stride(output, 1u32));
gpu!(scope, output_stride_2 = stride(output, 2u32));
gpu!(scope, output_stride_3 = stride(output, 3u32));
gpu!(scope, output_shape_0 = shape(output, 0u32));
gpu!(scope, output_shape_1 = shape(output, 1u32));
gpu!(scope, output_shape_2 = shape(output, 2u32));
gpu!(scope, output_shape_3 = shape(output, 3u32));
let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
let ih = scope.create_local(Elem::UInt);
let iw = scope.create_local(Elem::UInt);
gpu!(scope, b = id / output_stride_0);
gpu!(scope, b = b % output_shape_0);
gpu!(scope, c = id / output_stride_1);
gpu!(scope, c = c % output_shape_1);
gpu!(scope, ih = id / output_stride_2);
gpu!(scope, ih = ih % output_shape_2);
gpu!(scope, iw = id / output_stride_3);
gpu!(scope, iw = iw % output_shape_3);
let index_current = scope.create_local(Elem::UInt);
let index_current_tmp = scope.create_local(Elem::UInt);
gpu!(scope, index_current = ih * output_stride_2);
gpu!(scope, index_current_tmp = iw * output_stride_3);
gpu!(scope, index_current += index_current_tmp);
let index_select = scope.create_local(Elem::Int);
let index_max = scope.create_local(Elem::UInt);
let is_max = scope.create_local(Elem::Bool);
let index = scope.create_local(Elem::UInt);
let index_tmp = scope.create_local(Elem::UInt);
let grad_accumulation = scope.zero(grad.item());
let result = scope.create_local(grad.item());
let (oh_start, oh_end, ow_start, ow_end) = self.loop_ranges(
scope,
ih,
iw,
grad_shape_2,
grad_shape_3,
output_stride_2,
output_stride_3,
);
gpu!(
scope,
range(oh_start, oh_end).for_each(|oh, scope| {
gpu!(
scope,
range(ow_start, ow_end).for_each(|ow, scope| {
gpu!(scope, index = b * grad_stride_0);
gpu!(scope, index_tmp = c * grad_stride_1);
gpu!(scope, index += index_tmp);
gpu!(scope, index_tmp = oh * grad_stride_2);
gpu!(scope, index += index_tmp);
gpu!(scope, index_tmp = ow * grad_stride_3);
gpu!(scope, index += index_tmp);
gpu!(scope, index_select = indices[index]);
gpu!(scope, index_max = cast(index_select));
gpu!(scope, is_max = index_max == index_current);
gpu!(scope, if(is_max).then(|scope|{
gpu!(scope, result = grad[index]);
gpu!(scope, grad_accumulation += result);
}));
})
);
})
);
gpu!(scope, output[id] = grad_accumulation);
}
#[allow(clippy::too_many_arguments)]
fn loop_ranges(
self,
scope: &mut Scope,
ih: Variable,
iw: Variable,
grad_shape_2: Variable,
grad_shape_3: Variable,
output_stride_2: Variable,
output_stride_3: Variable,
) -> (Variable, Variable, Variable, Variable) {
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
let [kernel_size_0, kernel_size_1] = self.kernel_size;
let signed_ih = scope.create_local(Elem::Int);
let signed_iw = scope.create_local(Elem::Int);
let signed_pool_stride_0 = scope.create_local(Elem::Int);
let signed_pool_stride_1 = scope.create_local(Elem::Int);
let signed_dilation_0 = scope.create_local(Elem::Int);
let signed_dilation_1 = scope.create_local(Elem::Int);
let signed_padding_0 = scope.create_local(Elem::Int);
let signed_padding_1 = scope.create_local(Elem::Int);
let signed_kernel_size_0 = scope.create_local(Elem::Int);
let signed_kernel_size_1 = scope.create_local(Elem::Int);
gpu!(scope, signed_pool_stride_0 = cast(pool_stride_0));
gpu!(scope, signed_pool_stride_1 = cast(pool_stride_1));
gpu!(scope, signed_dilation_0 = cast(dilation_0));
gpu!(scope, signed_dilation_1 = cast(dilation_1));
gpu!(scope, signed_padding_0 = cast(padding_0));
gpu!(scope, signed_padding_1 = cast(padding_1));
gpu!(scope, signed_kernel_size_0 = cast(kernel_size_0));
gpu!(scope, signed_kernel_size_1 = cast(kernel_size_1));
gpu!(scope, signed_ih = cast(ih));
gpu!(scope, signed_iw = cast(iw));
let kms_0 = scope.create_local(Elem::Int);
let kms_1 = scope.create_local(Elem::Int);
gpu!(scope, kms_0 = signed_dilation_0 * signed_kernel_size_0);
gpu!(scope, kms_0 = kms_0 - signed_pool_stride_0);
gpu!(scope, kms_1 = signed_dilation_1 * signed_kernel_size_1);
gpu!(scope, kms_1 = kms_1 - signed_pool_stride_1);
let oh_start_tmp = scope.create_local(Elem::Int);
let ow_start_tmp = scope.create_local(Elem::Int);
gpu!(scope, oh_start_tmp = signed_ih + signed_padding_0);
gpu!(scope, oh_start_tmp = oh_start_tmp - kms_0);
gpu!(scope, oh_start_tmp = oh_start_tmp / signed_pool_stride_0);
gpu!(scope, ow_start_tmp = signed_iw + signed_padding_1);
gpu!(scope, ow_start_tmp = ow_start_tmp - kms_1);
gpu!(scope, ow_start_tmp = ow_start_tmp / signed_pool_stride_1);
gpu!(scope, oh_start_tmp = max(oh_start_tmp, 0i32));
gpu!(scope, ow_start_tmp = max(ow_start_tmp, 0i32));
let oh_start = scope.create_local(Elem::UInt);
let ow_start = scope.create_local(Elem::UInt);
gpu!(scope, oh_start = cast(oh_start_tmp));
gpu!(scope, ow_start = cast(ow_start_tmp));
let oh_end_tmp = scope.create_local(Elem::Int);
let ow_end_tmp = scope.create_local(Elem::Int);
gpu!(scope, oh_end_tmp = max(kms_0, 0i32));
gpu!(scope, ow_end_tmp = max(kms_1, 0i32));
let oh_end = scope.create_local(Elem::UInt);
let ow_end = scope.create_local(Elem::UInt);
let oh_end_limit = scope.create_local(Elem::UInt);
let ow_end_limit = scope.create_local(Elem::UInt);
gpu!(scope, oh_end = cast(oh_end_tmp));
gpu!(scope, ow_end = cast(ow_end_tmp));
gpu!(scope, oh_end = oh_end + oh_start);
gpu!(scope, oh_end_limit = grad_shape_2 - 1u32);
gpu!(scope, ow_end = ow_end + ow_start);
gpu!(scope, ow_end_limit = grad_shape_3 - 1u32);
gpu!(scope, oh_end = min(oh_end, oh_end_limit));
gpu!(scope, ow_end = min(ow_end, ow_end_limit));
let index_current = scope.create_local(Elem::UInt);
let index_current_tmp = scope.create_local(Elem::UInt);
gpu!(scope, index_current = ih * output_stride_2);
gpu!(scope, index_current_tmp = iw * output_stride_3);
gpu!(scope, index_current += index_current_tmp);
gpu!(scope, oh_end = oh_end + 1u32);
gpu!(scope, ow_end = ow_end + 1u32);
(oh_start, oh_end, ow_start, ow_end)
}
}
impl<R: Runtime, E: JitElement> DynamicKernelSource
for MaxPool2dWithIndicesBackwardEagerKernel<R, E>
{
fn source(&self) -> kernel::SourceTemplate {
let mut scope = Scope::root();
let item = E::gpu_elem().into();
let indices = Variable::GlobalInputArray(0, Item::Scalar(Elem::Int));
let grad = Variable::GlobalInputArray(1, item);
let output = Variable::GlobalOutputArray(0, item);
scope.write_global_custom(output);
MaxPool2dBackwardComputeShader {
indices,
grad,
output,
kernel_size: self.kernel_size,
}
.expand(&mut scope);
let indices = InputInfo::Array {
item: Item::Scalar(Elem::Int),
visibility: Visibility::Read,
};
let grad = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let scalars = InputInfo::Scalar {
elem: Elem::UInt,
size: 6,
};
let output = OutputInfo::Array { item };
let info = CompilationInfo {
inputs: vec![indices, grad, scalars],
outputs: vec![output],
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!(
"{:?}k={:?}",
core::any::TypeId::of::<Self>(),
self.kernel_size,
)
}
}
pub(crate) fn max_pool2d_with_indices_backward<R: Runtime, E: JitElement, I: JitElement>(
x: JitTensor<R, E, 4>,
grad: JitTensor<R, E, 4>,
indices: JitTensor<R, I, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> JitTensor<R, E, 4> {
let grad = kernel::into_contiguous(grad);
let indices = kernel::into_contiguous(indices);
let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone());
let kernel = MaxPool2dWithIndicesBackwardEagerKernel::new(kernel_size);
execute_dynamic::<R, MaxPool2dWithIndicesBackwardEagerKernel<R, E>, I>(
&[
EagerHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
EagerHandle::new(&grad.handle, &grad.strides, &grad.shape.dims),
],
&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&[
(stride[0] as i32).elem(),
(stride[1] as i32).elem(),
(dilation[0] as i32).elem(),
(dilation[1] as i32).elem(),
(padding[0] as i32).elem(),
(padding[1] as i32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
output
}
#[cfg(test)]
mod tests {
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{module, ops::ModuleOps, Distribution, Tensor};
#[test]
pub fn max_pool2d_with_indices_backward_should_work_with_multiple_invocations() {
let test_device = Default::default();
let tensor =
Tensor::<TestBackend, 4>::random([32, 32, 32, 32], Distribution::Default, &test_device);
let grad_output =
Tensor::<TestBackend, 4>::random([32, 32, 16, 16], Distribution::Default, &test_device);
let ref_device = Default::default();
let tensor_ref = Tensor::<ReferenceBackend, 4>::from_data(tensor.to_data(), &ref_device);
let grad_output_ref =
Tensor::<ReferenceBackend, 4>::from_data(grad_output.to_data(), &ref_device);
let kernel_size = [3, 3];
let stride = [2, 2];
let padding = [1, 1];
let dilation = [1, 1];
let (_, indices) =
module::max_pool2d_with_indices(tensor.clone(), kernel_size, stride, padding, dilation);
let (_, indices_ref) = module::max_pool2d_with_indices(
tensor_ref.clone(),
kernel_size,
stride,
padding,
dilation,
);
let grad = TestBackend::max_pool2d_with_indices_backward(
tensor.into_primitive(),
kernel_size,
stride,
padding,
dilation,
grad_output.into_primitive(),
indices.into_primitive(),
)
.x_grad;
let grad_ref = ReferenceBackend::max_pool2d_with_indices_backward(
tensor_ref.into_primitive(),
kernel_size,
stride,
padding,
dilation,
grad_output_ref.into_primitive(),
indices_ref.into_primitive(),
)
.x_grad;
Tensor::<TestBackend, 4>::from_primitive(grad)
.into_data()
.assert_approx_eq(
&Tensor::<ReferenceBackend, 4>::from_primitive(grad_ref).into_data(),
3,
);
}
}

View File

@ -2,8 +2,11 @@ mod adaptive_avg_pool2d;
mod avg_pool2d;
mod base;
mod max_pool2d;
mod max_pool2d_backward;
pub(crate) use adaptive_avg_pool2d::*;
pub use avg_pool2d::*;
pub(super) use base::*;
pub use max_pool2d::*;
pub(crate) use max_pool2d::*;
pub(crate) use max_pool2d_backward::*;

View File

@ -1,6 +1,9 @@
use crate::{codegen::Compiler, compute::JitAutotuneKey};
use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
/// Type alias to the runtime signed int element type.
pub type RuntimeInt<R> = <<R as Runtime>::Compiler as Compiler>::Int;
/// Runtime for the [just-in-time backend](crate::JitBackend).
pub trait Runtime: Send + Sync + 'static {
/// The compiler used to compile the inner representation into tokens.

View File

@ -1,84 +0,0 @@
@group(0)
@binding(0)
var<storage, read> x: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> info: array<u32, 24>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let input_stride_0 = info[0];
let input_stride_1 = info[1];
let input_stride_2 = info[2];
let input_stride_3 = info[3];
let input_shape_0 = info[4];
let input_shape_1 = info[5];
let input_shape_2 = info[6];
let input_shape_3 = info[7];
let output_stride_0 = info[8];
let output_stride_1 = info[9];
let output_stride_2 = info[10];
let output_stride_3 = info[11];
let output_shape_0 = info[12];
let output_shape_1 = info[13];
let output_shape_2 = info[14];
let output_shape_3 = info[15];
let kernel_size_0 = info[16];
let kernel_size_1 = info[17];
let pool_stride_0 = info[18];
let pool_stride_1 = info[19];
let padding_0 = info[20];
let padding_1 = info[21];
let dilation_0 = info[22];
let dilation_1 = info[23];
let b = id / output_stride_0 % output_shape_0;
let c = id / output_stride_1 % output_shape_1;
let oh = id / output_stride_2 % output_shape_2;
let ow = id / output_stride_3 % output_shape_3;
var max_val = -32767.0;
for (var kh = 0u; kh < kernel_size_0; kh++) {
let ih = oh * pool_stride_0 + kh * dilation_0;
// Padding
if ih < padding_0 || ih >= input_shape_2 + padding_0 {
continue;
}
for (var kw = 0u; kw < kernel_size_1; kw++) {
let iw = ow * pool_stride_1 + kw * dilation_1;
// Padding
if iw < padding_1 || iw >= input_shape_3 + padding_1 {
continue;
}
// Correct indexes for padding
let ih_pad = ih - padding_0;
let iw_pad = iw - padding_1;
let index_input = b * input_stride_0 + c * input_stride_1 + ih_pad * input_stride_2 + iw_pad * input_stride_3;
let val = x[index_input];
max_val = max(max_val, val);
}
}
output[id] = max_val;
}

View File

@ -1,94 +0,0 @@
@group(0)
@binding(0)
var<storage, read> x: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> indices: array<{{ int }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32, 24>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let input_stride_0 = info[0];
let input_stride_1 = info[1];
let input_stride_2 = info[2];
let input_stride_3 = info[3];
let input_shape_0 = info[4];
let input_shape_1 = info[5];
let input_shape_2 = info[6];
let input_shape_3 = info[7];
let output_stride_0 = info[8];
let output_stride_1 = info[9];
let output_stride_2 = info[10];
let output_stride_3 = info[11];
let output_shape_0 = info[12];
let output_shape_1 = info[13];
let output_shape_2 = info[14];
let output_shape_3 = info[15];
let kernel_size_0 = info[16];
let kernel_size_1 = info[17];
let pool_stride_0 = info[18];
let pool_stride_1 = info[19];
let padding_0 = info[20];
let padding_1 = info[21];
let dilation_0 = info[22];
let dilation_1 = info[23];
let b = id / output_stride_0 % output_shape_0;
let c = id / output_stride_1 % output_shape_1;
let oh = id / output_stride_2 % output_shape_2;
let ow = id / output_stride_3 % output_shape_3;
var max_val = -32767.0;
var index = 0u;
for (var kh = 0u; kh < kernel_size_0; kh++) {
let ih = oh * pool_stride_0 + kh * dilation_0;
// Padding
if ih < padding_0 || ih >= input_shape_2 + padding_0 {
continue;
}
for (var kw = 0u; kw < kernel_size_1; kw++) {
let iw = ow * pool_stride_1 + kw * dilation_1;
// Padding
if iw < padding_1 || iw >= input_shape_3 + padding_1 {
continue;
}
// Correct indexes for padding
let ih_pad = ih - padding_0;
let iw_pad = iw - padding_1;
let index_input = b * input_stride_0 + c * input_stride_1 + ih_pad * input_stride_2 + iw_pad * input_stride_3;
let val = x[index_input];
if max_val < val {
max_val = val;
index = ih_pad * input_shape_2 + iw_pad;
}
}
}
output[id] = max_val;
indices[id] = {{ int }}(index);
}

View File

@ -1,90 +0,0 @@
@group(0)
@binding(0)
var<storage, read> indices: array<{{ int }}>;
@group(0)
@binding(1)
var<storage, read> grad: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32, 24>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let input_stride_0 = info[0];
let input_stride_1 = info[1];
let input_stride_2 = info[2];
let input_stride_3 = info[3];
let input_shape_0 = info[4];
let input_shape_1 = info[5];
let input_shape_2 = info[6];
let input_shape_3 = info[7];
let grad_stride_0 = info[8];
let grad_stride_1 = info[9];
let grad_stride_2 = info[10];
let grad_stride_3 = info[11];
let grad_shape_0 = info[12];
let grad_shape_1 = info[13];
let grad_shape_2 = info[14];
let grad_shape_3 = info[15];
let kernel_size_0 = info[16];
let kernel_size_1 = info[17];
let pool_stride_0 = info[18];
let pool_stride_1 = info[19];
let padding_0 = info[20];
let padding_1 = info[21];
let dilation_0 = info[22];
let dilation_1 = info[23];
let b = id / input_stride_0 % input_shape_0;
let c = id / input_stride_1 % input_shape_1;
let ih = id / input_stride_2 % input_shape_2;
let iw = id / input_stride_3 % input_shape_3;
// The maximum number of overlapping filters that may content the current index.
let kms_0 = i32(kernel_size_0 * dilation_0) - i32(pool_stride_0);
let kms_1 = i32(kernel_size_1 * dilation_1) - i32(pool_stride_1);
let oh_start_tmp = (i32(ih + padding_0) - kms_0) / i32(pool_stride_0);
let ow_start_tmp = (i32(iw + padding_1) - kms_1) / i32(pool_stride_1);
let oh_start = u32(max(oh_start_tmp, 0));
let ow_start = u32(max(ow_start_tmp, 0));
let oh_end = min(u32(max(kms_0, 0)) + oh_start, grad_shape_2 - 1u);
let ow_end = min(u32(max(kms_1, 0)) + ow_start, grad_shape_3 - 1u);
let index_current = ih * input_stride_2 + iw * input_stride_3;
var grad_acc = 0.0;
// We iterate over each potentially resulting overlapping filters and check
// if their max index is the current one.
for (var oh = oh_start; oh <= oh_end; oh++) {
for (var ow = ow_start; ow <= ow_end; ow++) {
let index = b * grad_stride_0 + c * grad_stride_1 + oh * grad_stride_2 + ow * grad_stride_3;
let index_max = u32(indices[index]);
if index_max == index_current {
grad_acc += grad[index];
}
}
}
output[id] = grad_acc;
}