mirror of https://github.com/tracel-ai/burn.git
Migrate/jit/pooling (#1509)
* separate forward backward * refactor with pool strategy * refactor further * pooling refactored * refactoring for adaptive wip * wip adaptive * adaptive * delete some wgsl * avg pool backward * clippy * minor refactor
This commit is contained in:
parent
613e698007
commit
da5b0438ec
|
@ -275,6 +275,12 @@ macro_rules! gpu {
|
|||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = floor(input)
|
||||
($scope:expr, $out:ident = floor($input:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Floor(
|
||||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = ceil(input)
|
||||
($scope:expr, $out:ident = ceil($input:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Ceil(
|
||||
|
|
|
@ -36,6 +36,7 @@ pub enum Operator {
|
|||
Tanh(UnaryOperator),
|
||||
Powf(BinaryOperator),
|
||||
Sqrt(UnaryOperator),
|
||||
Floor(UnaryOperator),
|
||||
Ceil(UnaryOperator),
|
||||
Erf(UnaryOperator),
|
||||
Recip(UnaryOperator),
|
||||
|
|
|
@ -43,6 +43,8 @@ impl Operator {
|
|||
Operator::Sub(op) => Operator::Sub(op.vectorize(vectorization)),
|
||||
Operator::Mul(op) => Operator::Mul(op.vectorize(vectorization)),
|
||||
Operator::Div(op) => Operator::Div(op.vectorize(vectorization)),
|
||||
Operator::Floor(op) => Operator::Floor(op.vectorize(vectorization)),
|
||||
Operator::Ceil(op) => Operator::Ceil(op.vectorize(vectorization)),
|
||||
Operator::Abs(op) => Operator::Abs(op.vectorize(vectorization)),
|
||||
Operator::Exp(op) => Operator::Exp(op.vectorize(vectorization)),
|
||||
Operator::Log(op) => Operator::Log(op.vectorize(vectorization)),
|
||||
|
@ -52,7 +54,6 @@ impl Operator {
|
|||
Operator::Tanh(op) => Operator::Tanh(op.vectorize(vectorization)),
|
||||
Operator::Powf(op) => Operator::Powf(op.vectorize(vectorization)),
|
||||
Operator::Sqrt(op) => Operator::Sqrt(op.vectorize(vectorization)),
|
||||
Operator::Ceil(op) => Operator::Ceil(op.vectorize(vectorization)),
|
||||
Operator::Erf(op) => Operator::Erf(op.vectorize(vectorization)),
|
||||
Operator::Recip(op) => Operator::Recip(op.vectorize(vectorization)),
|
||||
Operator::Equal(op) => Operator::Equal(op.vectorize(vectorization)),
|
||||
|
|
|
@ -247,11 +247,6 @@ impl TraceBuilder {
|
|||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Ceil(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Log(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
|
@ -326,6 +321,16 @@ impl TraceBuilder {
|
|||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Floor(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Ceil(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Modulo(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
|
|
|
@ -1,100 +1,41 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
codegen::{execute_dynamic, EagerHandle, WorkgroupLaunch},
|
||||
element::JitElement,
|
||||
kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
};
|
||||
use burn_compute::server::Handle;
|
||||
use burn_tensor::Shape;
|
||||
|
||||
kernel_wgsl!(
|
||||
AdaptiveAvgPool2d,
|
||||
"../../template/pool/adaptive_avg_pool2d.wgsl"
|
||||
);
|
||||
kernel_wgsl!(
|
||||
AdaptiveAvgPool2dBackward,
|
||||
"../../template/pool/adaptive_avg_pool2d_backward.wgsl"
|
||||
);
|
||||
use super::AdaptivePool2dEagerKernel;
|
||||
|
||||
pub(crate) fn adaptive_avg_pool2d<R: Runtime, E: JitElement>(
|
||||
x: JitTensor<R, E, 4>,
|
||||
input: JitTensor<R, E, 4>,
|
||||
output_size: [usize; 2],
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let [batch_size, channels, _, _] = x.shape.dims;
|
||||
let [batch_size, channels, _, _] = input.shape.dims;
|
||||
|
||||
let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]);
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), output_shape);
|
||||
let output = empty_device(input.client.clone(), input.device.clone(), output_shape);
|
||||
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<AdaptiveAvgPool2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
output.shape.num_elements(),
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
let kernel = AdaptivePool2dEagerKernel::new();
|
||||
|
||||
let info_handle = build_info(&x, &output);
|
||||
x.client
|
||||
.execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) fn adaptive_avg_pool2d_backward<R: Runtime, E: JitElement>(
|
||||
x: JitTensor<R, E, 4>,
|
||||
out_grad: JitTensor<R, E, 4>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let output_shape = x.shape.clone();
|
||||
let num_elems = output_shape.num_elements();
|
||||
let output_buffer = x.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = JitTensor::new(
|
||||
x.client.clone(),
|
||||
x.device.clone(),
|
||||
output_shape,
|
||||
output_buffer,
|
||||
);
|
||||
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<AdaptiveAvgPool2dBackward, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
output.shape.num_elements(),
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
|
||||
let info_handle = build_info(&x, &out_grad);
|
||||
|
||||
x.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&out_grad.handle, &output.handle, &info_handle],
|
||||
execute_dynamic::<R, AdaptivePool2dEagerKernel<R, E>, E>(
|
||||
&[EagerHandle::new(
|
||||
&input.handle,
|
||||
&input.strides,
|
||||
&input.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
input.client,
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn build_info<R: Runtime, E: JitElement>(
|
||||
x: &JitTensor<R, E, 4>,
|
||||
output: &JitTensor<R, E, 4>,
|
||||
) -> Handle<R::Server> {
|
||||
let mut info: [u32; 16] = [0; 16];
|
||||
info[0] = x.strides[0] as u32;
|
||||
info[1] = x.strides[1] as u32;
|
||||
info[2] = x.strides[2] as u32;
|
||||
info[3] = x.strides[3] as u32;
|
||||
info[4] = x.shape.dims[0] as u32;
|
||||
info[5] = x.shape.dims[1] as u32;
|
||||
info[6] = x.shape.dims[2] as u32;
|
||||
info[7] = x.shape.dims[3] as u32;
|
||||
|
||||
info[8] = output.strides[0] as u32;
|
||||
info[9] = output.strides[1] as u32;
|
||||
info[10] = output.strides[2] as u32;
|
||||
info[11] = output.strides[3] as u32;
|
||||
info[12] = output.shape.dims[0] as u32;
|
||||
info[13] = output.shape.dims[1] as u32;
|
||||
info[14] = output.shape.dims[2] as u32;
|
||||
info[15] = output.shape.dims[3] as u32;
|
||||
|
||||
output.client.create(bytemuck::cast_slice(&info))
|
||||
}
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::JitElement,
|
||||
kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
};
|
||||
use burn_compute::server::Handle;
|
||||
|
||||
kernel_wgsl!(
|
||||
AdaptiveAvgPool2dBackward,
|
||||
"../../template/pool/adaptive_avg_pool2d_backward.wgsl"
|
||||
);
|
||||
|
||||
pub(crate) fn adaptive_avg_pool2d_backward<R: Runtime, E: JitElement>(
|
||||
x: JitTensor<R, E, 4>,
|
||||
out_grad: JitTensor<R, E, 4>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let output_shape = x.shape.clone();
|
||||
let num_elems = output_shape.num_elements();
|
||||
let output_buffer = x.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = JitTensor::new(
|
||||
x.client.clone(),
|
||||
x.device.clone(),
|
||||
output_shape,
|
||||
output_buffer,
|
||||
);
|
||||
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<AdaptiveAvgPool2dBackward, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
output.shape.num_elements(),
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
|
||||
let info_handle = build_info(&x, &out_grad);
|
||||
|
||||
x.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&out_grad.handle, &output.handle, &info_handle],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn build_info<R: Runtime, E: JitElement>(
|
||||
x: &JitTensor<R, E, 4>,
|
||||
output: &JitTensor<R, E, 4>,
|
||||
) -> Handle<R::Server> {
|
||||
let mut info: [u32; 16] = [0; 16];
|
||||
info[0] = x.strides[0] as u32;
|
||||
info[1] = x.strides[1] as u32;
|
||||
info[2] = x.strides[2] as u32;
|
||||
info[3] = x.strides[3] as u32;
|
||||
info[4] = x.shape.dims[0] as u32;
|
||||
info[5] = x.shape.dims[1] as u32;
|
||||
info[6] = x.shape.dims[2] as u32;
|
||||
info[7] = x.shape.dims[3] as u32;
|
||||
|
||||
info[8] = output.strides[0] as u32;
|
||||
info[9] = output.strides[1] as u32;
|
||||
info[10] = output.strides[2] as u32;
|
||||
info[11] = output.strides[3] as u32;
|
||||
info[12] = output.shape.dims[0] as u32;
|
||||
info[13] = output.shape.dims[1] as u32;
|
||||
info[14] = output.shape.dims[2] as u32;
|
||||
info[15] = output.shape.dims[3] as u32;
|
||||
|
||||
output.client.create(bytemuck::cast_slice(&info))
|
||||
}
|
|
@ -0,0 +1,229 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
codegen::{Compilation, CompilationInfo, CompilationSettings, InputInfo, OutputInfo},
|
||||
gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
kernel::{DynamicKernelSource, SourceTemplate},
|
||||
Compiler, JitElement, Runtime,
|
||||
};
|
||||
|
||||
pub(crate) struct AdaptivePool2dComputeShader<R: Runtime, E: JitElement> {
|
||||
input: Variable,
|
||||
output: Variable,
|
||||
_elem: PhantomData<E>,
|
||||
_runtime: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl<R: Runtime, E: JitElement> AdaptivePool2dComputeShader<R, E> {
|
||||
fn expand(self, scope: &mut Scope) {
|
||||
let input = self.input;
|
||||
let output = self.output;
|
||||
let id = Variable::Id;
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||
let input_stride_2 = scope.create_local(Elem::UInt);
|
||||
let input_stride_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let input_shape_0 = scope.create_local(Elem::UInt);
|
||||
let input_shape_1 = scope.create_local(Elem::UInt);
|
||||
let input_shape_2 = scope.create_local(Elem::UInt);
|
||||
let input_shape_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, input_stride_0 = stride(input, 0u32));
|
||||
gpu!(scope, input_stride_1 = stride(input, 1u32));
|
||||
gpu!(scope, input_stride_2 = stride(input, 2u32));
|
||||
gpu!(scope, input_stride_3 = stride(input, 3u32));
|
||||
|
||||
gpu!(scope, input_shape_0 = shape(input, 2u32));
|
||||
gpu!(scope, input_shape_1 = shape(input, 3u32));
|
||||
gpu!(scope, input_shape_2 = shape(input, 2u32));
|
||||
gpu!(scope, input_shape_3 = shape(input, 3u32));
|
||||
|
||||
gpu!(scope, output_stride_0 = stride(output, 0u32));
|
||||
gpu!(scope, output_stride_1 = stride(output, 1u32));
|
||||
gpu!(scope, output_stride_2 = stride(output, 2u32));
|
||||
gpu!(scope, output_stride_3 = stride(output, 3u32));
|
||||
|
||||
gpu!(scope, output_shape_0 = shape(output, 0u32));
|
||||
gpu!(scope, output_shape_1 = shape(output, 1u32));
|
||||
gpu!(scope, output_shape_2 = shape(output, 2u32));
|
||||
gpu!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let c = scope.create_local(Elem::UInt);
|
||||
let oh = scope.create_local(Elem::UInt);
|
||||
let ow = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, b = id / output_stride_0);
|
||||
gpu!(scope, b = b % output_shape_0);
|
||||
|
||||
gpu!(scope, c = id / output_stride_1);
|
||||
gpu!(scope, c = c % output_shape_1);
|
||||
|
||||
gpu!(scope, oh = id / output_stride_2);
|
||||
gpu!(scope, oh = oh % output_shape_2);
|
||||
|
||||
gpu!(scope, ow = id / output_stride_3);
|
||||
gpu!(scope, ow = ow % output_shape_3);
|
||||
|
||||
let ih_start = Self::start_index(scope, oh, output_shape_2, input_shape_2);
|
||||
let ih_end = Self::end_index(scope, oh, output_shape_2, input_shape_2);
|
||||
let iw_start = Self::start_index(scope, ow, output_shape_3, input_shape_3);
|
||||
let iw_end = Self::end_index(scope, ow, output_shape_3, input_shape_3);
|
||||
|
||||
let result = scope.create_local(input.item());
|
||||
|
||||
let index_input = scope.create_local(Elem::UInt);
|
||||
let index_input_0 = scope.create_local(Elem::UInt);
|
||||
let index_input_1 = scope.create_local(Elem::UInt);
|
||||
let index_input_2 = scope.create_local(Elem::UInt);
|
||||
let index_input_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, index_input_0 = b * input_stride_0);
|
||||
gpu!(scope, index_input_1 = c * input_stride_1);
|
||||
|
||||
let sum = scope.zero(output.item());
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
range(ih_start, ih_end).for_each(|ih, scope| {
|
||||
gpu!(
|
||||
scope,
|
||||
range(iw_start, iw_end).for_each(|iw, scope| {
|
||||
gpu!(scope, index_input_2 = ih * input_stride_2);
|
||||
gpu!(scope, index_input_3 = iw * input_stride_3);
|
||||
|
||||
gpu!(scope, index_input = index_input_0);
|
||||
gpu!(scope, index_input += index_input_1);
|
||||
gpu!(scope, index_input += index_input_2);
|
||||
gpu!(scope, index_input += index_input_3);
|
||||
|
||||
gpu!(scope, result = input[index_input]);
|
||||
|
||||
gpu!(scope, sum += result);
|
||||
})
|
||||
);
|
||||
})
|
||||
);
|
||||
|
||||
let count = scope.create_local(Elem::UInt);
|
||||
let count_tmp = scope.create_local(Elem::UInt);
|
||||
let count_float = scope.create_local(output.item());
|
||||
let avg = scope.create_local(output.item());
|
||||
|
||||
gpu!(scope, count = ih_end - ih_start);
|
||||
gpu!(scope, count_tmp = iw_end - iw_start);
|
||||
gpu!(scope, count *= count_tmp);
|
||||
|
||||
gpu!(scope, count_float = cast(count));
|
||||
gpu!(scope, avg = sum / count_float);
|
||||
gpu!(scope, output[id] = avg);
|
||||
}
|
||||
|
||||
fn start_index(
|
||||
scope: &mut Scope,
|
||||
output_size_index: Variable,
|
||||
output_size: Variable,
|
||||
input_size: Variable,
|
||||
) -> Variable {
|
||||
let numerator_float = scope.create_local(Elem::Float);
|
||||
let div = scope.create_local(Elem::Float);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, index = output_size_index * input_size);
|
||||
gpu!(scope, numerator_float = cast(index));
|
||||
gpu!(scope, div = cast(output_size));
|
||||
gpu!(scope, div = numerator_float / div);
|
||||
gpu!(scope, div = floor(div));
|
||||
gpu!(scope, index = cast(div));
|
||||
index
|
||||
}
|
||||
|
||||
fn end_index(
|
||||
scope: &mut Scope,
|
||||
output_size_index: Variable,
|
||||
output_size: Variable,
|
||||
input_size: Variable,
|
||||
) -> Variable {
|
||||
let numerator_float = scope.create_local(Elem::Float);
|
||||
let div = scope.create_local(Elem::Float);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let min = scope.create_local(Elem::Bool);
|
||||
let end_index = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, index = output_size_index + 1u32);
|
||||
gpu!(scope, index *= input_size);
|
||||
gpu!(scope, numerator_float = cast(index));
|
||||
gpu!(scope, div = cast(output_size));
|
||||
gpu!(scope, div = numerator_float / div);
|
||||
gpu!(scope, div = ceil(div));
|
||||
gpu!(scope, index = cast(div));
|
||||
|
||||
gpu!(scope, min = input_size < index);
|
||||
gpu!(scope, if(min).then(|scope|{
|
||||
gpu!(scope, end_index = input_size);
|
||||
}).else(|scope|{
|
||||
gpu!(scope, end_index = index);
|
||||
}));
|
||||
end_index
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct AdaptivePool2dEagerKernel<R: Runtime, E: JitElement> {
|
||||
_runtime: PhantomData<R>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<R: Runtime, E: JitElement> DynamicKernelSource for AdaptivePool2dEagerKernel<R, E> {
|
||||
fn source(&self) -> crate::kernel::SourceTemplate {
|
||||
let mut scope = Scope::root();
|
||||
let item = E::gpu_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
AdaptivePool2dComputeShader {
|
||||
input,
|
||||
output,
|
||||
_elem: PhantomData::<E>,
|
||||
_runtime: PhantomData::<R>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
let input = InputInfo::Array {
|
||||
item,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
|
||||
let output = OutputInfo::Array { item };
|
||||
|
||||
let info = CompilationInfo {
|
||||
inputs: vec![input],
|
||||
outputs: vec![output],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = CompilationSettings::default();
|
||||
let shader = Compilation::new(info).compile(settings);
|
||||
let shader = <R::Compiler as Compiler>::compile(shader);
|
||||
SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}", core::any::TypeId::of::<Self>(),)
|
||||
}
|
||||
}
|
|
@ -1,35 +1,72 @@
|
|||
use crate::{
|
||||
compute::{Kernel, StaticKernel},
|
||||
codegen::{dialect::gpu::Variable, execute_dynamic, EagerHandle, WorkgroupLaunch},
|
||||
element::JitElement,
|
||||
kernel::{
|
||||
self, elemwise_workgroup,
|
||||
pool::{build_output_and_info_pool2d, build_pool2d_info},
|
||||
KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT,
|
||||
},
|
||||
kernel_wgsl,
|
||||
gpu::{gpu, Elem, Item, Scope},
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
Runtime, RuntimeInt,
|
||||
};
|
||||
use burn_tensor::{ops::conv::calculate_pool_output_size, ElementConversion, Shape};
|
||||
use std::fmt::Debug;
|
||||
|
||||
kernel_wgsl!(AvgPool2dRaw, "../../template/pool/avg_pool2d.wgsl");
|
||||
kernel_wgsl!(
|
||||
AvgPool2dBackwardRaw,
|
||||
"../../template/pool/avg_pool2d_backward.wgsl"
|
||||
);
|
||||
use super::{Pool2dEagerKernel, PoolStrategy};
|
||||
|
||||
struct AvgPool2dBackward<const COUNT_INCLUDE_PAD: bool>;
|
||||
struct AvgPool2d<const COUNT_INCLUDE_PAD: bool>;
|
||||
|
||||
impl<const COUNT_INCLUDE_PAD: bool> StaticKernelSource for AvgPool2dBackward<COUNT_INCLUDE_PAD> {
|
||||
fn source() -> kernel::SourceTemplate {
|
||||
AvgPool2dBackwardRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}"))
|
||||
}
|
||||
#[derive(new, Debug, Clone)]
|
||||
struct AvgPool {
|
||||
kernel_size: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
}
|
||||
|
||||
impl<const COUNT_INCLUDE_PAD: bool> StaticKernelSource for AvgPool2d<COUNT_INCLUDE_PAD> {
|
||||
fn source() -> kernel::SourceTemplate {
|
||||
AvgPool2dRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}"))
|
||||
impl PoolStrategy for AvgPool {
|
||||
type Accumulator = (Variable, Variable);
|
||||
|
||||
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
|
||||
let sum = scope.create_local(item);
|
||||
let count = scope.create_local(Elem::UInt);
|
||||
if self.count_include_pad {
|
||||
let kernel_size: Variable = (self.kernel_size[0] * self.kernel_size[1]).into();
|
||||
gpu!(scope, count = kernel_size);
|
||||
} else {
|
||||
let zero: Variable = 0u32.into();
|
||||
gpu!(scope, count = zero);
|
||||
}
|
||||
(sum, count)
|
||||
}
|
||||
|
||||
fn process_result(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
accumulator: Self::Accumulator,
|
||||
result: Variable,
|
||||
_idx: Variable,
|
||||
) -> Self::Accumulator {
|
||||
let (sum, count) = accumulator;
|
||||
if !self.count_include_pad {
|
||||
let one: Variable = 1u32.into();
|
||||
gpu!(scope, count += one);
|
||||
}
|
||||
gpu!(scope, sum += result);
|
||||
(sum, count)
|
||||
}
|
||||
|
||||
fn assign(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
id: Variable,
|
||||
output: Variable,
|
||||
_indices: Option<Variable>,
|
||||
accumulator: Self::Accumulator,
|
||||
) {
|
||||
let (sum, count) = accumulator;
|
||||
let avg = scope.create_local(output.item());
|
||||
let count_float = scope.create_local(output.item());
|
||||
gpu!(scope, count_float = cast(count));
|
||||
gpu!(scope, avg = sum / count_float);
|
||||
gpu!(scope, output[id] = avg);
|
||||
}
|
||||
|
||||
fn with_indices() -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -40,63 +77,49 @@ pub(crate) fn avg_pool2d<R: Runtime, E: JitElement>(
|
|||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let (info_handle, output) =
|
||||
build_output_and_info_pool2d(&x, kernel_size, stride, padding, [1, 1]);
|
||||
let [batch_size, channels, _, _] = x.shape.dims;
|
||||
let dilation = 1;
|
||||
|
||||
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT);
|
||||
let kernel: Box<dyn Kernel> = match count_include_pad {
|
||||
true => Box::new(StaticKernel::<
|
||||
KernelSettings<AvgPool2d<true>, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(workgroup)),
|
||||
false => Box::new(StaticKernel::<
|
||||
KernelSettings<AvgPool2d<false>, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(workgroup)),
|
||||
};
|
||||
let size_0 = calculate_pool_output_size(
|
||||
kernel_size[0],
|
||||
stride[0],
|
||||
padding[0],
|
||||
dilation,
|
||||
x.shape.dims[2],
|
||||
);
|
||||
let size_1 = calculate_pool_output_size(
|
||||
kernel_size[1],
|
||||
stride[1],
|
||||
padding[1],
|
||||
dilation,
|
||||
x.shape.dims[3],
|
||||
);
|
||||
|
||||
x.client
|
||||
.execute(kernel, &[&x.handle, &output.handle, &info_handle]);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) fn avg_pool2d_backward<R: Runtime, E: JitElement>(
|
||||
x: JitTensor<R, E, 4>,
|
||||
grad: JitTensor<R, E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let grad = kernel::into_contiguous(grad);
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone());
|
||||
let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]);
|
||||
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT);
|
||||
|
||||
let kernel: Box<dyn Kernel> = match count_include_pad {
|
||||
true => Box::new(StaticKernel::<
|
||||
KernelSettings<
|
||||
AvgPool2dBackward<true>,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP_DEFAULT,
|
||||
WORKGROUP_DEFAULT,
|
||||
1,
|
||||
>,
|
||||
>::new(workgroup)),
|
||||
false => Box::new(StaticKernel::<
|
||||
KernelSettings<
|
||||
AvgPool2dBackward<false>,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP_DEFAULT,
|
||||
WORKGROUP_DEFAULT,
|
||||
1,
|
||||
>,
|
||||
>::new(workgroup)),
|
||||
};
|
||||
|
||||
x.client
|
||||
.execute(kernel, &[&grad.handle, &output.handle, &info_handle]);
|
||||
let shape_out = Shape::new([batch_size, channels, size_0, size_1]);
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
|
||||
|
||||
let pool_strategy = AvgPool::new(kernel_size, count_include_pad);
|
||||
let kernel = Pool2dEagerKernel::new(kernel_size, pool_strategy);
|
||||
|
||||
execute_dynamic::<R, Pool2dEagerKernel<AvgPool, R, E>, RuntimeInt<R>>(
|
||||
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
|
||||
&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
Some(&[
|
||||
(stride[0] as u32).elem(),
|
||||
(stride[1] as u32).elem(),
|
||||
(dilation as u32).elem(),
|
||||
(dilation as u32).elem(),
|
||||
(padding[0] as u32).elem(),
|
||||
(padding[1] as u32).elem(),
|
||||
]),
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
x.client,
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -0,0 +1,408 @@
|
|||
use burn_tensor::ElementConversion;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
kernel::{self, DynamicKernelSource, SourceTemplate},
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime, RuntimeInt,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(new)]
|
||||
struct AvgPool2dBackwardEagerKernel<R: Runtime, E: JitElement> {
|
||||
kernel_size: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
_runtime: PhantomData<R>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
struct AvgPool2dBackwardComputeShader {
|
||||
grad: Variable,
|
||||
output: Variable,
|
||||
kernel_size: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
}
|
||||
|
||||
impl AvgPool2dBackwardComputeShader {
|
||||
fn expand(self, scope: &mut Scope) {
|
||||
let grad = self.grad;
|
||||
let output = self.output;
|
||||
let id = Variable::Id;
|
||||
|
||||
let grad_stride_0 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_1 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_2 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let grad_shape_2 = scope.create_local(Elem::UInt);
|
||||
let grad_shape_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, grad_stride_0 = stride(grad, 0u32));
|
||||
gpu!(scope, grad_stride_1 = stride(grad, 1u32));
|
||||
gpu!(scope, grad_stride_2 = stride(grad, 2u32));
|
||||
gpu!(scope, grad_stride_3 = stride(grad, 3u32));
|
||||
|
||||
gpu!(scope, grad_shape_2 = shape(grad, 2u32));
|
||||
gpu!(scope, grad_shape_3 = shape(grad, 3u32));
|
||||
|
||||
gpu!(scope, output_stride_0 = stride(output, 0u32));
|
||||
gpu!(scope, output_stride_1 = stride(output, 1u32));
|
||||
gpu!(scope, output_stride_2 = stride(output, 2u32));
|
||||
gpu!(scope, output_stride_3 = stride(output, 3u32));
|
||||
|
||||
gpu!(scope, output_shape_0 = shape(output, 0u32));
|
||||
gpu!(scope, output_shape_1 = shape(output, 1u32));
|
||||
gpu!(scope, output_shape_2 = shape(output, 2u32));
|
||||
gpu!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
||||
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let c = scope.create_local(Elem::UInt);
|
||||
let ih = scope.create_local(Elem::UInt);
|
||||
let iw = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, b = id / output_stride_0);
|
||||
gpu!(scope, b = b % output_shape_0);
|
||||
|
||||
gpu!(scope, c = id / output_stride_1);
|
||||
gpu!(scope, c = c % output_shape_1);
|
||||
|
||||
gpu!(scope, ih = id / output_stride_2);
|
||||
gpu!(scope, ih = ih % output_shape_2);
|
||||
|
||||
gpu!(scope, iw = id / output_stride_3);
|
||||
gpu!(scope, iw = iw % output_shape_3);
|
||||
|
||||
let index_current = scope.create_local(Elem::UInt);
|
||||
let index_current_tmp = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, index_current = ih * output_stride_2);
|
||||
gpu!(scope, index_current_tmp = iw * output_stride_3);
|
||||
gpu!(scope, index_current += index_current_tmp);
|
||||
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let index_tmp = scope.create_local(Elem::UInt);
|
||||
let index_base = scope.create_local(Elem::UInt);
|
||||
|
||||
let grad_accumulation = scope.zero(grad.item());
|
||||
let result = scope.create_local(grad.item());
|
||||
let count = scope.create_local(grad.item());
|
||||
|
||||
let count_include_pad = self.count_include_pad;
|
||||
if count_include_pad {
|
||||
let kernel_size: Variable = (self.kernel_size[0] * self.kernel_size[1]).into();
|
||||
gpu!(scope, count = kernel_size);
|
||||
}
|
||||
|
||||
let (oh_start, oh_end, ow_start, ow_end) = self.loop_ranges(
|
||||
scope,
|
||||
ih,
|
||||
iw,
|
||||
grad_shape_2,
|
||||
grad_shape_3,
|
||||
output_stride_2,
|
||||
output_stride_3,
|
||||
);
|
||||
|
||||
gpu!(scope, index_base = b * grad_stride_0);
|
||||
gpu!(scope, index_tmp = c * grad_stride_1);
|
||||
gpu!(scope, index_base += index_tmp);
|
||||
|
||||
let border_bottom = scope.create_local(Elem::UInt);
|
||||
let border_right = scope.create_local(Elem::UInt);
|
||||
let begin_h = scope.create_local(Elem::UInt);
|
||||
let begin_w = scope.create_local(Elem::UInt);
|
||||
let iw_start = scope.create_local(Elem::UInt);
|
||||
let iw_end = scope.create_local(Elem::UInt);
|
||||
let ih_start = scope.create_local(Elem::UInt);
|
||||
let ih_end = scope.create_local(Elem::UInt);
|
||||
let after_start = scope.create_local(Elem::Bool);
|
||||
let before_end = scope.create_local(Elem::Bool);
|
||||
let contributed_h = scope.create_local(Elem::Bool);
|
||||
let contributed_w = scope.create_local(Elem::Bool);
|
||||
gpu!(scope, border_bottom = output_shape_2 + padding_0);
|
||||
gpu!(scope, border_right = output_shape_3 + padding_1);
|
||||
gpu!(scope, begin_h = ih + padding_0);
|
||||
gpu!(scope, begin_w = iw + padding_1);
|
||||
|
||||
let ih_diff = scope.create_local(Elem::UInt);
|
||||
let iw_diff = scope.create_local(Elem::UInt);
|
||||
let count_int = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
range(oh_start, oh_end).for_each(|oh, scope| {
|
||||
// Contributed h
|
||||
gpu!(scope, ih_start = oh * pool_stride_0);
|
||||
gpu!(scope, ih_end = ih_start + kernel_size_0);
|
||||
gpu!(scope, ih_start = max(ih_start, padding_0));
|
||||
gpu!(scope, ih_end = min(ih_end, border_bottom));
|
||||
gpu!(scope, after_start = begin_h >= ih_start);
|
||||
gpu!(scope, before_end = ih < ih_end);
|
||||
gpu!(scope, contributed_h = after_start && before_end);
|
||||
|
||||
if !count_include_pad {
|
||||
gpu!(scope, ih_diff = ih_end - ih_start);
|
||||
}
|
||||
|
||||
gpu!(scope, if(contributed_h).then(|scope|{
|
||||
gpu!(
|
||||
scope,
|
||||
range(ow_start, ow_end).for_each(|ow, scope| {
|
||||
gpu!(scope, index = index_base);
|
||||
gpu!(scope, index_tmp = oh * grad_stride_2);
|
||||
gpu!(scope, index += index_tmp);
|
||||
gpu!(scope, index_tmp = ow * grad_stride_3);
|
||||
gpu!(scope, index += index_tmp);
|
||||
|
||||
// Contributed w
|
||||
gpu!(scope, iw_start = ow * pool_stride_1);
|
||||
gpu!(scope, iw_end = iw_start + kernel_size_1);
|
||||
gpu!(scope, iw_start = max(iw_start, padding_1));
|
||||
gpu!(scope, iw_end = min(iw_end, border_right));
|
||||
gpu!(scope, after_start = begin_w >= iw_start);
|
||||
gpu!(scope, before_end = iw < iw_end);
|
||||
gpu!(scope, contributed_w = after_start && before_end);
|
||||
|
||||
gpu!(scope, if(contributed_w).then(|scope|{
|
||||
if !count_include_pad {
|
||||
gpu!(scope, iw_diff = iw_end - iw_start);
|
||||
gpu!(scope, count_int = ih_diff * iw_diff);
|
||||
gpu!(scope, count = cast(count_int));
|
||||
}
|
||||
|
||||
gpu!(scope, result = grad[index]);
|
||||
gpu!(scope, result = result / count);
|
||||
gpu!(scope, grad_accumulation += result);
|
||||
}));
|
||||
}));
|
||||
}));
|
||||
})
|
||||
);
|
||||
|
||||
gpu!(scope, output[id] = grad_accumulation);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn loop_ranges(
|
||||
self,
|
||||
scope: &mut Scope,
|
||||
ih: Variable,
|
||||
iw: Variable,
|
||||
grad_shape_2: Variable,
|
||||
grad_shape_3: Variable,
|
||||
output_stride_2: Variable,
|
||||
output_stride_3: Variable,
|
||||
) -> (Variable, Variable, Variable, Variable) {
|
||||
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
||||
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
||||
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
|
||||
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
|
||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
||||
|
||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||
|
||||
let signed_ih = scope.create_local(Elem::Int);
|
||||
let signed_iw = scope.create_local(Elem::Int);
|
||||
|
||||
let signed_pool_stride_0 = scope.create_local(Elem::Int);
|
||||
let signed_pool_stride_1 = scope.create_local(Elem::Int);
|
||||
let signed_dilation_0 = scope.create_local(Elem::Int);
|
||||
let signed_dilation_1 = scope.create_local(Elem::Int);
|
||||
let signed_padding_0 = scope.create_local(Elem::Int);
|
||||
let signed_padding_1 = scope.create_local(Elem::Int);
|
||||
let signed_kernel_size_0 = scope.create_local(Elem::Int);
|
||||
let signed_kernel_size_1 = scope.create_local(Elem::Int);
|
||||
|
||||
gpu!(scope, signed_pool_stride_0 = cast(pool_stride_0));
|
||||
gpu!(scope, signed_pool_stride_1 = cast(pool_stride_1));
|
||||
gpu!(scope, signed_dilation_0 = cast(dilation_0));
|
||||
gpu!(scope, signed_dilation_1 = cast(dilation_1));
|
||||
gpu!(scope, signed_padding_0 = cast(padding_0));
|
||||
gpu!(scope, signed_padding_1 = cast(padding_1));
|
||||
|
||||
gpu!(scope, signed_kernel_size_0 = cast(kernel_size_0));
|
||||
gpu!(scope, signed_kernel_size_1 = cast(kernel_size_1));
|
||||
|
||||
gpu!(scope, signed_ih = cast(ih));
|
||||
gpu!(scope, signed_iw = cast(iw));
|
||||
|
||||
let kms_0 = scope.create_local(Elem::Int);
|
||||
let kms_1 = scope.create_local(Elem::Int);
|
||||
|
||||
gpu!(scope, kms_0 = signed_dilation_0 * signed_kernel_size_0);
|
||||
gpu!(scope, kms_0 = kms_0 - signed_pool_stride_0);
|
||||
|
||||
gpu!(scope, kms_1 = signed_dilation_1 * signed_kernel_size_1);
|
||||
gpu!(scope, kms_1 = kms_1 - signed_pool_stride_1);
|
||||
|
||||
let oh_start_tmp = scope.create_local(Elem::Int);
|
||||
let ow_start_tmp = scope.create_local(Elem::Int);
|
||||
|
||||
gpu!(scope, oh_start_tmp = signed_ih + signed_padding_0);
|
||||
gpu!(scope, oh_start_tmp = oh_start_tmp - kms_0);
|
||||
gpu!(scope, oh_start_tmp = oh_start_tmp / signed_pool_stride_0);
|
||||
|
||||
gpu!(scope, ow_start_tmp = signed_iw + signed_padding_1);
|
||||
gpu!(scope, ow_start_tmp = ow_start_tmp - kms_1);
|
||||
gpu!(scope, ow_start_tmp = ow_start_tmp / signed_pool_stride_1);
|
||||
|
||||
gpu!(scope, oh_start_tmp = max(oh_start_tmp, 0i32));
|
||||
gpu!(scope, ow_start_tmp = max(ow_start_tmp, 0i32));
|
||||
|
||||
let oh_start = scope.create_local(Elem::UInt);
|
||||
let ow_start = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, oh_start = cast(oh_start_tmp));
|
||||
gpu!(scope, ow_start = cast(ow_start_tmp));
|
||||
|
||||
let oh_end_tmp = scope.create_local(Elem::Int);
|
||||
let ow_end_tmp = scope.create_local(Elem::Int);
|
||||
|
||||
gpu!(scope, oh_end_tmp = max(kms_0, 0i32));
|
||||
gpu!(scope, ow_end_tmp = max(kms_1, 0i32));
|
||||
|
||||
let oh_end = scope.create_local(Elem::UInt);
|
||||
let ow_end = scope.create_local(Elem::UInt);
|
||||
|
||||
let oh_end_limit = scope.create_local(Elem::UInt);
|
||||
let ow_end_limit = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, oh_end = cast(oh_end_tmp));
|
||||
gpu!(scope, ow_end = cast(ow_end_tmp));
|
||||
|
||||
gpu!(scope, oh_end = oh_end + oh_start);
|
||||
gpu!(scope, oh_end_limit = grad_shape_2 - 1u32);
|
||||
|
||||
gpu!(scope, ow_end = ow_end + ow_start);
|
||||
gpu!(scope, ow_end_limit = grad_shape_3 - 1u32);
|
||||
|
||||
gpu!(scope, oh_end = min(oh_end, oh_end_limit));
|
||||
gpu!(scope, ow_end = min(ow_end, ow_end_limit));
|
||||
|
||||
let index_current = scope.create_local(Elem::UInt);
|
||||
let index_current_tmp = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, index_current = ih * output_stride_2);
|
||||
gpu!(scope, index_current_tmp = iw * output_stride_3);
|
||||
gpu!(scope, index_current += index_current_tmp);
|
||||
|
||||
gpu!(scope, oh_end = oh_end + 1u32);
|
||||
gpu!(scope, ow_end = ow_end + 1u32);
|
||||
|
||||
(oh_start, oh_end, ow_start, ow_end)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime, E: JitElement> DynamicKernelSource for AvgPool2dBackwardEagerKernel<R, E> {
|
||||
fn source(&self) -> kernel::SourceTemplate {
|
||||
let mut scope = Scope::root();
|
||||
let item = E::gpu_elem().into();
|
||||
|
||||
let grad = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
AvgPool2dBackwardComputeShader {
|
||||
grad,
|
||||
output,
|
||||
kernel_size: self.kernel_size,
|
||||
count_include_pad: self.count_include_pad,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
let grad = InputInfo::Array {
|
||||
item,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
let scalars = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
size: 6,
|
||||
};
|
||||
let output = OutputInfo::Array { item };
|
||||
|
||||
let info = CompilationInfo {
|
||||
inputs: vec![grad, scalars],
|
||||
outputs: vec![output],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = CompilationSettings::default();
|
||||
let shader = Compilation::new(info).compile(settings);
|
||||
let shader = <R::Compiler as Compiler>::compile(shader);
|
||||
SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!(
|
||||
"{:?}k={:?}count_include_pad={:?}",
|
||||
core::any::TypeId::of::<Self>(),
|
||||
self.kernel_size,
|
||||
self.count_include_pad
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn avg_pool2d_backward<R: Runtime, E: JitElement>(
|
||||
x: JitTensor<R, E, 4>,
|
||||
grad: JitTensor<R, E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let grad = kernel::into_contiguous(grad);
|
||||
let dilation = 1;
|
||||
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone());
|
||||
let kernel = AvgPool2dBackwardEagerKernel::new(kernel_size, count_include_pad);
|
||||
|
||||
execute_dynamic::<R, AvgPool2dBackwardEagerKernel<R, E>, RuntimeInt<R>>(
|
||||
&[EagerHandle::new(
|
||||
&grad.handle,
|
||||
&grad.strides,
|
||||
&grad.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
Some(&[
|
||||
(stride[0] as i32).elem(),
|
||||
(stride[1] as i32).elem(),
|
||||
dilation.elem(),
|
||||
dilation.elem(),
|
||||
(padding[0] as i32).elem(),
|
||||
(padding[1] as i32).elem(),
|
||||
]),
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
x.client,
|
||||
);
|
||||
|
||||
output
|
||||
}
|
|
@ -1,72 +1,27 @@
|
|||
use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, Runtime};
|
||||
use burn_compute::server::Handle;
|
||||
use burn_tensor::Shape;
|
||||
use crate::gpu::{Item, Scope, Variable};
|
||||
use std::fmt::Debug;
|
||||
|
||||
/// Build basic info to launch pool 2d kernels.
|
||||
pub fn build_output_and_info_pool2d<R: Runtime, E: JitElement>(
|
||||
x: &JitTensor<R, E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> (Handle<R::Server>, JitTensor<R, E, 4>) {
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [padding_height, padding_width] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [dilation_height, dilation_width] = dilation;
|
||||
let [batch_size, channels, x_height, x_width] = x.shape.dims;
|
||||
pub(crate) trait PoolStrategy: Send + Sync + 'static + Clone + Debug {
|
||||
type Accumulator: Copy;
|
||||
|
||||
let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1)
|
||||
/ stride_height)
|
||||
+ 1;
|
||||
let out_width = ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1)
|
||||
/ stride_width)
|
||||
+ 1;
|
||||
let shape_out = Shape::new([batch_size, channels, out_height, out_width]);
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
|
||||
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator;
|
||||
|
||||
let info_buffer = build_pool2d_info(x, &output, kernel_size, stride, padding, dilation);
|
||||
fn process_result(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
accumulator: Self::Accumulator,
|
||||
result: Variable,
|
||||
idx: Variable,
|
||||
) -> Self::Accumulator;
|
||||
|
||||
(info_buffer, output)
|
||||
}
|
||||
|
||||
pub fn build_pool2d_info<R: Runtime, E: JitElement>(
|
||||
input: &JitTensor<R, E, 4>,
|
||||
output: &JitTensor<R, E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> Handle<R::Server> {
|
||||
let mut info: [u32; 24] = [0; 24];
|
||||
info[0] = input.strides[0] as u32;
|
||||
info[1] = input.strides[1] as u32;
|
||||
info[2] = input.strides[2] as u32;
|
||||
info[3] = input.strides[3] as u32;
|
||||
info[4] = input.shape.dims[0] as u32;
|
||||
info[5] = input.shape.dims[1] as u32;
|
||||
info[6] = input.shape.dims[2] as u32;
|
||||
info[7] = input.shape.dims[3] as u32;
|
||||
|
||||
info[8] = output.strides[0] as u32;
|
||||
info[9] = output.strides[1] as u32;
|
||||
info[10] = output.strides[2] as u32;
|
||||
info[11] = output.strides[3] as u32;
|
||||
info[12] = output.shape.dims[0] as u32;
|
||||
info[13] = output.shape.dims[1] as u32;
|
||||
info[14] = output.shape.dims[2] as u32;
|
||||
info[15] = output.shape.dims[3] as u32;
|
||||
|
||||
info[16] = kernel_size[0] as u32;
|
||||
info[17] = kernel_size[1] as u32;
|
||||
info[18] = stride[0] as u32;
|
||||
info[19] = stride[1] as u32;
|
||||
info[20] = padding[0] as u32;
|
||||
info[21] = padding[1] as u32;
|
||||
info[22] = dilation[0] as u32;
|
||||
info[23] = dilation[1] as u32;
|
||||
|
||||
let info_buffer = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
info_buffer
|
||||
fn assign(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
id: Variable,
|
||||
output: Variable,
|
||||
indices: Option<Variable>,
|
||||
accumulator: Self::Accumulator,
|
||||
);
|
||||
|
||||
fn with_indices() -> bool;
|
||||
}
|
||||
|
|
|
@ -1,302 +1,165 @@
|
|||
use burn_tensor::{ops::conv::calculate_pool_output_size, ElementConversion, Shape};
|
||||
use std::marker::PhantomData;
|
||||
use std::{fmt::Debug, marker::PhantomData};
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
codegen::{dialect::gpu::Variable, execute_dynamic, EagerHandle, WorkgroupLaunch},
|
||||
element::JitElement,
|
||||
kernel::{DynamicKernelSource, SourceTemplate},
|
||||
gpu::{gpu, Elem, Item, Scope},
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime, RuntimeInt,
|
||||
};
|
||||
use burn_tensor::{ops::conv::calculate_pool_output_size, ElementConversion, Shape};
|
||||
|
||||
#[derive(new)]
|
||||
struct MaxPool2dEagerKernel<R: Runtime, E: JitElement> {
|
||||
kernel_size: [usize; 2],
|
||||
_runtime: PhantomData<R>,
|
||||
use super::{Pool2dEagerKernel, PoolStrategy};
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
struct MaxPool<E: JitElement> {
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct MaxPool2dWithIndicesEagerKernel<R: Runtime, E: JitElement> {
|
||||
kernel_size: [usize; 2],
|
||||
_runtime: PhantomData<R>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
impl<E: JitElement> PoolStrategy for MaxPool<E> {
|
||||
type Accumulator = Variable;
|
||||
|
||||
struct MaxPool2dComputeShader<E: JitElement> {
|
||||
x: Variable,
|
||||
output: Variable,
|
||||
kernel_size: [usize; 2],
|
||||
indices: Option<Variable>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: JitElement> MaxPool2dComputeShader<E> {
|
||||
fn expand(self, scope: &mut Scope) {
|
||||
let x = self.x;
|
||||
let output = self.output;
|
||||
let id = Variable::Id;
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||
let input_stride_2 = scope.create_local(Elem::UInt);
|
||||
let input_stride_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let input_shape_2 = scope.create_local(Elem::UInt);
|
||||
let input_shape_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, input_stride_0 = stride(x, 0u32));
|
||||
gpu!(scope, input_stride_1 = stride(x, 1u32));
|
||||
gpu!(scope, input_stride_2 = stride(x, 2u32));
|
||||
gpu!(scope, input_stride_3 = stride(x, 3u32));
|
||||
|
||||
gpu!(scope, input_shape_2 = shape(x, 2u32));
|
||||
gpu!(scope, input_shape_3 = shape(x, 3u32));
|
||||
|
||||
gpu!(scope, output_stride_0 = stride(output, 0u32));
|
||||
gpu!(scope, output_stride_1 = stride(output, 1u32));
|
||||
gpu!(scope, output_stride_2 = stride(output, 2u32));
|
||||
gpu!(scope, output_stride_3 = stride(output, 3u32));
|
||||
|
||||
gpu!(scope, output_shape_0 = shape(output, 0u32));
|
||||
gpu!(scope, output_shape_1 = shape(output, 1u32));
|
||||
gpu!(scope, output_shape_2 = shape(output, 2u32));
|
||||
gpu!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
||||
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
||||
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
|
||||
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
|
||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
||||
|
||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let c = scope.create_local(Elem::UInt);
|
||||
let oh = scope.create_local(Elem::UInt);
|
||||
let ow = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, b = id / output_stride_0);
|
||||
gpu!(scope, b = b % output_shape_0);
|
||||
|
||||
gpu!(scope, c = id / output_stride_1);
|
||||
gpu!(scope, c = c % output_shape_1);
|
||||
|
||||
gpu!(scope, oh = id / output_stride_2);
|
||||
gpu!(scope, oh = oh % output_shape_2);
|
||||
|
||||
gpu!(scope, ow = id / output_stride_3);
|
||||
gpu!(scope, ow = ow % output_shape_3);
|
||||
|
||||
let tmp = scope.create_local(Elem::UInt);
|
||||
let ih = scope.create_local(Elem::UInt);
|
||||
let iw = scope.create_local(Elem::UInt);
|
||||
|
||||
let ih_pad = scope.create_local(Elem::UInt);
|
||||
let iw_pad = scope.create_local(Elem::UInt);
|
||||
let result = scope.create_local(x.item());
|
||||
|
||||
let cond = scope.create_local(Elem::Bool);
|
||||
let cond_tmp = scope.create_local(Elem::Bool);
|
||||
|
||||
let index_input = scope.create_local(Elem::UInt);
|
||||
let index_input_1 = scope.create_local(Elem::UInt);
|
||||
let index_input_2 = scope.create_local(Elem::UInt);
|
||||
let index_input_3 = scope.create_local(Elem::UInt);
|
||||
let index_input_4 = scope.create_local(Elem::UInt);
|
||||
|
||||
let is_max = scope.create_local(Elem::Bool);
|
||||
let max_index = self.indices.map(|_| scope.create_local(Elem::UInt));
|
||||
|
||||
let max_val = scope.create_local(x.item());
|
||||
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
|
||||
let max_val = scope.create_local(item);
|
||||
let max_initial =
|
||||
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), x.item().elem());
|
||||
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), item.elem());
|
||||
gpu!(scope, max_val = max_initial);
|
||||
max_val
|
||||
}
|
||||
|
||||
(0..kernel_size_0).for_each(|kh| {
|
||||
gpu!(scope, ih = oh * pool_stride_0);
|
||||
gpu!(scope, tmp = kh * dilation_0);
|
||||
gpu!(scope, ih += tmp);
|
||||
fn process_result(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
accumulator: Self::Accumulator,
|
||||
result: Variable,
|
||||
_idx: Variable,
|
||||
) -> Self::Accumulator {
|
||||
let is_max = scope.create_local(Elem::Bool);
|
||||
gpu!(scope, is_max = result > accumulator);
|
||||
gpu!(scope, if(is_max).then(|scope|{
|
||||
gpu!(scope, accumulator = result);
|
||||
}));
|
||||
accumulator
|
||||
}
|
||||
|
||||
// Up
|
||||
gpu!(scope, cond = ih < padding_0);
|
||||
// Down
|
||||
gpu!(scope, tmp = input_shape_2 + padding_0);
|
||||
gpu!(scope, cond_tmp = ih >= tmp);
|
||||
gpu!(scope, cond = cond || cond_tmp);
|
||||
gpu!(scope, cond = !cond);
|
||||
fn assign(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
id: Variable,
|
||||
output: Variable,
|
||||
_indices: Option<Variable>,
|
||||
accumulator: Self::Accumulator,
|
||||
) {
|
||||
gpu!(scope, output[id] = accumulator);
|
||||
}
|
||||
|
||||
gpu!(scope, if (cond).then(|scope| {
|
||||
(0..kernel_size_1).for_each(|kw| {
|
||||
gpu!(scope, iw = ow * pool_stride_1);
|
||||
gpu!(scope, tmp = kw * dilation_1);
|
||||
gpu!(scope, iw = iw + tmp);
|
||||
fn with_indices() -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
// Left
|
||||
gpu!(scope, cond = iw < padding_1);
|
||||
// Right
|
||||
gpu!(scope, tmp = input_shape_3 + padding_1);
|
||||
gpu!(scope, cond_tmp = iw >= tmp);
|
||||
gpu!(scope, cond = cond || cond_tmp);
|
||||
gpu!(scope, cond = !cond);
|
||||
#[derive(Default, Debug, Clone)]
|
||||
struct MaxPoolWithIndices<E: JitElement> {
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
gpu!(scope, if (cond).then(|scope| {
|
||||
gpu!(scope, ih_pad = ih - padding_0);
|
||||
gpu!(scope, iw_pad = iw - padding_1);
|
||||
impl<E: JitElement> PoolStrategy for MaxPoolWithIndices<E> {
|
||||
type Accumulator = (Variable, Variable);
|
||||
|
||||
gpu!(scope, index_input_1 = b * input_stride_0);
|
||||
gpu!(scope, index_input_2 = c * input_stride_1);
|
||||
gpu!(scope, index_input_3 = ih_pad * input_stride_2);
|
||||
gpu!(scope, index_input_4 = iw_pad * input_stride_3);
|
||||
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
|
||||
let max_val = scope.create_local(item);
|
||||
let max_initial =
|
||||
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), item.elem());
|
||||
gpu!(scope, max_val = max_initial);
|
||||
let max_index = scope.create_local(Elem::UInt);
|
||||
(max_val, max_index)
|
||||
}
|
||||
|
||||
gpu!(scope, index_input = index_input_1);
|
||||
gpu!(scope, index_input += index_input_2);
|
||||
gpu!(scope, index_input += index_input_3);
|
||||
gpu!(scope, index_input += index_input_4);
|
||||
|
||||
gpu!(scope, result = x[index_input]);
|
||||
|
||||
gpu!(scope, is_max = result > max_val);
|
||||
|
||||
gpu!(scope, if(is_max).then(|scope|{
|
||||
gpu!(scope, max_val = result);
|
||||
if let Some(max_index) = max_index {
|
||||
gpu!(scope, max_index = ih_pad * input_shape_2);
|
||||
gpu!(scope, max_index += iw_pad);
|
||||
}
|
||||
}));
|
||||
}));
|
||||
});
|
||||
}));
|
||||
});
|
||||
fn process_result(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
(max_val, max_index): Self::Accumulator,
|
||||
result: Variable,
|
||||
idx: Variable,
|
||||
) -> Self::Accumulator {
|
||||
let is_max = scope.create_local(Elem::Bool);
|
||||
gpu!(scope, is_max = result > max_val);
|
||||
gpu!(scope, if(is_max).then(|scope|{
|
||||
gpu!(scope, max_val = result);
|
||||
gpu!(scope, max_index = idx);
|
||||
}));
|
||||
(max_val, max_index)
|
||||
}
|
||||
|
||||
fn assign(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
id: Variable,
|
||||
output: Variable,
|
||||
indices: Option<Variable>,
|
||||
(max_val, max_index): Self::Accumulator,
|
||||
) {
|
||||
let indices = indices.unwrap();
|
||||
gpu!(scope, output[id] = max_val);
|
||||
gpu!(scope, indices[id] = max_index);
|
||||
}
|
||||
|
||||
if let Some(indices) = self.indices {
|
||||
let max_index = max_index.unwrap();
|
||||
gpu!(scope, indices[id] = max_index);
|
||||
}
|
||||
fn with_indices() -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime, E: JitElement> DynamicKernelSource for MaxPool2dEagerKernel<R, E> {
|
||||
fn source(&self) -> crate::kernel::SourceTemplate {
|
||||
let mut scope = Scope::root();
|
||||
let item = E::gpu_elem().into();
|
||||
pub(crate) fn max_pool2d<R: Runtime, E: JitElement>(
|
||||
x: JitTensor<R, E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let [batch_size, channels, _, _] = x.shape.dims;
|
||||
|
||||
let x = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let size_0 = calculate_pool_output_size(
|
||||
kernel_size[0],
|
||||
stride[0],
|
||||
padding[0],
|
||||
dilation[0],
|
||||
x.shape.dims[2],
|
||||
);
|
||||
let size_1 = calculate_pool_output_size(
|
||||
kernel_size[1],
|
||||
stride[1],
|
||||
padding[1],
|
||||
dilation[1],
|
||||
x.shape.dims[3],
|
||||
);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
let shape_out = Shape::new([batch_size, channels, size_0, size_1]);
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
|
||||
|
||||
MaxPool2dComputeShader {
|
||||
x,
|
||||
output,
|
||||
kernel_size: self.kernel_size,
|
||||
indices: None,
|
||||
_elem: PhantomData::<E>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
let kernel = Pool2dEagerKernel::new(kernel_size, MaxPool::default());
|
||||
|
||||
let input = InputInfo::Array {
|
||||
item,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
let scalars = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
size: 6,
|
||||
};
|
||||
let output = OutputInfo::Array { item };
|
||||
execute_dynamic::<R, Pool2dEagerKernel<MaxPool<E>, R, E>, RuntimeInt<R>>(
|
||||
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
|
||||
&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
Some(&[
|
||||
(stride[0] as u32).elem(),
|
||||
(stride[1] as u32).elem(),
|
||||
(dilation[0] as u32).elem(),
|
||||
(dilation[1] as u32).elem(),
|
||||
(padding[0] as u32).elem(),
|
||||
(padding[1] as u32).elem(),
|
||||
]),
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
x.client,
|
||||
);
|
||||
|
||||
let info = CompilationInfo {
|
||||
inputs: vec![input, scalars],
|
||||
outputs: vec![output],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = CompilationSettings::default();
|
||||
let shader = Compilation::new(info).compile(settings);
|
||||
let shader = <R::Compiler as Compiler>::compile(shader);
|
||||
SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!(
|
||||
"{:?}k={:?}",
|
||||
core::any::TypeId::of::<Self>(),
|
||||
self.kernel_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime, E: JitElement> DynamicKernelSource for MaxPool2dWithIndicesEagerKernel<R, E> {
|
||||
fn source(&self) -> crate::kernel::SourceTemplate {
|
||||
let mut scope = Scope::root();
|
||||
let item = E::gpu_elem().into();
|
||||
|
||||
let x = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let indices = Variable::GlobalOutputArray(1, Item::Scalar(Elem::Int));
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
MaxPool2dComputeShader {
|
||||
x,
|
||||
output,
|
||||
kernel_size: self.kernel_size,
|
||||
indices: Some(indices),
|
||||
_elem: PhantomData::<E>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
let input = InputInfo::Array {
|
||||
item,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
let scalars = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
size: 6,
|
||||
};
|
||||
let output = OutputInfo::Array { item };
|
||||
let indices = OutputInfo::Array {
|
||||
item: Item::Scalar(Elem::Int),
|
||||
};
|
||||
|
||||
let info = CompilationInfo {
|
||||
inputs: vec![input, scalars],
|
||||
outputs: vec![output, indices],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = CompilationSettings::default();
|
||||
let shader = Compilation::new(info).compile(settings);
|
||||
let shader = <R::Compiler as Compiler>::compile(shader);
|
||||
SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!(
|
||||
"{:?}k={:?}",
|
||||
core::any::TypeId::of::<Self>(),
|
||||
self.kernel_size,
|
||||
)
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) fn max_pool2d_with_indices<R: Runtime, E: JitElement, I: JitElement>(
|
||||
|
@ -327,8 +190,9 @@ pub(crate) fn max_pool2d_with_indices<R: Runtime, E: JitElement, I: JitElement>(
|
|||
let output = empty_device(x.client.clone(), x.device.clone(), shape_out.clone());
|
||||
let indices = empty_device(x.client.clone(), x.device.clone(), shape_out);
|
||||
|
||||
let kernel = MaxPool2dWithIndicesEagerKernel::new(kernel_size);
|
||||
execute_dynamic::<R, MaxPool2dWithIndicesEagerKernel<R, E>, I>(
|
||||
let kernel = Pool2dEagerKernel::new(kernel_size, MaxPoolWithIndices::default());
|
||||
|
||||
execute_dynamic::<R, Pool2dEagerKernel<MaxPoolWithIndices<E>, R, E>, I>(
|
||||
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
|
||||
&[
|
||||
EagerHandle::new(&output.handle, &output.strides, &output.shape.dims),
|
||||
|
@ -349,55 +213,3 @@ pub(crate) fn max_pool2d_with_indices<R: Runtime, E: JitElement, I: JitElement>(
|
|||
|
||||
(output, indices)
|
||||
}
|
||||
|
||||
pub(crate) fn max_pool2d<R: Runtime, E: JitElement>(
|
||||
x: JitTensor<R, E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let [batch_size, channels, _, _] = x.shape.dims;
|
||||
|
||||
let size_0 = calculate_pool_output_size(
|
||||
kernel_size[0],
|
||||
stride[0],
|
||||
padding[0],
|
||||
dilation[0],
|
||||
x.shape.dims[2],
|
||||
);
|
||||
let size_1 = calculate_pool_output_size(
|
||||
kernel_size[1],
|
||||
stride[1],
|
||||
padding[1],
|
||||
dilation[1],
|
||||
x.shape.dims[3],
|
||||
);
|
||||
|
||||
let shape_out = Shape::new([batch_size, channels, size_0, size_1]);
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
|
||||
|
||||
let kernel = MaxPool2dEagerKernel::new(kernel_size);
|
||||
|
||||
execute_dynamic::<R, MaxPool2dEagerKernel<R, E>, RuntimeInt<R>>(
|
||||
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
|
||||
&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
Some(&[
|
||||
(stride[0] as i32).elem(),
|
||||
(stride[1] as i32).elem(),
|
||||
(dilation[0] as i32).elem(),
|
||||
(dilation[1] as i32).elem(),
|
||||
(padding[0] as i32).elem(),
|
||||
(padding[1] as i32).elem(),
|
||||
]),
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
x.client,
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -101,6 +101,7 @@ impl MaxPool2dBackwardComputeShader {
|
|||
let is_max = scope.create_local(Elem::Bool);
|
||||
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let index_base = scope.create_local(Elem::UInt);
|
||||
let index_tmp = scope.create_local(Elem::UInt);
|
||||
|
||||
let grad_accumulation = scope.zero(grad.item());
|
||||
|
@ -116,20 +117,19 @@ impl MaxPool2dBackwardComputeShader {
|
|||
output_stride_3,
|
||||
);
|
||||
|
||||
gpu!(scope, index_base = b * grad_stride_0);
|
||||
gpu!(scope, index_tmp = c * grad_stride_1);
|
||||
gpu!(scope, index_base += index_tmp);
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
range(oh_start, oh_end).for_each(|oh, scope| {
|
||||
gpu!(
|
||||
scope,
|
||||
range(ow_start, ow_end).for_each(|ow, scope| {
|
||||
gpu!(scope, index = b * grad_stride_0);
|
||||
|
||||
gpu!(scope, index_tmp = c * grad_stride_1);
|
||||
gpu!(scope, index += index_tmp);
|
||||
|
||||
gpu!(scope, index = index_base);
|
||||
gpu!(scope, index_tmp = oh * grad_stride_2);
|
||||
gpu!(scope, index += index_tmp);
|
||||
|
||||
gpu!(scope, index_tmp = ow * grad_stride_3);
|
||||
gpu!(scope, index += index_tmp);
|
||||
|
||||
|
|
|
@ -1,12 +1,19 @@
|
|||
mod adaptive_avg_pool2d;
|
||||
mod adaptive_avg_pool2d_backward;
|
||||
mod adaptive_pool2d_shader;
|
||||
mod avg_pool2d;
|
||||
mod avg_pool2d_backward;
|
||||
mod base;
|
||||
mod max_pool2d;
|
||||
mod max_pool2d_backward;
|
||||
mod pool2d_shader;
|
||||
pub(crate) use adaptive_pool2d_shader::*;
|
||||
pub(crate) use pool2d_shader::*;
|
||||
|
||||
pub(crate) use adaptive_avg_pool2d::*;
|
||||
pub use avg_pool2d::*;
|
||||
pub use adaptive_avg_pool2d_backward::*;
|
||||
pub(crate) use avg_pool2d::*;
|
||||
pub(crate) use avg_pool2d_backward::*;
|
||||
pub(super) use base::*;
|
||||
|
||||
pub(crate) use max_pool2d::*;
|
||||
pub(crate) use max_pool2d_backward::*;
|
||||
|
|
|
@ -0,0 +1,243 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
codegen::{Compilation, CompilationInfo, CompilationSettings, InputInfo, OutputInfo},
|
||||
gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
|
||||
kernel::{DynamicKernelSource, SourceTemplate},
|
||||
Compiler, JitElement, Runtime,
|
||||
};
|
||||
|
||||
use super::PoolStrategy;
|
||||
|
||||
pub(crate) struct Pool2dComputeShader<P: PoolStrategy, R: Runtime, E: JitElement> {
|
||||
input: Variable,
|
||||
output: Variable,
|
||||
indices: Option<Variable>,
|
||||
kernel_size: [usize; 2],
|
||||
pool_strategy: P,
|
||||
_elem: PhantomData<E>,
|
||||
_runtime: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl<P: PoolStrategy, R: Runtime, E: JitElement> Pool2dComputeShader<P, R, E> {
|
||||
fn expand(self, scope: &mut Scope) {
|
||||
let input = self.input;
|
||||
let output = self.output;
|
||||
let id = Variable::Id;
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||
let input_stride_2 = scope.create_local(Elem::UInt);
|
||||
let input_stride_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let input_shape_0 = scope.create_local(Elem::UInt);
|
||||
let input_shape_1 = scope.create_local(Elem::UInt);
|
||||
let input_shape_2 = scope.create_local(Elem::UInt);
|
||||
let input_shape_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, input_stride_0 = stride(input, 0u32));
|
||||
gpu!(scope, input_stride_1 = stride(input, 1u32));
|
||||
gpu!(scope, input_stride_2 = stride(input, 2u32));
|
||||
gpu!(scope, input_stride_3 = stride(input, 3u32));
|
||||
|
||||
gpu!(scope, input_shape_0 = shape(input, 2u32));
|
||||
gpu!(scope, input_shape_1 = shape(input, 3u32));
|
||||
gpu!(scope, input_shape_2 = shape(input, 2u32));
|
||||
gpu!(scope, input_shape_3 = shape(input, 3u32));
|
||||
|
||||
gpu!(scope, output_stride_0 = stride(output, 0u32));
|
||||
gpu!(scope, output_stride_1 = stride(output, 1u32));
|
||||
gpu!(scope, output_stride_2 = stride(output, 2u32));
|
||||
gpu!(scope, output_stride_3 = stride(output, 3u32));
|
||||
|
||||
gpu!(scope, output_shape_0 = shape(output, 0u32));
|
||||
gpu!(scope, output_shape_1 = shape(output, 1u32));
|
||||
gpu!(scope, output_shape_2 = shape(output, 2u32));
|
||||
gpu!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
||||
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
||||
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
|
||||
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
|
||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let c = scope.create_local(Elem::UInt);
|
||||
let oh = scope.create_local(Elem::UInt);
|
||||
let ow = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, b = id / output_stride_0);
|
||||
gpu!(scope, b = b % output_shape_0);
|
||||
|
||||
gpu!(scope, c = id / output_stride_1);
|
||||
gpu!(scope, c = c % output_shape_1);
|
||||
|
||||
gpu!(scope, oh = id / output_stride_2);
|
||||
gpu!(scope, oh = oh % output_shape_2);
|
||||
|
||||
gpu!(scope, ow = id / output_stride_3);
|
||||
gpu!(scope, ow = ow % output_shape_3);
|
||||
|
||||
let ih = scope.create_local(Elem::UInt);
|
||||
let iw = scope.create_local(Elem::UInt);
|
||||
let dilated = scope.create_local(Elem::UInt);
|
||||
|
||||
let ih_pad = scope.create_local(Elem::UInt);
|
||||
let iw_pad = scope.create_local(Elem::UInt);
|
||||
let result = scope.create_local(input.item());
|
||||
|
||||
let index_input = scope.create_local(Elem::UInt);
|
||||
let index_input_0 = scope.create_local(Elem::UInt);
|
||||
let index_input_1 = scope.create_local(Elem::UInt);
|
||||
let index_input_2 = scope.create_local(Elem::UInt);
|
||||
let index_input_3 = scope.create_local(Elem::UInt);
|
||||
let idx = scope.create_local(Elem::UInt);
|
||||
|
||||
let within_padding_h = scope.create_local(Elem::Bool);
|
||||
let within_padding_w = scope.create_local(Elem::Bool);
|
||||
let tmp_padding = scope.create_local(Elem::Bool);
|
||||
let border_bottom = scope.create_local(Elem::UInt);
|
||||
let border_right = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, border_bottom = input_shape_2 + padding_0);
|
||||
gpu!(scope, border_right = input_shape_3 + padding_1);
|
||||
|
||||
gpu!(scope, index_input_0 = b * input_stride_0);
|
||||
gpu!(scope, index_input_1 = c * input_stride_1);
|
||||
|
||||
let accumulator = self.pool_strategy.initialize(scope, input.item());
|
||||
|
||||
(0..self.kernel_size[0]).for_each(|kh| {
|
||||
gpu!(scope, ih = oh * pool_stride_0);
|
||||
gpu!(scope, dilated = kh * dilation_0);
|
||||
gpu!(scope, ih += dilated);
|
||||
|
||||
gpu!(scope, within_padding_h = ih >= padding_0);
|
||||
gpu!(scope, tmp_padding = ih < border_bottom);
|
||||
gpu!(scope, within_padding_h = within_padding_h && tmp_padding);
|
||||
|
||||
gpu!(scope, if (within_padding_h).then(|scope| {
|
||||
(0..self.kernel_size[1]).for_each(|kw| {
|
||||
gpu!(scope, iw = ow * pool_stride_1);
|
||||
gpu!(scope, dilated = kw * dilation_1);
|
||||
gpu!(scope, iw += dilated);
|
||||
|
||||
gpu!(scope, within_padding_w = iw >= padding_1);
|
||||
gpu!(scope, tmp_padding = iw < border_right);
|
||||
gpu!(scope, within_padding_w = within_padding_w && tmp_padding);
|
||||
|
||||
gpu!(scope, if (within_padding_w).then(|scope| {
|
||||
gpu!(scope, ih_pad = ih - padding_0);
|
||||
gpu!(scope, iw_pad = iw - padding_1);
|
||||
|
||||
gpu!(scope, index_input_2 = ih_pad * input_stride_2);
|
||||
gpu!(scope, idx = index_input_2);
|
||||
gpu!(scope, idx += iw_pad);
|
||||
gpu!(scope, index_input_3 = iw_pad * input_stride_3);
|
||||
|
||||
gpu!(scope, index_input = index_input_0);
|
||||
gpu!(scope, index_input += index_input_1);
|
||||
gpu!(scope, index_input += index_input_2);
|
||||
gpu!(scope, index_input += index_input_3);
|
||||
|
||||
gpu!(scope, result = input[index_input]);
|
||||
|
||||
self.pool_strategy.process_result(scope, accumulator, result, idx);
|
||||
}));
|
||||
});
|
||||
}));
|
||||
});
|
||||
|
||||
self.pool_strategy
|
||||
.assign(scope, id, output, self.indices, accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct Pool2dEagerKernel<P: PoolStrategy, R: Runtime, E: JitElement> {
|
||||
kernel_size: [usize; 2],
|
||||
pool_strategy: P,
|
||||
_runtime: PhantomData<R>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<P: PoolStrategy, R: Runtime, E: JitElement> DynamicKernelSource
|
||||
for Pool2dEagerKernel<P, R, E>
|
||||
{
|
||||
fn source(&self) -> crate::kernel::SourceTemplate {
|
||||
let mut scope = Scope::root();
|
||||
let item = E::gpu_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let indices = if P::with_indices() {
|
||||
Some(Variable::GlobalOutputArray(1, Item::Scalar(Elem::Int)))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
Pool2dComputeShader {
|
||||
input,
|
||||
output,
|
||||
indices,
|
||||
kernel_size: self.kernel_size,
|
||||
pool_strategy: self.pool_strategy.clone(),
|
||||
_elem: PhantomData::<E>,
|
||||
_runtime: PhantomData::<R>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
let input = InputInfo::Array {
|
||||
item,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
let scalars = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
size: 6,
|
||||
};
|
||||
let output = OutputInfo::Array { item };
|
||||
let outputs = if P::with_indices() {
|
||||
vec![
|
||||
output,
|
||||
OutputInfo::Array {
|
||||
item: Item::Scalar(Elem::Int),
|
||||
},
|
||||
]
|
||||
} else {
|
||||
vec![output]
|
||||
};
|
||||
|
||||
let info = CompilationInfo {
|
||||
inputs: vec![input, scalars],
|
||||
outputs,
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = CompilationSettings::default();
|
||||
let shader = Compilation::new(info).compile(settings);
|
||||
let shader = <R::Compiler as Compiler>::compile(shader);
|
||||
SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!(
|
||||
"{:?}k={:?}pl={:?}",
|
||||
core::any::TypeId::of::<Self>(),
|
||||
self.kernel_size,
|
||||
self.pool_strategy
|
||||
)
|
||||
}
|
||||
}
|
|
@ -1,243 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const B_M = {{b_m}}u;
|
||||
const B_N = {{b_n}}u;
|
||||
const B_K = {{b_k}}u;
|
||||
const B_M_X_B_K_4 = {{bm_x_bk_4}}u;
|
||||
const B_K_X_B_N_4 = {{bk_x_bn_4}}u;
|
||||
|
||||
const T_M = 4u;
|
||||
const T_N = 4u;
|
||||
const T_M_X_T_N = 16u;
|
||||
|
||||
var<workgroup> shared_lhs: array<vec4<{{ elem }}>, B_M_X_B_K_4>;
|
||||
var<workgroup> shared_rhs: array<vec4<{{ elem }}>, B_K_X_B_N_4>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_index) local_idx: u32,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
) {
|
||||
let skip_row = workgroup_id.x * B_M;
|
||||
let skip_col = workgroup_id.y * B_N;
|
||||
|
||||
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
|
||||
|
||||
// Position of the first element of the thread, relative to the block
|
||||
let thread_row = (local_idx / n_thread_per_row) * T_M;
|
||||
let thread_col = (local_idx % n_thread_per_row) * T_N;
|
||||
|
||||
// Position of the first element of the thread, in absolute (in one batch)
|
||||
let row = skip_row + thread_row;
|
||||
let col = skip_col + thread_col;
|
||||
|
||||
let batch = global_id.z;
|
||||
|
||||
// Basic information
|
||||
let dim = info[0];
|
||||
let n_rows = info[6u * dim - 1u];
|
||||
let n_cols = info[6u * dim];
|
||||
let K = info[5u * dim - 1u];
|
||||
|
||||
// Row / col strides
|
||||
let lhs_stride_row = info[dim - 1u];
|
||||
let lhs_stride_col = info[dim];
|
||||
let rhs_stride_row = info[2u * dim - 1u];
|
||||
let rhs_stride_col = info[2u * dim];
|
||||
let out_stride_row = info [3u * dim - 1u];
|
||||
let out_stride_col = info [3u * dim];
|
||||
|
||||
// Calculate the corresponding offsets with support for broadcasting.
|
||||
let offset_output = batch * n_rows * n_cols;
|
||||
var offset_lhs: u32 = skip_row * lhs_stride_row;
|
||||
var offset_rhs: u32 = skip_col * rhs_stride_col;
|
||||
|
||||
let batch_dims = dim - 2u;
|
||||
for (var b: u32 = 1u; b <= batch_dims; b++) {
|
||||
let stride_lhs = info[b];
|
||||
let stride_rhs = info[b + dim];
|
||||
let stride_output = info[b + 2u * dim];
|
||||
let shape_lhs = info[b + 3u * dim];
|
||||
let shape_rhs = info[b + 4u * dim];
|
||||
|
||||
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
|
||||
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
// Registers used in the compute pass
|
||||
var results: array<{{ elem }}, T_M_X_T_N>;
|
||||
var register_M: vec4<{{ elem }}>;
|
||||
var register_N: vec4<{{ elem }}>;
|
||||
|
||||
// How close is the thread to the end of the matrix.
|
||||
// If < 4 then it is an edge case
|
||||
let remain_row_lhs = n_rows - row;
|
||||
let remain_col_rhs = n_cols - col;
|
||||
|
||||
for (var k = 0u; k < K; k += B_K) {
|
||||
|
||||
// LHS LOAD PASS
|
||||
|
||||
// For the 4 vec4 columns of this thread
|
||||
for (var j = 0u; j < 4u; j++) {
|
||||
|
||||
// The precise
|
||||
let current_col = thread_col + j;
|
||||
|
||||
// Position of the column vec4 in shared memory
|
||||
let lhs_sm_position = (thread_row/4u) * B_K + current_col;
|
||||
|
||||
// To avoid overwriting following row in share memory
|
||||
if current_col < B_K {
|
||||
// To pad with zeros if outside lhs
|
||||
if current_col + k < K && remain_row_lhs >= 1u {
|
||||
let lhs_position0 = offset_lhs + (k + current_col) * lhs_stride_col + thread_row * lhs_stride_row;
|
||||
let lhs_position1 = lhs_position0 + lhs_stride_row;
|
||||
let lhs_position2 = lhs_position1 + lhs_stride_row;
|
||||
let lhs_position3 = lhs_position2 + lhs_stride_row;
|
||||
|
||||
if remain_row_lhs >= 4u {
|
||||
shared_lhs[lhs_sm_position] = vec4(
|
||||
lhs[lhs_position0],
|
||||
lhs[lhs_position1],
|
||||
lhs[lhs_position2],
|
||||
lhs[lhs_position3],
|
||||
);
|
||||
} else if remain_row_lhs == 3u {
|
||||
shared_lhs[lhs_sm_position] = vec4(
|
||||
lhs[lhs_position0],
|
||||
lhs[lhs_position1],
|
||||
lhs[lhs_position2],
|
||||
0.
|
||||
);
|
||||
} else if remain_row_lhs == 2u {
|
||||
shared_lhs[lhs_sm_position] = vec4(
|
||||
lhs[lhs_position0],
|
||||
lhs[lhs_position1],
|
||||
0.,
|
||||
0.
|
||||
);
|
||||
} else if remain_row_lhs == 1u {
|
||||
shared_lhs[lhs_sm_position] = vec4(
|
||||
lhs[lhs_position0],
|
||||
0.,
|
||||
0.,
|
||||
0.
|
||||
);
|
||||
}
|
||||
} else {
|
||||
shared_lhs[lhs_sm_position] = vec4(0.,0.,0.,0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RHS LOAD PASS
|
||||
|
||||
for (var i = 0u; i < 4u; i++) {
|
||||
let current_row = thread_row + i;
|
||||
|
||||
let rhs_sm_position = (current_row * B_N + thread_col) / 4u;
|
||||
|
||||
if current_row < B_K {
|
||||
if current_row + k < K && remain_col_rhs >= 1u {
|
||||
|
||||
let rhs_position0 = offset_rhs + (k + current_row) * rhs_stride_row + thread_col * rhs_stride_col;
|
||||
let rhs_position1 = rhs_position0 + rhs_stride_col;
|
||||
let rhs_position2 = rhs_position1 + rhs_stride_col;
|
||||
let rhs_position3 = rhs_position2 + rhs_stride_col;
|
||||
|
||||
if remain_col_rhs >= 4u {
|
||||
shared_rhs[rhs_sm_position] = vec4(
|
||||
rhs[rhs_position0],
|
||||
rhs[rhs_position1],
|
||||
rhs[rhs_position2],
|
||||
rhs[rhs_position3],
|
||||
);
|
||||
} else if remain_col_rhs == 3u {
|
||||
shared_rhs[rhs_sm_position] = vec4(
|
||||
rhs[rhs_position0],
|
||||
rhs[rhs_position1],
|
||||
rhs[rhs_position2],
|
||||
0.
|
||||
);
|
||||
} else if remain_col_rhs == 2u {
|
||||
shared_rhs[rhs_sm_position] = vec4(
|
||||
rhs[rhs_position0],
|
||||
rhs[rhs_position1],
|
||||
0.,
|
||||
0.
|
||||
);
|
||||
} else if remain_col_rhs == 1u {
|
||||
shared_rhs[rhs_sm_position] = vec4(
|
||||
rhs[rhs_position0],
|
||||
0.,
|
||||
0.,
|
||||
0.
|
||||
);
|
||||
}
|
||||
} else {
|
||||
shared_rhs[rhs_sm_position] = vec4(0.,0.,0.,0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// COMPUTE PASS
|
||||
|
||||
// Compute intermediate results
|
||||
// Results are cumulated in results array and updated at each block
|
||||
// Outer loop indicates which subcolumns/subrows to read from shared memories
|
||||
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
|
||||
|
||||
// Load a subcolumn of values from lhs
|
||||
let lhs_sm_position = (thread_row/4u) * B_K + dot_index;
|
||||
register_M = shared_lhs[lhs_sm_position];
|
||||
|
||||
// Load a subrow of values from rhs
|
||||
let rhs_sm_position = (dot_index * B_N + thread_col) / 4u;
|
||||
register_N = shared_rhs[rhs_sm_position];
|
||||
|
||||
// Multiply subcolumn and subrow and store results
|
||||
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
|
||||
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// OUTPUT PASS
|
||||
|
||||
// Write output matrix
|
||||
// Each thread is responsible of writing T_M x T_N results
|
||||
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
|
||||
let row_index = row + res_idx_M;
|
||||
let col_index = col + res_idx_N;
|
||||
if row_index < n_rows && col_index < n_cols {
|
||||
let result_position = res_idx_M * T_N + res_idx_N;
|
||||
let output_position = offset_output + row_index * out_stride_row + col_index * out_stride_col;
|
||||
output[output_position] = results[result_position];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,165 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const B_M = {{b_m}}u;
|
||||
const B_N = {{b_n}}u;
|
||||
const B_K = {{b_k}}u;
|
||||
const B_M_X_B_K_4 = {{bm_x_bk_4}}u;
|
||||
const B_K_X_B_N_4 = {{bk_x_bn_4}}u;
|
||||
|
||||
const T_M = 4u;
|
||||
const T_N = 4u;
|
||||
const T_M_X_T_N = 16u;
|
||||
|
||||
var<workgroup> shared_lhs: array<vec4<{{ elem }}>, B_M_X_B_K_4>;
|
||||
var<workgroup> shared_rhs: array<vec4<{{ elem }}>, B_K_X_B_N_4>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_index) local_idx: u32,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
) {
|
||||
let skip_row = workgroup_id.x * B_M;
|
||||
let skip_col = workgroup_id.y * B_N;
|
||||
|
||||
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
|
||||
let thread_row = (local_idx / n_thread_per_row) * T_M;
|
||||
let thread_col = (local_idx % n_thread_per_row) * T_N;
|
||||
|
||||
let row = skip_row + thread_row;
|
||||
let col = skip_col + thread_col;
|
||||
|
||||
let batch = global_id.z;
|
||||
|
||||
// Basic information
|
||||
let dim = info[0];
|
||||
let n_rows = info[6u * dim - 1u];
|
||||
let n_cols = info[6u * dim];
|
||||
let K = info[5u * dim - 1u];
|
||||
|
||||
// Row / col strides
|
||||
let lhs_stride_row = info[dim - 1u];
|
||||
let lhs_stride_col = info[dim];
|
||||
let rhs_stride_row = info[2u * dim - 1u];
|
||||
let rhs_stride_col = info[2u * dim];
|
||||
let out_stride_row = info [3u * dim - 1u];
|
||||
let out_stride_col = info [3u * dim];
|
||||
|
||||
// Calculate the corresponding offsets with support for broadcasting.
|
||||
let offset_output = batch * n_rows * n_cols;
|
||||
var offset_lhs: u32 = skip_row * lhs_stride_row;
|
||||
var offset_rhs: u32 = skip_col * rhs_stride_col;
|
||||
|
||||
let batch_dims = dim - 2u;
|
||||
for (var b: u32 = 1u; b <= batch_dims; b++) {
|
||||
let stride_lhs = info[b];
|
||||
let stride_rhs = info[b + dim];
|
||||
let stride_output = info[b + 2u * dim];
|
||||
let shape_lhs = info[b + 3u * dim];
|
||||
let shape_rhs = info[b + 4u * dim];
|
||||
|
||||
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
|
||||
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
var results: array<{{ elem }}, T_M_X_T_N>;
|
||||
var register_M: vec4<{{ elem }}>;
|
||||
var register_N: vec4<{{ elem }}>;
|
||||
|
||||
for (var k = 0u; k < K; k += B_K) {
|
||||
// Load data into shared memories
|
||||
// Each thread is responsible of loading T_M x T_N values from both lhs and rhs
|
||||
|
||||
for (var j = 0u; j < 4u; j++) {
|
||||
let current_col = thread_col + j;
|
||||
|
||||
if current_col < B_K { // so that threads who work on between B_K and B_N store nothing
|
||||
|
||||
let lhs_sm_position = (thread_row/4u) * B_K + current_col;
|
||||
|
||||
let lhs_position0 = offset_lhs + (k + current_col) * lhs_stride_col + thread_row * lhs_stride_row;
|
||||
let lhs_position1 = lhs_position0 + lhs_stride_row;
|
||||
let lhs_position2 = lhs_position1 + lhs_stride_row;
|
||||
let lhs_position3 = lhs_position2 + lhs_stride_row;
|
||||
|
||||
shared_lhs[lhs_sm_position] = vec4(
|
||||
lhs[lhs_position0],
|
||||
lhs[lhs_position1],
|
||||
lhs[lhs_position2],
|
||||
lhs[lhs_position3],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (var i = 0u; i < 4u; i++) {
|
||||
let current_row = thread_row + i;
|
||||
|
||||
if current_row < B_K { // so that threads who work on between B_K and B_N store nothing
|
||||
|
||||
let rhs_sm_position = (current_row * B_N + thread_col) / 4u;
|
||||
|
||||
let rhs_position0 = offset_rhs + (k + current_row) * rhs_stride_row + thread_col * rhs_stride_col;
|
||||
let rhs_position1 = rhs_position0 + rhs_stride_col;
|
||||
let rhs_position2 = rhs_position1 + rhs_stride_col;
|
||||
let rhs_position3 = rhs_position2 + rhs_stride_col;
|
||||
|
||||
shared_rhs[rhs_sm_position] = vec4(
|
||||
rhs[rhs_position0],
|
||||
rhs[rhs_position1],
|
||||
rhs[rhs_position2],
|
||||
rhs[rhs_position3],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute intermediate results
|
||||
// Results are cumulated in results array and updated at each block
|
||||
// Outer loop indicates which subcolumns/subrows to read from shared memories
|
||||
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
|
||||
|
||||
// Load a subcolumn of values from lhs
|
||||
let lhs_sm_position = (thread_row/4u) * B_K + dot_index;
|
||||
register_M = shared_lhs[lhs_sm_position];
|
||||
|
||||
// Load a subrow of values from rhs
|
||||
let rhs_sm_position = (dot_index * B_N + thread_col) / 4u;
|
||||
register_N = shared_rhs[rhs_sm_position];
|
||||
|
||||
// Multiply subcolumn and subrow and store results
|
||||
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
|
||||
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Write output matrix
|
||||
// Each thread is responsible of writing T_M x T_N results
|
||||
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
|
||||
let result_position = res_idx_M * T_N + res_idx_N;
|
||||
let output_position = offset_output + (row + res_idx_M) * out_stride_row + (col + res_idx_N) * out_stride_col;
|
||||
output[output_position] = results[result_position];
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,70 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const BLOCK_SIZE = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_index) local_idx: u32,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
) {
|
||||
// Indices
|
||||
let row = workgroup_id.x * BLOCK_SIZE + (local_idx / BLOCK_SIZE);
|
||||
let col = workgroup_id.y * BLOCK_SIZE + (local_idx % BLOCK_SIZE);
|
||||
let batch = global_id.z;
|
||||
|
||||
// Basic information
|
||||
let dim = info[0];
|
||||
let n_rows = info[6u * dim - 1u];
|
||||
let n_cols = info[6u * dim];
|
||||
let K = info[5u * dim - 1u];
|
||||
|
||||
// Returns if outside the output dimension
|
||||
if row >= n_rows || col >= n_cols {
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate the corresponding offsets with support for broadcasting.
|
||||
let offset_output = batch * n_rows * n_cols;
|
||||
var offset_lhs: u32 = 0u;
|
||||
var offset_rhs: u32 = 0u;
|
||||
|
||||
let batch_dims = dim - 2u;
|
||||
for (var b: u32 = 1u; b <= batch_dims; b++) {
|
||||
let stride_lhs = info[b];
|
||||
let stride_rhs = info[b + dim];
|
||||
let stride_output = info[b + 2u * dim];
|
||||
let shape_lhs = info[b + 3u * dim];
|
||||
let shape_rhs = info[b + 4u * dim];
|
||||
|
||||
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
|
||||
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
// Basic matmul implementation
|
||||
var sum = 0.0;
|
||||
for (var k: u32 = 0u; k < K; k++) {
|
||||
let lhs_index = row * K + k;
|
||||
let rhs_index = k * n_cols + col;
|
||||
|
||||
sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];
|
||||
}
|
||||
|
||||
let output_index = row * n_cols + col;
|
||||
output[offset_output + output_index] = sum;
|
||||
}
|
|
@ -1,74 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> x: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> info: array<u32, 16>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
|
||||
let input_stride_0 = info[0];
|
||||
let input_stride_1 = info[1];
|
||||
let input_stride_2 = info[2];
|
||||
let input_stride_3 = info[3];
|
||||
let input_shape_0 = info[4];
|
||||
let input_shape_1 = info[5];
|
||||
let input_shape_2 = info[6];
|
||||
let input_shape_3 = info[7];
|
||||
|
||||
let output_stride_0 = info[8];
|
||||
let output_stride_1 = info[9];
|
||||
let output_stride_2 = info[10];
|
||||
let output_stride_3 = info[11];
|
||||
let output_shape_0 = info[12];
|
||||
let output_shape_1 = info[13];
|
||||
let output_shape_2 = info[14];
|
||||
let output_shape_3 = info[15];
|
||||
|
||||
let b = id / output_stride_0 % output_shape_0;
|
||||
let c = id / output_stride_1 % output_shape_1;
|
||||
let oh = id / output_stride_2 % output_shape_2;
|
||||
let ow = id / output_stride_3 % output_shape_3;
|
||||
|
||||
let ih_start = start_index(oh, output_shape_2, input_shape_2);
|
||||
let ih_end = end_index(oh, output_shape_2, input_shape_2);
|
||||
|
||||
let iw_start = start_index(ow, output_shape_3, input_shape_3);
|
||||
let iw_end = end_index(ow, output_shape_3, input_shape_3);
|
||||
|
||||
var sum = 0.0;
|
||||
|
||||
for (var ih = ih_start; ih < ih_end; ih++) {
|
||||
for (var iw = iw_start; iw < iw_end; iw++) {
|
||||
let index_input = b * input_stride_0 + c * input_stride_1 + ih * input_stride_2 + iw * input_stride_3;
|
||||
sum += x[index_input];
|
||||
}
|
||||
}
|
||||
|
||||
let count = {{ elem }}((ih_end - ih_start) * (iw_end - iw_start));
|
||||
output[id] = sum / count;
|
||||
}
|
||||
|
||||
fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
|
||||
return u32(floor((f32(output_size_index) * f32(input_size)) / f32(output_size)));
|
||||
}
|
||||
|
||||
fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
|
||||
let index = u32(ceil((f32(output_size_index + 1u) * f32(input_size)) / f32(output_size)));
|
||||
|
||||
return min(index, input_size);
|
||||
}
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> x: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> info: array<u32, 22>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
const COUNT_INCLUDE_PAD = {{ count_include_pad }};
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
|
||||
let input_stride_0 = info[0];
|
||||
let input_stride_1 = info[1];
|
||||
let input_stride_2 = info[2];
|
||||
let input_stride_3 = info[3];
|
||||
let input_shape_0 = info[4];
|
||||
let input_shape_1 = info[5];
|
||||
let input_shape_2 = info[6];
|
||||
let input_shape_3 = info[7];
|
||||
|
||||
let output_stride_0 = info[8];
|
||||
let output_stride_1 = info[9];
|
||||
let output_stride_2 = info[10];
|
||||
let output_stride_3 = info[11];
|
||||
let output_shape_0 = info[12];
|
||||
let output_shape_1 = info[13];
|
||||
let output_shape_2 = info[14];
|
||||
let output_shape_3 = info[15];
|
||||
|
||||
let kernel_size_0 = info[16];
|
||||
let kernel_size_1 = info[17];
|
||||
let pool_stride_0 = info[18];
|
||||
let pool_stride_1 = info[19];
|
||||
let padding_0 = info[20];
|
||||
let padding_1 = info[21];
|
||||
|
||||
let b = id / output_stride_0 % output_shape_0;
|
||||
let c = id / output_stride_1 % output_shape_1;
|
||||
let oh = id / output_stride_2 % output_shape_2;
|
||||
let ow = id / output_stride_3 % output_shape_3;
|
||||
|
||||
var sum = 0.0;
|
||||
var count = 0.0;
|
||||
|
||||
for (var kh = 0u; kh < kernel_size_0; kh++) {
|
||||
let ih = oh * pool_stride_0 + kh;
|
||||
|
||||
// Padding
|
||||
if ih < padding_0 || ih >= input_shape_2 + padding_0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (var kw = 0u; kw < kernel_size_1; kw++) {
|
||||
let iw = ow * pool_stride_1 + kw;
|
||||
|
||||
// Padding
|
||||
if iw < padding_1 || iw >= input_shape_3 + padding_1 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Correct indexes for padding
|
||||
let ih_pad = ih - padding_0;
|
||||
let iw_pad = iw - padding_1;
|
||||
|
||||
let index_input = b * input_stride_0 + c * input_stride_1 + ih_pad * input_stride_2 + iw_pad * input_stride_3;
|
||||
count += 1.0;
|
||||
sum += x[index_input];
|
||||
}
|
||||
}
|
||||
|
||||
if COUNT_INCLUDE_PAD {
|
||||
count = {{ elem }}(kernel_size_1 * kernel_size_0);
|
||||
}
|
||||
output[id] = sum / count;
|
||||
}
|
|
@ -1,110 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> grad: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> info: array<u32, 22>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
const COUNT_INCLUDE_PAD = {{ count_include_pad }};
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
|
||||
let input_stride_0 = info[0];
|
||||
let input_stride_1 = info[1];
|
||||
let input_stride_2 = info[2];
|
||||
let input_stride_3 = info[3];
|
||||
let input_shape_0 = info[4];
|
||||
let input_shape_1 = info[5];
|
||||
let input_shape_2 = info[6];
|
||||
let input_shape_3 = info[7];
|
||||
|
||||
let grad_stride_0 = info[8];
|
||||
let grad_stride_1 = info[9];
|
||||
let grad_stride_2 = info[10];
|
||||
let grad_stride_3 = info[11];
|
||||
let grad_shape_0 = info[12];
|
||||
let grad_shape_1 = info[13];
|
||||
let grad_shape_2 = info[14];
|
||||
let grad_shape_3 = info[15];
|
||||
|
||||
let kernel_size_0 = info[16];
|
||||
let kernel_size_1 = info[17];
|
||||
let pool_stride_0 = info[18];
|
||||
let pool_stride_1 = info[19];
|
||||
let padding_0 = info[20];
|
||||
let padding_1 = info[21];
|
||||
|
||||
let b = id / input_stride_0 % input_shape_0;
|
||||
let c = id / input_stride_1 % input_shape_1;
|
||||
let ih = id / input_stride_2 % input_shape_2;
|
||||
let iw = id / input_stride_3 % input_shape_3;
|
||||
|
||||
// The maximum number of overlapping filters that may content the current index.
|
||||
let kms_0 = i32(kernel_size_0) - i32(pool_stride_0);
|
||||
let kms_1 = i32(kernel_size_1) - i32(pool_stride_1);
|
||||
|
||||
let oh_start_tmp = (i32(ih + padding_0) - kms_0) / i32(pool_stride_0);
|
||||
let ow_start_tmp = (i32(iw + padding_1) - kms_1) / i32(pool_stride_1);
|
||||
|
||||
let oh_start = u32(max(oh_start_tmp, 0));
|
||||
let ow_start = u32(max(ow_start_tmp, 0));
|
||||
|
||||
let oh_end = u32(max(kms_0, 0)) + oh_start;
|
||||
let ow_end = u32(max(kms_1, 0)) + ow_start;
|
||||
|
||||
var grad_acc = 0.0;
|
||||
// We iterate over each potentially resulting overlapping filters and check
|
||||
// if their max index is the current one.
|
||||
for (var oh = oh_start; oh <= oh_end; oh++) {
|
||||
for (var ow = ow_start; ow <= ow_end; ow++) {
|
||||
if oh >= grad_shape_2 || ow >= grad_shape_3 {
|
||||
continue;
|
||||
}
|
||||
|
||||
var ih_start = oh * pool_stride_0;
|
||||
var iw_start = ow * pool_stride_1;
|
||||
|
||||
var ih_end = ih_start + kernel_size_0;
|
||||
var iw_end = iw_start + kernel_size_1;
|
||||
|
||||
ih_start = max(ih_start, padding_0);
|
||||
iw_start = max(iw_start, padding_1);
|
||||
|
||||
ih_end = min(ih_end, input_shape_2 + padding_0);
|
||||
iw_end = min(iw_end, input_shape_3 + padding_1);
|
||||
|
||||
let contributed_h = ih + padding_0 >= ih_start && ih < ih_end;
|
||||
let contributed_w = iw + padding_1 >= iw_start && iw < iw_end;
|
||||
|
||||
// If no contribution or outside of output shape, skip.
|
||||
if !contributed_h || !contributed_w {
|
||||
continue;
|
||||
}
|
||||
|
||||
var count = 0.0;
|
||||
|
||||
if COUNT_INCLUDE_PAD {
|
||||
count = {{ elem }}(kernel_size_0 * kernel_size_1);
|
||||
} else {
|
||||
count = {{ elem }}((ih_end - ih_start) * (iw_end - iw_start));
|
||||
}
|
||||
|
||||
let index = b * grad_stride_0 + c * grad_stride_1 + oh * grad_stride_2 + ow * grad_stride_3;
|
||||
grad_acc += grad[index] / count;
|
||||
}
|
||||
}
|
||||
|
||||
output[id] = grad_acc;
|
||||
}
|
|
@ -432,10 +432,6 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
|
|||
input: self.compile_variable(op.input),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Ceil(op) => wgsl::Instruction::Ceil {
|
||||
input: self.compile_variable(op.input),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Log(op) => wgsl::Instruction::Log {
|
||||
input: self.compile_variable(op.input),
|
||||
out: self.compile_variable(op.out),
|
||||
|
@ -465,6 +461,14 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
|
|||
input: self.compile_variable(op.input),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Floor(op) => wgsl::Instruction::Floor {
|
||||
input: self.compile_variable(op.input),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Ceil(op) => wgsl::Instruction::Ceil {
|
||||
input: self.compile_variable(op.input),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Erf(op) => wgsl::Instruction::Erf {
|
||||
input: self.compile_variable(op.input),
|
||||
out: self.compile_variable(op.out),
|
||||
|
|
|
@ -109,10 +109,6 @@ pub enum Instruction {
|
|||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Ceil {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Erf {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
|
@ -214,6 +210,14 @@ pub enum Instruction {
|
|||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Floor {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Ceil {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
}
|
||||
|
||||
impl Display for Instruction {
|
||||
|
@ -276,9 +280,6 @@ impl Display for Instruction {
|
|||
Instruction::Sqrt { input, out } => {
|
||||
f.write_fmt(format_args!("{out} = sqrt({input});\n"))
|
||||
}
|
||||
Instruction::Ceil { input, out } => {
|
||||
f.write_fmt(format_args!("{out} = ceil({input});\n"))
|
||||
}
|
||||
Instruction::Log1p { input, out } => {
|
||||
f.write_fmt(format_args!("{out} = log({input} + 1.0);\n"))
|
||||
}
|
||||
|
@ -468,6 +469,12 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
|
|||
Instruction::ShiftRight { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = {lhs} >> {rhs};\n"))
|
||||
}
|
||||
Instruction::Floor { input, out } => {
|
||||
f.write_fmt(format_args!("{out} = floor({input});\n"))
|
||||
}
|
||||
Instruction::Ceil { input, out } => {
|
||||
f.write_fmt(format_args!("{out} = ceil({input});\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue