Migrate/jit/pooling (#1509)

* separate forward backward

* refactor with pool strategy

* refactor further

* pooling refactored

* refactoring for adaptive wip

* wip adaptive

* adaptive

* delete some wgsl

* avg pool backward

* clippy

* minor refactor
This commit is contained in:
Louis Fortier-Dubois 2024-03-25 16:04:58 -04:00 committed by GitHub
parent 613e698007
commit da5b0438ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1284 additions and 1320 deletions

View File

@ -275,6 +275,12 @@ macro_rules! gpu {
gpu!(unary $input, $out)
));
};
// out = floor(input)
($scope:expr, $out:ident = floor($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Floor(
gpu!(unary $input, $out)
));
};
// out = ceil(input)
($scope:expr, $out:ident = ceil($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Ceil(

View File

@ -36,6 +36,7 @@ pub enum Operator {
Tanh(UnaryOperator),
Powf(BinaryOperator),
Sqrt(UnaryOperator),
Floor(UnaryOperator),
Ceil(UnaryOperator),
Erf(UnaryOperator),
Recip(UnaryOperator),

View File

@ -43,6 +43,8 @@ impl Operator {
Operator::Sub(op) => Operator::Sub(op.vectorize(vectorization)),
Operator::Mul(op) => Operator::Mul(op.vectorize(vectorization)),
Operator::Div(op) => Operator::Div(op.vectorize(vectorization)),
Operator::Floor(op) => Operator::Floor(op.vectorize(vectorization)),
Operator::Ceil(op) => Operator::Ceil(op.vectorize(vectorization)),
Operator::Abs(op) => Operator::Abs(op.vectorize(vectorization)),
Operator::Exp(op) => Operator::Exp(op.vectorize(vectorization)),
Operator::Log(op) => Operator::Log(op.vectorize(vectorization)),
@ -52,7 +54,6 @@ impl Operator {
Operator::Tanh(op) => Operator::Tanh(op.vectorize(vectorization)),
Operator::Powf(op) => Operator::Powf(op.vectorize(vectorization)),
Operator::Sqrt(op) => Operator::Sqrt(op.vectorize(vectorization)),
Operator::Ceil(op) => Operator::Ceil(op.vectorize(vectorization)),
Operator::Erf(op) => Operator::Erf(op.vectorize(vectorization)),
Operator::Recip(op) => Operator::Recip(op.vectorize(vectorization)),
Operator::Equal(op) => Operator::Equal(op.vectorize(vectorization)),

View File

@ -247,11 +247,6 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Ceil(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Log(op) => mark_unary(
op,
&mut local_tensor_ids_input,
@ -326,6 +321,16 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Floor(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Ceil(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Modulo(op) => mark_binary(
op,
&mut local_tensor_ids_input,

View File

@ -1,100 +1,41 @@
use crate::{
compute::StaticKernel,
codegen::{execute_dynamic, EagerHandle, WorkgroupLaunch},
element::JitElement,
kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::JitTensor,
Runtime,
};
use burn_compute::server::Handle;
use burn_tensor::Shape;
kernel_wgsl!(
AdaptiveAvgPool2d,
"../../template/pool/adaptive_avg_pool2d.wgsl"
);
kernel_wgsl!(
AdaptiveAvgPool2dBackward,
"../../template/pool/adaptive_avg_pool2d_backward.wgsl"
);
use super::AdaptivePool2dEagerKernel;
pub(crate) fn adaptive_avg_pool2d<R: Runtime, E: JitElement>(
x: JitTensor<R, E, 4>,
input: JitTensor<R, E, 4>,
output_size: [usize; 2],
) -> JitTensor<R, E, 4> {
let [batch_size, channels, _, _] = x.shape.dims;
let [batch_size, channels, _, _] = input.shape.dims;
let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]);
let output = empty_device(x.client.clone(), x.device.clone(), output_shape);
let output = empty_device(input.client.clone(), input.device.clone(), output_shape);
let kernel = StaticKernel::<
KernelSettings<AdaptiveAvgPool2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
output.shape.num_elements(),
WORKGROUP_DEFAULT,
));
let kernel = AdaptivePool2dEagerKernel::new();
let info_handle = build_info(&x, &output);
x.client
.execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]);
output
}
pub(crate) fn adaptive_avg_pool2d_backward<R: Runtime, E: JitElement>(
x: JitTensor<R, E, 4>,
out_grad: JitTensor<R, E, 4>,
) -> JitTensor<R, E, 4> {
let output_shape = x.shape.clone();
let num_elems = output_shape.num_elements();
let output_buffer = x.client.empty(num_elems * core::mem::size_of::<E>());
let output = JitTensor::new(
x.client.clone(),
x.device.clone(),
output_shape,
output_buffer,
);
let kernel = StaticKernel::<
KernelSettings<AdaptiveAvgPool2dBackward, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
output.shape.num_elements(),
WORKGROUP_DEFAULT,
));
let info_handle = build_info(&x, &out_grad);
x.client.execute(
Box::new(kernel),
&[&out_grad.handle, &output.handle, &info_handle],
execute_dynamic::<R, AdaptivePool2dEagerKernel<R, E>, E>(
&[EagerHandle::new(
&input.handle,
&input.strides,
&input.shape.dims,
)],
&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
input.client,
);
output
}
fn build_info<R: Runtime, E: JitElement>(
x: &JitTensor<R, E, 4>,
output: &JitTensor<R, E, 4>,
) -> Handle<R::Server> {
let mut info: [u32; 16] = [0; 16];
info[0] = x.strides[0] as u32;
info[1] = x.strides[1] as u32;
info[2] = x.strides[2] as u32;
info[3] = x.strides[3] as u32;
info[4] = x.shape.dims[0] as u32;
info[5] = x.shape.dims[1] as u32;
info[6] = x.shape.dims[2] as u32;
info[7] = x.shape.dims[3] as u32;
info[8] = output.strides[0] as u32;
info[9] = output.strides[1] as u32;
info[10] = output.strides[2] as u32;
info[11] = output.strides[3] as u32;
info[12] = output.shape.dims[0] as u32;
info[13] = output.shape.dims[1] as u32;
info[14] = output.shape.dims[2] as u32;
info[15] = output.shape.dims[3] as u32;
output.client.create(bytemuck::cast_slice(&info))
}

View File

@ -0,0 +1,71 @@
use crate::{
compute::StaticKernel,
element::JitElement,
kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
tensor::JitTensor,
Runtime,
};
use burn_compute::server::Handle;
kernel_wgsl!(
AdaptiveAvgPool2dBackward,
"../../template/pool/adaptive_avg_pool2d_backward.wgsl"
);
pub(crate) fn adaptive_avg_pool2d_backward<R: Runtime, E: JitElement>(
x: JitTensor<R, E, 4>,
out_grad: JitTensor<R, E, 4>,
) -> JitTensor<R, E, 4> {
let output_shape = x.shape.clone();
let num_elems = output_shape.num_elements();
let output_buffer = x.client.empty(num_elems * core::mem::size_of::<E>());
let output = JitTensor::new(
x.client.clone(),
x.device.clone(),
output_shape,
output_buffer,
);
let kernel = StaticKernel::<
KernelSettings<AdaptiveAvgPool2dBackward, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
output.shape.num_elements(),
WORKGROUP_DEFAULT,
));
let info_handle = build_info(&x, &out_grad);
x.client.execute(
Box::new(kernel),
&[&out_grad.handle, &output.handle, &info_handle],
);
output
}
fn build_info<R: Runtime, E: JitElement>(
x: &JitTensor<R, E, 4>,
output: &JitTensor<R, E, 4>,
) -> Handle<R::Server> {
let mut info: [u32; 16] = [0; 16];
info[0] = x.strides[0] as u32;
info[1] = x.strides[1] as u32;
info[2] = x.strides[2] as u32;
info[3] = x.strides[3] as u32;
info[4] = x.shape.dims[0] as u32;
info[5] = x.shape.dims[1] as u32;
info[6] = x.shape.dims[2] as u32;
info[7] = x.shape.dims[3] as u32;
info[8] = output.strides[0] as u32;
info[9] = output.strides[1] as u32;
info[10] = output.strides[2] as u32;
info[11] = output.strides[3] as u32;
info[12] = output.shape.dims[0] as u32;
info[13] = output.shape.dims[1] as u32;
info[14] = output.shape.dims[2] as u32;
info[15] = output.shape.dims[3] as u32;
output.client.create(bytemuck::cast_slice(&info))
}

View File

@ -0,0 +1,229 @@
use std::marker::PhantomData;
use crate::{
codegen::{Compilation, CompilationInfo, CompilationSettings, InputInfo, OutputInfo},
gpu::{gpu, Elem, Scope, Variable, Visibility},
kernel::{DynamicKernelSource, SourceTemplate},
Compiler, JitElement, Runtime,
};
pub(crate) struct AdaptivePool2dComputeShader<R: Runtime, E: JitElement> {
input: Variable,
output: Variable,
_elem: PhantomData<E>,
_runtime: PhantomData<R>,
}
impl<R: Runtime, E: JitElement> AdaptivePool2dComputeShader<R, E> {
fn expand(self, scope: &mut Scope) {
let input = self.input;
let output = self.output;
let id = Variable::Id;
let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt);
let input_stride_2 = scope.create_local(Elem::UInt);
let input_stride_3 = scope.create_local(Elem::UInt);
let input_shape_0 = scope.create_local(Elem::UInt);
let input_shape_1 = scope.create_local(Elem::UInt);
let input_shape_2 = scope.create_local(Elem::UInt);
let input_shape_3 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
gpu!(scope, input_stride_0 = stride(input, 0u32));
gpu!(scope, input_stride_1 = stride(input, 1u32));
gpu!(scope, input_stride_2 = stride(input, 2u32));
gpu!(scope, input_stride_3 = stride(input, 3u32));
gpu!(scope, input_shape_0 = shape(input, 2u32));
gpu!(scope, input_shape_1 = shape(input, 3u32));
gpu!(scope, input_shape_2 = shape(input, 2u32));
gpu!(scope, input_shape_3 = shape(input, 3u32));
gpu!(scope, output_stride_0 = stride(output, 0u32));
gpu!(scope, output_stride_1 = stride(output, 1u32));
gpu!(scope, output_stride_2 = stride(output, 2u32));
gpu!(scope, output_stride_3 = stride(output, 3u32));
gpu!(scope, output_shape_0 = shape(output, 0u32));
gpu!(scope, output_shape_1 = shape(output, 1u32));
gpu!(scope, output_shape_2 = shape(output, 2u32));
gpu!(scope, output_shape_3 = shape(output, 3u32));
let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
let oh = scope.create_local(Elem::UInt);
let ow = scope.create_local(Elem::UInt);
gpu!(scope, b = id / output_stride_0);
gpu!(scope, b = b % output_shape_0);
gpu!(scope, c = id / output_stride_1);
gpu!(scope, c = c % output_shape_1);
gpu!(scope, oh = id / output_stride_2);
gpu!(scope, oh = oh % output_shape_2);
gpu!(scope, ow = id / output_stride_3);
gpu!(scope, ow = ow % output_shape_3);
let ih_start = Self::start_index(scope, oh, output_shape_2, input_shape_2);
let ih_end = Self::end_index(scope, oh, output_shape_2, input_shape_2);
let iw_start = Self::start_index(scope, ow, output_shape_3, input_shape_3);
let iw_end = Self::end_index(scope, ow, output_shape_3, input_shape_3);
let result = scope.create_local(input.item());
let index_input = scope.create_local(Elem::UInt);
let index_input_0 = scope.create_local(Elem::UInt);
let index_input_1 = scope.create_local(Elem::UInt);
let index_input_2 = scope.create_local(Elem::UInt);
let index_input_3 = scope.create_local(Elem::UInt);
gpu!(scope, index_input_0 = b * input_stride_0);
gpu!(scope, index_input_1 = c * input_stride_1);
let sum = scope.zero(output.item());
gpu!(
scope,
range(ih_start, ih_end).for_each(|ih, scope| {
gpu!(
scope,
range(iw_start, iw_end).for_each(|iw, scope| {
gpu!(scope, index_input_2 = ih * input_stride_2);
gpu!(scope, index_input_3 = iw * input_stride_3);
gpu!(scope, index_input = index_input_0);
gpu!(scope, index_input += index_input_1);
gpu!(scope, index_input += index_input_2);
gpu!(scope, index_input += index_input_3);
gpu!(scope, result = input[index_input]);
gpu!(scope, sum += result);
})
);
})
);
let count = scope.create_local(Elem::UInt);
let count_tmp = scope.create_local(Elem::UInt);
let count_float = scope.create_local(output.item());
let avg = scope.create_local(output.item());
gpu!(scope, count = ih_end - ih_start);
gpu!(scope, count_tmp = iw_end - iw_start);
gpu!(scope, count *= count_tmp);
gpu!(scope, count_float = cast(count));
gpu!(scope, avg = sum / count_float);
gpu!(scope, output[id] = avg);
}
fn start_index(
scope: &mut Scope,
output_size_index: Variable,
output_size: Variable,
input_size: Variable,
) -> Variable {
let numerator_float = scope.create_local(Elem::Float);
let div = scope.create_local(Elem::Float);
let index = scope.create_local(Elem::UInt);
gpu!(scope, index = output_size_index * input_size);
gpu!(scope, numerator_float = cast(index));
gpu!(scope, div = cast(output_size));
gpu!(scope, div = numerator_float / div);
gpu!(scope, div = floor(div));
gpu!(scope, index = cast(div));
index
}
fn end_index(
scope: &mut Scope,
output_size_index: Variable,
output_size: Variable,
input_size: Variable,
) -> Variable {
let numerator_float = scope.create_local(Elem::Float);
let div = scope.create_local(Elem::Float);
let index = scope.create_local(Elem::UInt);
let min = scope.create_local(Elem::Bool);
let end_index = scope.create_local(Elem::UInt);
gpu!(scope, index = output_size_index + 1u32);
gpu!(scope, index *= input_size);
gpu!(scope, numerator_float = cast(index));
gpu!(scope, div = cast(output_size));
gpu!(scope, div = numerator_float / div);
gpu!(scope, div = ceil(div));
gpu!(scope, index = cast(div));
gpu!(scope, min = input_size < index);
gpu!(scope, if(min).then(|scope|{
gpu!(scope, end_index = input_size);
}).else(|scope|{
gpu!(scope, end_index = index);
}));
end_index
}
}
#[derive(new)]
pub(crate) struct AdaptivePool2dEagerKernel<R: Runtime, E: JitElement> {
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}
impl<R: Runtime, E: JitElement> DynamicKernelSource for AdaptivePool2dEagerKernel<R, E> {
fn source(&self) -> crate::kernel::SourceTemplate {
let mut scope = Scope::root();
let item = E::gpu_elem().into();
let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
scope.write_global_custom(output);
AdaptivePool2dComputeShader {
input,
output,
_elem: PhantomData::<E>,
_runtime: PhantomData::<R>,
}
.expand(&mut scope);
let input = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let output = OutputInfo::Array { item };
let info = CompilationInfo {
inputs: vec![input],
outputs: vec![output],
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!("{:?}", core::any::TypeId::of::<Self>(),)
}
}

View File

@ -1,35 +1,72 @@
use crate::{
compute::{Kernel, StaticKernel},
codegen::{dialect::gpu::Variable, execute_dynamic, EagerHandle, WorkgroupLaunch},
element::JitElement,
kernel::{
self, elemwise_workgroup,
pool::{build_output_and_info_pool2d, build_pool2d_info},
KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT,
},
kernel_wgsl,
gpu::{gpu, Elem, Item, Scope},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime,
Runtime, RuntimeInt,
};
use burn_tensor::{ops::conv::calculate_pool_output_size, ElementConversion, Shape};
use std::fmt::Debug;
kernel_wgsl!(AvgPool2dRaw, "../../template/pool/avg_pool2d.wgsl");
kernel_wgsl!(
AvgPool2dBackwardRaw,
"../../template/pool/avg_pool2d_backward.wgsl"
);
use super::{Pool2dEagerKernel, PoolStrategy};
struct AvgPool2dBackward<const COUNT_INCLUDE_PAD: bool>;
struct AvgPool2d<const COUNT_INCLUDE_PAD: bool>;
impl<const COUNT_INCLUDE_PAD: bool> StaticKernelSource for AvgPool2dBackward<COUNT_INCLUDE_PAD> {
fn source() -> kernel::SourceTemplate {
AvgPool2dBackwardRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}"))
}
#[derive(new, Debug, Clone)]
struct AvgPool {
kernel_size: [usize; 2],
count_include_pad: bool,
}
impl<const COUNT_INCLUDE_PAD: bool> StaticKernelSource for AvgPool2d<COUNT_INCLUDE_PAD> {
fn source() -> kernel::SourceTemplate {
AvgPool2dRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}"))
impl PoolStrategy for AvgPool {
type Accumulator = (Variable, Variable);
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
let sum = scope.create_local(item);
let count = scope.create_local(Elem::UInt);
if self.count_include_pad {
let kernel_size: Variable = (self.kernel_size[0] * self.kernel_size[1]).into();
gpu!(scope, count = kernel_size);
} else {
let zero: Variable = 0u32.into();
gpu!(scope, count = zero);
}
(sum, count)
}
fn process_result(
&self,
scope: &mut Scope,
accumulator: Self::Accumulator,
result: Variable,
_idx: Variable,
) -> Self::Accumulator {
let (sum, count) = accumulator;
if !self.count_include_pad {
let one: Variable = 1u32.into();
gpu!(scope, count += one);
}
gpu!(scope, sum += result);
(sum, count)
}
fn assign(
&self,
scope: &mut Scope,
id: Variable,
output: Variable,
_indices: Option<Variable>,
accumulator: Self::Accumulator,
) {
let (sum, count) = accumulator;
let avg = scope.create_local(output.item());
let count_float = scope.create_local(output.item());
gpu!(scope, count_float = cast(count));
gpu!(scope, avg = sum / count_float);
gpu!(scope, output[id] = avg);
}
fn with_indices() -> bool {
false
}
}
@ -40,63 +77,49 @@ pub(crate) fn avg_pool2d<R: Runtime, E: JitElement>(
padding: [usize; 2],
count_include_pad: bool,
) -> JitTensor<R, E, 4> {
let (info_handle, output) =
build_output_and_info_pool2d(&x, kernel_size, stride, padding, [1, 1]);
let [batch_size, channels, _, _] = x.shape.dims;
let dilation = 1;
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT);
let kernel: Box<dyn Kernel> = match count_include_pad {
true => Box::new(StaticKernel::<
KernelSettings<AvgPool2d<true>, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(workgroup)),
false => Box::new(StaticKernel::<
KernelSettings<AvgPool2d<false>, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(workgroup)),
};
let size_0 = calculate_pool_output_size(
kernel_size[0],
stride[0],
padding[0],
dilation,
x.shape.dims[2],
);
let size_1 = calculate_pool_output_size(
kernel_size[1],
stride[1],
padding[1],
dilation,
x.shape.dims[3],
);
x.client
.execute(kernel, &[&x.handle, &output.handle, &info_handle]);
output
}
pub(crate) fn avg_pool2d_backward<R: Runtime, E: JitElement>(
x: JitTensor<R, E, 4>,
grad: JitTensor<R, E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
) -> JitTensor<R, E, 4> {
let grad = kernel::into_contiguous(grad);
let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone());
let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]);
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT);
let kernel: Box<dyn Kernel> = match count_include_pad {
true => Box::new(StaticKernel::<
KernelSettings<
AvgPool2dBackward<true>,
E,
i32,
WORKGROUP_DEFAULT,
WORKGROUP_DEFAULT,
1,
>,
>::new(workgroup)),
false => Box::new(StaticKernel::<
KernelSettings<
AvgPool2dBackward<false>,
E,
i32,
WORKGROUP_DEFAULT,
WORKGROUP_DEFAULT,
1,
>,
>::new(workgroup)),
};
x.client
.execute(kernel, &[&grad.handle, &output.handle, &info_handle]);
let shape_out = Shape::new([batch_size, channels, size_0, size_1]);
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
let pool_strategy = AvgPool::new(kernel_size, count_include_pad);
let kernel = Pool2dEagerKernel::new(kernel_size, pool_strategy);
execute_dynamic::<R, Pool2dEagerKernel<AvgPool, R, E>, RuntimeInt<R>>(
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&[
(stride[0] as u32).elem(),
(stride[1] as u32).elem(),
(dilation as u32).elem(),
(dilation as u32).elem(),
(padding[0] as u32).elem(),
(padding[1] as u32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
output
}

View File

@ -0,0 +1,408 @@
use burn_tensor::ElementConversion;
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
kernel::{self, DynamicKernelSource, SourceTemplate},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime, RuntimeInt,
};
use std::marker::PhantomData;
#[derive(new)]
struct AvgPool2dBackwardEagerKernel<R: Runtime, E: JitElement> {
kernel_size: [usize; 2],
count_include_pad: bool,
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}
struct AvgPool2dBackwardComputeShader {
grad: Variable,
output: Variable,
kernel_size: [usize; 2],
count_include_pad: bool,
}
impl AvgPool2dBackwardComputeShader {
fn expand(self, scope: &mut Scope) {
let grad = self.grad;
let output = self.output;
let id = Variable::Id;
let grad_stride_0 = scope.create_local(Elem::UInt);
let grad_stride_1 = scope.create_local(Elem::UInt);
let grad_stride_2 = scope.create_local(Elem::UInt);
let grad_stride_3 = scope.create_local(Elem::UInt);
let grad_shape_2 = scope.create_local(Elem::UInt);
let grad_shape_3 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
gpu!(scope, grad_stride_0 = stride(grad, 0u32));
gpu!(scope, grad_stride_1 = stride(grad, 1u32));
gpu!(scope, grad_stride_2 = stride(grad, 2u32));
gpu!(scope, grad_stride_3 = stride(grad, 3u32));
gpu!(scope, grad_shape_2 = shape(grad, 2u32));
gpu!(scope, grad_shape_3 = shape(grad, 3u32));
gpu!(scope, output_stride_0 = stride(output, 0u32));
gpu!(scope, output_stride_1 = stride(output, 1u32));
gpu!(scope, output_stride_2 = stride(output, 2u32));
gpu!(scope, output_stride_3 = stride(output, 3u32));
gpu!(scope, output_shape_0 = shape(output, 0u32));
gpu!(scope, output_shape_1 = shape(output, 1u32));
gpu!(scope, output_shape_2 = shape(output, 2u32));
gpu!(scope, output_shape_3 = shape(output, 3u32));
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
let [kernel_size_0, kernel_size_1] = self.kernel_size;
let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
let ih = scope.create_local(Elem::UInt);
let iw = scope.create_local(Elem::UInt);
gpu!(scope, b = id / output_stride_0);
gpu!(scope, b = b % output_shape_0);
gpu!(scope, c = id / output_stride_1);
gpu!(scope, c = c % output_shape_1);
gpu!(scope, ih = id / output_stride_2);
gpu!(scope, ih = ih % output_shape_2);
gpu!(scope, iw = id / output_stride_3);
gpu!(scope, iw = iw % output_shape_3);
let index_current = scope.create_local(Elem::UInt);
let index_current_tmp = scope.create_local(Elem::UInt);
gpu!(scope, index_current = ih * output_stride_2);
gpu!(scope, index_current_tmp = iw * output_stride_3);
gpu!(scope, index_current += index_current_tmp);
let index = scope.create_local(Elem::UInt);
let index_tmp = scope.create_local(Elem::UInt);
let index_base = scope.create_local(Elem::UInt);
let grad_accumulation = scope.zero(grad.item());
let result = scope.create_local(grad.item());
let count = scope.create_local(grad.item());
let count_include_pad = self.count_include_pad;
if count_include_pad {
let kernel_size: Variable = (self.kernel_size[0] * self.kernel_size[1]).into();
gpu!(scope, count = kernel_size);
}
let (oh_start, oh_end, ow_start, ow_end) = self.loop_ranges(
scope,
ih,
iw,
grad_shape_2,
grad_shape_3,
output_stride_2,
output_stride_3,
);
gpu!(scope, index_base = b * grad_stride_0);
gpu!(scope, index_tmp = c * grad_stride_1);
gpu!(scope, index_base += index_tmp);
let border_bottom = scope.create_local(Elem::UInt);
let border_right = scope.create_local(Elem::UInt);
let begin_h = scope.create_local(Elem::UInt);
let begin_w = scope.create_local(Elem::UInt);
let iw_start = scope.create_local(Elem::UInt);
let iw_end = scope.create_local(Elem::UInt);
let ih_start = scope.create_local(Elem::UInt);
let ih_end = scope.create_local(Elem::UInt);
let after_start = scope.create_local(Elem::Bool);
let before_end = scope.create_local(Elem::Bool);
let contributed_h = scope.create_local(Elem::Bool);
let contributed_w = scope.create_local(Elem::Bool);
gpu!(scope, border_bottom = output_shape_2 + padding_0);
gpu!(scope, border_right = output_shape_3 + padding_1);
gpu!(scope, begin_h = ih + padding_0);
gpu!(scope, begin_w = iw + padding_1);
let ih_diff = scope.create_local(Elem::UInt);
let iw_diff = scope.create_local(Elem::UInt);
let count_int = scope.create_local(Elem::UInt);
gpu!(
scope,
range(oh_start, oh_end).for_each(|oh, scope| {
// Contributed h
gpu!(scope, ih_start = oh * pool_stride_0);
gpu!(scope, ih_end = ih_start + kernel_size_0);
gpu!(scope, ih_start = max(ih_start, padding_0));
gpu!(scope, ih_end = min(ih_end, border_bottom));
gpu!(scope, after_start = begin_h >= ih_start);
gpu!(scope, before_end = ih < ih_end);
gpu!(scope, contributed_h = after_start && before_end);
if !count_include_pad {
gpu!(scope, ih_diff = ih_end - ih_start);
}
gpu!(scope, if(contributed_h).then(|scope|{
gpu!(
scope,
range(ow_start, ow_end).for_each(|ow, scope| {
gpu!(scope, index = index_base);
gpu!(scope, index_tmp = oh * grad_stride_2);
gpu!(scope, index += index_tmp);
gpu!(scope, index_tmp = ow * grad_stride_3);
gpu!(scope, index += index_tmp);
// Contributed w
gpu!(scope, iw_start = ow * pool_stride_1);
gpu!(scope, iw_end = iw_start + kernel_size_1);
gpu!(scope, iw_start = max(iw_start, padding_1));
gpu!(scope, iw_end = min(iw_end, border_right));
gpu!(scope, after_start = begin_w >= iw_start);
gpu!(scope, before_end = iw < iw_end);
gpu!(scope, contributed_w = after_start && before_end);
gpu!(scope, if(contributed_w).then(|scope|{
if !count_include_pad {
gpu!(scope, iw_diff = iw_end - iw_start);
gpu!(scope, count_int = ih_diff * iw_diff);
gpu!(scope, count = cast(count_int));
}
gpu!(scope, result = grad[index]);
gpu!(scope, result = result / count);
gpu!(scope, grad_accumulation += result);
}));
}));
}));
})
);
gpu!(scope, output[id] = grad_accumulation);
}
#[allow(clippy::too_many_arguments)]
fn loop_ranges(
self,
scope: &mut Scope,
ih: Variable,
iw: Variable,
grad_shape_2: Variable,
grad_shape_3: Variable,
output_stride_2: Variable,
output_stride_3: Variable,
) -> (Variable, Variable, Variable, Variable) {
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
let [kernel_size_0, kernel_size_1] = self.kernel_size;
let signed_ih = scope.create_local(Elem::Int);
let signed_iw = scope.create_local(Elem::Int);
let signed_pool_stride_0 = scope.create_local(Elem::Int);
let signed_pool_stride_1 = scope.create_local(Elem::Int);
let signed_dilation_0 = scope.create_local(Elem::Int);
let signed_dilation_1 = scope.create_local(Elem::Int);
let signed_padding_0 = scope.create_local(Elem::Int);
let signed_padding_1 = scope.create_local(Elem::Int);
let signed_kernel_size_0 = scope.create_local(Elem::Int);
let signed_kernel_size_1 = scope.create_local(Elem::Int);
gpu!(scope, signed_pool_stride_0 = cast(pool_stride_0));
gpu!(scope, signed_pool_stride_1 = cast(pool_stride_1));
gpu!(scope, signed_dilation_0 = cast(dilation_0));
gpu!(scope, signed_dilation_1 = cast(dilation_1));
gpu!(scope, signed_padding_0 = cast(padding_0));
gpu!(scope, signed_padding_1 = cast(padding_1));
gpu!(scope, signed_kernel_size_0 = cast(kernel_size_0));
gpu!(scope, signed_kernel_size_1 = cast(kernel_size_1));
gpu!(scope, signed_ih = cast(ih));
gpu!(scope, signed_iw = cast(iw));
let kms_0 = scope.create_local(Elem::Int);
let kms_1 = scope.create_local(Elem::Int);
gpu!(scope, kms_0 = signed_dilation_0 * signed_kernel_size_0);
gpu!(scope, kms_0 = kms_0 - signed_pool_stride_0);
gpu!(scope, kms_1 = signed_dilation_1 * signed_kernel_size_1);
gpu!(scope, kms_1 = kms_1 - signed_pool_stride_1);
let oh_start_tmp = scope.create_local(Elem::Int);
let ow_start_tmp = scope.create_local(Elem::Int);
gpu!(scope, oh_start_tmp = signed_ih + signed_padding_0);
gpu!(scope, oh_start_tmp = oh_start_tmp - kms_0);
gpu!(scope, oh_start_tmp = oh_start_tmp / signed_pool_stride_0);
gpu!(scope, ow_start_tmp = signed_iw + signed_padding_1);
gpu!(scope, ow_start_tmp = ow_start_tmp - kms_1);
gpu!(scope, ow_start_tmp = ow_start_tmp / signed_pool_stride_1);
gpu!(scope, oh_start_tmp = max(oh_start_tmp, 0i32));
gpu!(scope, ow_start_tmp = max(ow_start_tmp, 0i32));
let oh_start = scope.create_local(Elem::UInt);
let ow_start = scope.create_local(Elem::UInt);
gpu!(scope, oh_start = cast(oh_start_tmp));
gpu!(scope, ow_start = cast(ow_start_tmp));
let oh_end_tmp = scope.create_local(Elem::Int);
let ow_end_tmp = scope.create_local(Elem::Int);
gpu!(scope, oh_end_tmp = max(kms_0, 0i32));
gpu!(scope, ow_end_tmp = max(kms_1, 0i32));
let oh_end = scope.create_local(Elem::UInt);
let ow_end = scope.create_local(Elem::UInt);
let oh_end_limit = scope.create_local(Elem::UInt);
let ow_end_limit = scope.create_local(Elem::UInt);
gpu!(scope, oh_end = cast(oh_end_tmp));
gpu!(scope, ow_end = cast(ow_end_tmp));
gpu!(scope, oh_end = oh_end + oh_start);
gpu!(scope, oh_end_limit = grad_shape_2 - 1u32);
gpu!(scope, ow_end = ow_end + ow_start);
gpu!(scope, ow_end_limit = grad_shape_3 - 1u32);
gpu!(scope, oh_end = min(oh_end, oh_end_limit));
gpu!(scope, ow_end = min(ow_end, ow_end_limit));
let index_current = scope.create_local(Elem::UInt);
let index_current_tmp = scope.create_local(Elem::UInt);
gpu!(scope, index_current = ih * output_stride_2);
gpu!(scope, index_current_tmp = iw * output_stride_3);
gpu!(scope, index_current += index_current_tmp);
gpu!(scope, oh_end = oh_end + 1u32);
gpu!(scope, ow_end = ow_end + 1u32);
(oh_start, oh_end, ow_start, ow_end)
}
}
impl<R: Runtime, E: JitElement> DynamicKernelSource for AvgPool2dBackwardEagerKernel<R, E> {
fn source(&self) -> kernel::SourceTemplate {
let mut scope = Scope::root();
let item = E::gpu_elem().into();
let grad = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
scope.write_global_custom(output);
AvgPool2dBackwardComputeShader {
grad,
output,
kernel_size: self.kernel_size,
count_include_pad: self.count_include_pad,
}
.expand(&mut scope);
let grad = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let scalars = InputInfo::Scalar {
elem: Elem::UInt,
size: 6,
};
let output = OutputInfo::Array { item };
let info = CompilationInfo {
inputs: vec![grad, scalars],
outputs: vec![output],
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!(
"{:?}k={:?}count_include_pad={:?}",
core::any::TypeId::of::<Self>(),
self.kernel_size,
self.count_include_pad
)
}
}
pub(crate) fn avg_pool2d_backward<R: Runtime, E: JitElement>(
x: JitTensor<R, E, 4>,
grad: JitTensor<R, E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
) -> JitTensor<R, E, 4> {
let grad = kernel::into_contiguous(grad);
let dilation = 1;
let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone());
let kernel = AvgPool2dBackwardEagerKernel::new(kernel_size, count_include_pad);
execute_dynamic::<R, AvgPool2dBackwardEagerKernel<R, E>, RuntimeInt<R>>(
&[EagerHandle::new(
&grad.handle,
&grad.strides,
&grad.shape.dims,
)],
&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&[
(stride[0] as i32).elem(),
(stride[1] as i32).elem(),
dilation.elem(),
dilation.elem(),
(padding[0] as i32).elem(),
(padding[1] as i32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
output
}

View File

@ -1,72 +1,27 @@
use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, Runtime};
use burn_compute::server::Handle;
use burn_tensor::Shape;
use crate::gpu::{Item, Scope, Variable};
use std::fmt::Debug;
/// Build basic info to launch pool 2d kernels.
pub fn build_output_and_info_pool2d<R: Runtime, E: JitElement>(
x: &JitTensor<R, E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> (Handle<R::Server>, JitTensor<R, E, 4>) {
let [kernel_height, kernel_width] = kernel_size;
let [padding_height, padding_width] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let [batch_size, channels, x_height, x_width] = x.shape.dims;
pub(crate) trait PoolStrategy: Send + Sync + 'static + Clone + Debug {
type Accumulator: Copy;
let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1)
/ stride_height)
+ 1;
let out_width = ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1)
/ stride_width)
+ 1;
let shape_out = Shape::new([batch_size, channels, out_height, out_width]);
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator;
let info_buffer = build_pool2d_info(x, &output, kernel_size, stride, padding, dilation);
fn process_result(
&self,
scope: &mut Scope,
accumulator: Self::Accumulator,
result: Variable,
idx: Variable,
) -> Self::Accumulator;
(info_buffer, output)
}
pub fn build_pool2d_info<R: Runtime, E: JitElement>(
input: &JitTensor<R, E, 4>,
output: &JitTensor<R, E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> Handle<R::Server> {
let mut info: [u32; 24] = [0; 24];
info[0] = input.strides[0] as u32;
info[1] = input.strides[1] as u32;
info[2] = input.strides[2] as u32;
info[3] = input.strides[3] as u32;
info[4] = input.shape.dims[0] as u32;
info[5] = input.shape.dims[1] as u32;
info[6] = input.shape.dims[2] as u32;
info[7] = input.shape.dims[3] as u32;
info[8] = output.strides[0] as u32;
info[9] = output.strides[1] as u32;
info[10] = output.strides[2] as u32;
info[11] = output.strides[3] as u32;
info[12] = output.shape.dims[0] as u32;
info[13] = output.shape.dims[1] as u32;
info[14] = output.shape.dims[2] as u32;
info[15] = output.shape.dims[3] as u32;
info[16] = kernel_size[0] as u32;
info[17] = kernel_size[1] as u32;
info[18] = stride[0] as u32;
info[19] = stride[1] as u32;
info[20] = padding[0] as u32;
info[21] = padding[1] as u32;
info[22] = dilation[0] as u32;
info[23] = dilation[1] as u32;
let info_buffer = input.client.create(bytemuck::cast_slice(&info));
info_buffer
fn assign(
&self,
scope: &mut Scope,
id: Variable,
output: Variable,
indices: Option<Variable>,
accumulator: Self::Accumulator,
);
fn with_indices() -> bool;
}

View File

@ -1,302 +1,165 @@
use burn_tensor::{ops::conv::calculate_pool_output_size, ElementConversion, Shape};
use std::marker::PhantomData;
use std::{fmt::Debug, marker::PhantomData};
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
InputInfo, OutputInfo, WorkgroupLaunch,
},
codegen::{dialect::gpu::Variable, execute_dynamic, EagerHandle, WorkgroupLaunch},
element::JitElement,
kernel::{DynamicKernelSource, SourceTemplate},
gpu::{gpu, Elem, Item, Scope},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime, RuntimeInt,
};
use burn_tensor::{ops::conv::calculate_pool_output_size, ElementConversion, Shape};
#[derive(new)]
struct MaxPool2dEagerKernel<R: Runtime, E: JitElement> {
kernel_size: [usize; 2],
_runtime: PhantomData<R>,
use super::{Pool2dEagerKernel, PoolStrategy};
#[derive(Default, Debug, Clone)]
struct MaxPool<E: JitElement> {
_elem: PhantomData<E>,
}
#[derive(new)]
struct MaxPool2dWithIndicesEagerKernel<R: Runtime, E: JitElement> {
kernel_size: [usize; 2],
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}
impl<E: JitElement> PoolStrategy for MaxPool<E> {
type Accumulator = Variable;
struct MaxPool2dComputeShader<E: JitElement> {
x: Variable,
output: Variable,
kernel_size: [usize; 2],
indices: Option<Variable>,
_elem: PhantomData<E>,
}
impl<E: JitElement> MaxPool2dComputeShader<E> {
fn expand(self, scope: &mut Scope) {
let x = self.x;
let output = self.output;
let id = Variable::Id;
let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt);
let input_stride_2 = scope.create_local(Elem::UInt);
let input_stride_3 = scope.create_local(Elem::UInt);
let input_shape_2 = scope.create_local(Elem::UInt);
let input_shape_3 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
gpu!(scope, input_stride_0 = stride(x, 0u32));
gpu!(scope, input_stride_1 = stride(x, 1u32));
gpu!(scope, input_stride_2 = stride(x, 2u32));
gpu!(scope, input_stride_3 = stride(x, 3u32));
gpu!(scope, input_shape_2 = shape(x, 2u32));
gpu!(scope, input_shape_3 = shape(x, 3u32));
gpu!(scope, output_stride_0 = stride(output, 0u32));
gpu!(scope, output_stride_1 = stride(output, 1u32));
gpu!(scope, output_stride_2 = stride(output, 2u32));
gpu!(scope, output_stride_3 = stride(output, 3u32));
gpu!(scope, output_shape_0 = shape(output, 0u32));
gpu!(scope, output_shape_1 = shape(output, 1u32));
gpu!(scope, output_shape_2 = shape(output, 2u32));
gpu!(scope, output_shape_3 = shape(output, 3u32));
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
let [kernel_size_0, kernel_size_1] = self.kernel_size;
let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
let oh = scope.create_local(Elem::UInt);
let ow = scope.create_local(Elem::UInt);
gpu!(scope, b = id / output_stride_0);
gpu!(scope, b = b % output_shape_0);
gpu!(scope, c = id / output_stride_1);
gpu!(scope, c = c % output_shape_1);
gpu!(scope, oh = id / output_stride_2);
gpu!(scope, oh = oh % output_shape_2);
gpu!(scope, ow = id / output_stride_3);
gpu!(scope, ow = ow % output_shape_3);
let tmp = scope.create_local(Elem::UInt);
let ih = scope.create_local(Elem::UInt);
let iw = scope.create_local(Elem::UInt);
let ih_pad = scope.create_local(Elem::UInt);
let iw_pad = scope.create_local(Elem::UInt);
let result = scope.create_local(x.item());
let cond = scope.create_local(Elem::Bool);
let cond_tmp = scope.create_local(Elem::Bool);
let index_input = scope.create_local(Elem::UInt);
let index_input_1 = scope.create_local(Elem::UInt);
let index_input_2 = scope.create_local(Elem::UInt);
let index_input_3 = scope.create_local(Elem::UInt);
let index_input_4 = scope.create_local(Elem::UInt);
let is_max = scope.create_local(Elem::Bool);
let max_index = self.indices.map(|_| scope.create_local(Elem::UInt));
let max_val = scope.create_local(x.item());
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
let max_val = scope.create_local(item);
let max_initial =
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), x.item().elem());
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), item.elem());
gpu!(scope, max_val = max_initial);
max_val
}
(0..kernel_size_0).for_each(|kh| {
gpu!(scope, ih = oh * pool_stride_0);
gpu!(scope, tmp = kh * dilation_0);
gpu!(scope, ih += tmp);
fn process_result(
&self,
scope: &mut Scope,
accumulator: Self::Accumulator,
result: Variable,
_idx: Variable,
) -> Self::Accumulator {
let is_max = scope.create_local(Elem::Bool);
gpu!(scope, is_max = result > accumulator);
gpu!(scope, if(is_max).then(|scope|{
gpu!(scope, accumulator = result);
}));
accumulator
}
// Up
gpu!(scope, cond = ih < padding_0);
// Down
gpu!(scope, tmp = input_shape_2 + padding_0);
gpu!(scope, cond_tmp = ih >= tmp);
gpu!(scope, cond = cond || cond_tmp);
gpu!(scope, cond = !cond);
fn assign(
&self,
scope: &mut Scope,
id: Variable,
output: Variable,
_indices: Option<Variable>,
accumulator: Self::Accumulator,
) {
gpu!(scope, output[id] = accumulator);
}
gpu!(scope, if (cond).then(|scope| {
(0..kernel_size_1).for_each(|kw| {
gpu!(scope, iw = ow * pool_stride_1);
gpu!(scope, tmp = kw * dilation_1);
gpu!(scope, iw = iw + tmp);
fn with_indices() -> bool {
false
}
}
// Left
gpu!(scope, cond = iw < padding_1);
// Right
gpu!(scope, tmp = input_shape_3 + padding_1);
gpu!(scope, cond_tmp = iw >= tmp);
gpu!(scope, cond = cond || cond_tmp);
gpu!(scope, cond = !cond);
#[derive(Default, Debug, Clone)]
struct MaxPoolWithIndices<E: JitElement> {
_elem: PhantomData<E>,
}
gpu!(scope, if (cond).then(|scope| {
gpu!(scope, ih_pad = ih - padding_0);
gpu!(scope, iw_pad = iw - padding_1);
impl<E: JitElement> PoolStrategy for MaxPoolWithIndices<E> {
type Accumulator = (Variable, Variable);
gpu!(scope, index_input_1 = b * input_stride_0);
gpu!(scope, index_input_2 = c * input_stride_1);
gpu!(scope, index_input_3 = ih_pad * input_stride_2);
gpu!(scope, index_input_4 = iw_pad * input_stride_3);
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
let max_val = scope.create_local(item);
let max_initial =
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), item.elem());
gpu!(scope, max_val = max_initial);
let max_index = scope.create_local(Elem::UInt);
(max_val, max_index)
}
gpu!(scope, index_input = index_input_1);
gpu!(scope, index_input += index_input_2);
gpu!(scope, index_input += index_input_3);
gpu!(scope, index_input += index_input_4);
gpu!(scope, result = x[index_input]);
gpu!(scope, is_max = result > max_val);
gpu!(scope, if(is_max).then(|scope|{
gpu!(scope, max_val = result);
if let Some(max_index) = max_index {
gpu!(scope, max_index = ih_pad * input_shape_2);
gpu!(scope, max_index += iw_pad);
}
}));
}));
});
}));
});
fn process_result(
&self,
scope: &mut Scope,
(max_val, max_index): Self::Accumulator,
result: Variable,
idx: Variable,
) -> Self::Accumulator {
let is_max = scope.create_local(Elem::Bool);
gpu!(scope, is_max = result > max_val);
gpu!(scope, if(is_max).then(|scope|{
gpu!(scope, max_val = result);
gpu!(scope, max_index = idx);
}));
(max_val, max_index)
}
fn assign(
&self,
scope: &mut Scope,
id: Variable,
output: Variable,
indices: Option<Variable>,
(max_val, max_index): Self::Accumulator,
) {
let indices = indices.unwrap();
gpu!(scope, output[id] = max_val);
gpu!(scope, indices[id] = max_index);
}
if let Some(indices) = self.indices {
let max_index = max_index.unwrap();
gpu!(scope, indices[id] = max_index);
}
fn with_indices() -> bool {
true
}
}
impl<R: Runtime, E: JitElement> DynamicKernelSource for MaxPool2dEagerKernel<R, E> {
fn source(&self) -> crate::kernel::SourceTemplate {
let mut scope = Scope::root();
let item = E::gpu_elem().into();
pub(crate) fn max_pool2d<R: Runtime, E: JitElement>(
x: JitTensor<R, E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> JitTensor<R, E, 4> {
let [batch_size, channels, _, _] = x.shape.dims;
let x = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let size_0 = calculate_pool_output_size(
kernel_size[0],
stride[0],
padding[0],
dilation[0],
x.shape.dims[2],
);
let size_1 = calculate_pool_output_size(
kernel_size[1],
stride[1],
padding[1],
dilation[1],
x.shape.dims[3],
);
scope.write_global_custom(output);
let shape_out = Shape::new([batch_size, channels, size_0, size_1]);
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
MaxPool2dComputeShader {
x,
output,
kernel_size: self.kernel_size,
indices: None,
_elem: PhantomData::<E>,
}
.expand(&mut scope);
let kernel = Pool2dEagerKernel::new(kernel_size, MaxPool::default());
let input = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let scalars = InputInfo::Scalar {
elem: Elem::UInt,
size: 6,
};
let output = OutputInfo::Array { item };
execute_dynamic::<R, Pool2dEagerKernel<MaxPool<E>, R, E>, RuntimeInt<R>>(
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&[
(stride[0] as u32).elem(),
(stride[1] as u32).elem(),
(dilation[0] as u32).elem(),
(dilation[1] as u32).elem(),
(padding[0] as u32).elem(),
(padding[1] as u32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
let info = CompilationInfo {
inputs: vec![input, scalars],
outputs: vec![output],
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!(
"{:?}k={:?}",
core::any::TypeId::of::<Self>(),
self.kernel_size,
)
}
}
impl<R: Runtime, E: JitElement> DynamicKernelSource for MaxPool2dWithIndicesEagerKernel<R, E> {
fn source(&self) -> crate::kernel::SourceTemplate {
let mut scope = Scope::root();
let item = E::gpu_elem().into();
let x = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let indices = Variable::GlobalOutputArray(1, Item::Scalar(Elem::Int));
scope.write_global_custom(output);
MaxPool2dComputeShader {
x,
output,
kernel_size: self.kernel_size,
indices: Some(indices),
_elem: PhantomData::<E>,
}
.expand(&mut scope);
let input = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let scalars = InputInfo::Scalar {
elem: Elem::UInt,
size: 6,
};
let output = OutputInfo::Array { item };
let indices = OutputInfo::Array {
item: Item::Scalar(Elem::Int),
};
let info = CompilationInfo {
inputs: vec![input, scalars],
outputs: vec![output, indices],
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!(
"{:?}k={:?}",
core::any::TypeId::of::<Self>(),
self.kernel_size,
)
}
output
}
pub(crate) fn max_pool2d_with_indices<R: Runtime, E: JitElement, I: JitElement>(
@ -327,8 +190,9 @@ pub(crate) fn max_pool2d_with_indices<R: Runtime, E: JitElement, I: JitElement>(
let output = empty_device(x.client.clone(), x.device.clone(), shape_out.clone());
let indices = empty_device(x.client.clone(), x.device.clone(), shape_out);
let kernel = MaxPool2dWithIndicesEagerKernel::new(kernel_size);
execute_dynamic::<R, MaxPool2dWithIndicesEagerKernel<R, E>, I>(
let kernel = Pool2dEagerKernel::new(kernel_size, MaxPoolWithIndices::default());
execute_dynamic::<R, Pool2dEagerKernel<MaxPoolWithIndices<E>, R, E>, I>(
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
&[
EagerHandle::new(&output.handle, &output.strides, &output.shape.dims),
@ -349,55 +213,3 @@ pub(crate) fn max_pool2d_with_indices<R: Runtime, E: JitElement, I: JitElement>(
(output, indices)
}
pub(crate) fn max_pool2d<R: Runtime, E: JitElement>(
x: JitTensor<R, E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> JitTensor<R, E, 4> {
let [batch_size, channels, _, _] = x.shape.dims;
let size_0 = calculate_pool_output_size(
kernel_size[0],
stride[0],
padding[0],
dilation[0],
x.shape.dims[2],
);
let size_1 = calculate_pool_output_size(
kernel_size[1],
stride[1],
padding[1],
dilation[1],
x.shape.dims[3],
);
let shape_out = Shape::new([batch_size, channels, size_0, size_1]);
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
let kernel = MaxPool2dEagerKernel::new(kernel_size);
execute_dynamic::<R, MaxPool2dEagerKernel<R, E>, RuntimeInt<R>>(
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&[
(stride[0] as i32).elem(),
(stride[1] as i32).elem(),
(dilation[0] as i32).elem(),
(dilation[1] as i32).elem(),
(padding[0] as i32).elem(),
(padding[1] as i32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
output
}

View File

@ -101,6 +101,7 @@ impl MaxPool2dBackwardComputeShader {
let is_max = scope.create_local(Elem::Bool);
let index = scope.create_local(Elem::UInt);
let index_base = scope.create_local(Elem::UInt);
let index_tmp = scope.create_local(Elem::UInt);
let grad_accumulation = scope.zero(grad.item());
@ -116,20 +117,19 @@ impl MaxPool2dBackwardComputeShader {
output_stride_3,
);
gpu!(scope, index_base = b * grad_stride_0);
gpu!(scope, index_tmp = c * grad_stride_1);
gpu!(scope, index_base += index_tmp);
gpu!(
scope,
range(oh_start, oh_end).for_each(|oh, scope| {
gpu!(
scope,
range(ow_start, ow_end).for_each(|ow, scope| {
gpu!(scope, index = b * grad_stride_0);
gpu!(scope, index_tmp = c * grad_stride_1);
gpu!(scope, index += index_tmp);
gpu!(scope, index = index_base);
gpu!(scope, index_tmp = oh * grad_stride_2);
gpu!(scope, index += index_tmp);
gpu!(scope, index_tmp = ow * grad_stride_3);
gpu!(scope, index += index_tmp);

View File

@ -1,12 +1,19 @@
mod adaptive_avg_pool2d;
mod adaptive_avg_pool2d_backward;
mod adaptive_pool2d_shader;
mod avg_pool2d;
mod avg_pool2d_backward;
mod base;
mod max_pool2d;
mod max_pool2d_backward;
mod pool2d_shader;
pub(crate) use adaptive_pool2d_shader::*;
pub(crate) use pool2d_shader::*;
pub(crate) use adaptive_avg_pool2d::*;
pub use avg_pool2d::*;
pub use adaptive_avg_pool2d_backward::*;
pub(crate) use avg_pool2d::*;
pub(crate) use avg_pool2d_backward::*;
pub(super) use base::*;
pub(crate) use max_pool2d::*;
pub(crate) use max_pool2d_backward::*;

View File

@ -0,0 +1,243 @@
use std::marker::PhantomData;
use crate::{
codegen::{Compilation, CompilationInfo, CompilationSettings, InputInfo, OutputInfo},
gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
kernel::{DynamicKernelSource, SourceTemplate},
Compiler, JitElement, Runtime,
};
use super::PoolStrategy;
pub(crate) struct Pool2dComputeShader<P: PoolStrategy, R: Runtime, E: JitElement> {
input: Variable,
output: Variable,
indices: Option<Variable>,
kernel_size: [usize; 2],
pool_strategy: P,
_elem: PhantomData<E>,
_runtime: PhantomData<R>,
}
impl<P: PoolStrategy, R: Runtime, E: JitElement> Pool2dComputeShader<P, R, E> {
fn expand(self, scope: &mut Scope) {
let input = self.input;
let output = self.output;
let id = Variable::Id;
let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt);
let input_stride_2 = scope.create_local(Elem::UInt);
let input_stride_3 = scope.create_local(Elem::UInt);
let input_shape_0 = scope.create_local(Elem::UInt);
let input_shape_1 = scope.create_local(Elem::UInt);
let input_shape_2 = scope.create_local(Elem::UInt);
let input_shape_3 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
gpu!(scope, input_stride_0 = stride(input, 0u32));
gpu!(scope, input_stride_1 = stride(input, 1u32));
gpu!(scope, input_stride_2 = stride(input, 2u32));
gpu!(scope, input_stride_3 = stride(input, 3u32));
gpu!(scope, input_shape_0 = shape(input, 2u32));
gpu!(scope, input_shape_1 = shape(input, 3u32));
gpu!(scope, input_shape_2 = shape(input, 2u32));
gpu!(scope, input_shape_3 = shape(input, 3u32));
gpu!(scope, output_stride_0 = stride(output, 0u32));
gpu!(scope, output_stride_1 = stride(output, 1u32));
gpu!(scope, output_stride_2 = stride(output, 2u32));
gpu!(scope, output_stride_3 = stride(output, 3u32));
gpu!(scope, output_shape_0 = shape(output, 0u32));
gpu!(scope, output_shape_1 = shape(output, 1u32));
gpu!(scope, output_shape_2 = shape(output, 2u32));
gpu!(scope, output_shape_3 = shape(output, 3u32));
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
let oh = scope.create_local(Elem::UInt);
let ow = scope.create_local(Elem::UInt);
gpu!(scope, b = id / output_stride_0);
gpu!(scope, b = b % output_shape_0);
gpu!(scope, c = id / output_stride_1);
gpu!(scope, c = c % output_shape_1);
gpu!(scope, oh = id / output_stride_2);
gpu!(scope, oh = oh % output_shape_2);
gpu!(scope, ow = id / output_stride_3);
gpu!(scope, ow = ow % output_shape_3);
let ih = scope.create_local(Elem::UInt);
let iw = scope.create_local(Elem::UInt);
let dilated = scope.create_local(Elem::UInt);
let ih_pad = scope.create_local(Elem::UInt);
let iw_pad = scope.create_local(Elem::UInt);
let result = scope.create_local(input.item());
let index_input = scope.create_local(Elem::UInt);
let index_input_0 = scope.create_local(Elem::UInt);
let index_input_1 = scope.create_local(Elem::UInt);
let index_input_2 = scope.create_local(Elem::UInt);
let index_input_3 = scope.create_local(Elem::UInt);
let idx = scope.create_local(Elem::UInt);
let within_padding_h = scope.create_local(Elem::Bool);
let within_padding_w = scope.create_local(Elem::Bool);
let tmp_padding = scope.create_local(Elem::Bool);
let border_bottom = scope.create_local(Elem::UInt);
let border_right = scope.create_local(Elem::UInt);
gpu!(scope, border_bottom = input_shape_2 + padding_0);
gpu!(scope, border_right = input_shape_3 + padding_1);
gpu!(scope, index_input_0 = b * input_stride_0);
gpu!(scope, index_input_1 = c * input_stride_1);
let accumulator = self.pool_strategy.initialize(scope, input.item());
(0..self.kernel_size[0]).for_each(|kh| {
gpu!(scope, ih = oh * pool_stride_0);
gpu!(scope, dilated = kh * dilation_0);
gpu!(scope, ih += dilated);
gpu!(scope, within_padding_h = ih >= padding_0);
gpu!(scope, tmp_padding = ih < border_bottom);
gpu!(scope, within_padding_h = within_padding_h && tmp_padding);
gpu!(scope, if (within_padding_h).then(|scope| {
(0..self.kernel_size[1]).for_each(|kw| {
gpu!(scope, iw = ow * pool_stride_1);
gpu!(scope, dilated = kw * dilation_1);
gpu!(scope, iw += dilated);
gpu!(scope, within_padding_w = iw >= padding_1);
gpu!(scope, tmp_padding = iw < border_right);
gpu!(scope, within_padding_w = within_padding_w && tmp_padding);
gpu!(scope, if (within_padding_w).then(|scope| {
gpu!(scope, ih_pad = ih - padding_0);
gpu!(scope, iw_pad = iw - padding_1);
gpu!(scope, index_input_2 = ih_pad * input_stride_2);
gpu!(scope, idx = index_input_2);
gpu!(scope, idx += iw_pad);
gpu!(scope, index_input_3 = iw_pad * input_stride_3);
gpu!(scope, index_input = index_input_0);
gpu!(scope, index_input += index_input_1);
gpu!(scope, index_input += index_input_2);
gpu!(scope, index_input += index_input_3);
gpu!(scope, result = input[index_input]);
self.pool_strategy.process_result(scope, accumulator, result, idx);
}));
});
}));
});
self.pool_strategy
.assign(scope, id, output, self.indices, accumulator);
}
}
#[derive(new)]
pub(crate) struct Pool2dEagerKernel<P: PoolStrategy, R: Runtime, E: JitElement> {
kernel_size: [usize; 2],
pool_strategy: P,
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}
impl<P: PoolStrategy, R: Runtime, E: JitElement> DynamicKernelSource
for Pool2dEagerKernel<P, R, E>
{
fn source(&self) -> crate::kernel::SourceTemplate {
let mut scope = Scope::root();
let item = E::gpu_elem().into();
let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let indices = if P::with_indices() {
Some(Variable::GlobalOutputArray(1, Item::Scalar(Elem::Int)))
} else {
None
};
scope.write_global_custom(output);
Pool2dComputeShader {
input,
output,
indices,
kernel_size: self.kernel_size,
pool_strategy: self.pool_strategy.clone(),
_elem: PhantomData::<E>,
_runtime: PhantomData::<R>,
}
.expand(&mut scope);
let input = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let scalars = InputInfo::Scalar {
elem: Elem::UInt,
size: 6,
};
let output = OutputInfo::Array { item };
let outputs = if P::with_indices() {
vec![
output,
OutputInfo::Array {
item: Item::Scalar(Elem::Int),
},
]
} else {
vec![output]
};
let info = CompilationInfo {
inputs: vec![input, scalars],
outputs,
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!(
"{:?}k={:?}pl={:?}",
core::any::TypeId::of::<Self>(),
self.kernel_size,
self.pool_strategy
)
}
}

View File

@ -1,243 +0,0 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const B_M = {{b_m}}u;
const B_N = {{b_n}}u;
const B_K = {{b_k}}u;
const B_M_X_B_K_4 = {{bm_x_bk_4}}u;
const B_K_X_B_N_4 = {{bk_x_bn_4}}u;
const T_M = 4u;
const T_N = 4u;
const T_M_X_T_N = 16u;
var<workgroup> shared_lhs: array<vec4<{{ elem }}>, B_M_X_B_K_4>;
var<workgroup> shared_rhs: array<vec4<{{ elem }}>, B_K_X_B_N_4>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let skip_row = workgroup_id.x * B_M;
let skip_col = workgroup_id.y * B_N;
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
// Position of the first element of the thread, relative to the block
let thread_row = (local_idx / n_thread_per_row) * T_M;
let thread_col = (local_idx % n_thread_per_row) * T_N;
// Position of the first element of the thread, in absolute (in one batch)
let row = skip_row + thread_row;
let col = skip_col + thread_col;
let batch = global_id.z;
// Basic information
let dim = info[0];
let n_rows = info[6u * dim - 1u];
let n_cols = info[6u * dim];
let K = info[5u * dim - 1u];
// Row / col strides
let lhs_stride_row = info[dim - 1u];
let lhs_stride_col = info[dim];
let rhs_stride_row = info[2u * dim - 1u];
let rhs_stride_col = info[2u * dim];
let out_stride_row = info [3u * dim - 1u];
let out_stride_col = info [3u * dim];
// Calculate the corresponding offsets with support for broadcasting.
let offset_output = batch * n_rows * n_cols;
var offset_lhs: u32 = skip_row * lhs_stride_row;
var offset_rhs: u32 = skip_col * rhs_stride_col;
let batch_dims = dim - 2u;
for (var b: u32 = 1u; b <= batch_dims; b++) {
let stride_lhs = info[b];
let stride_rhs = info[b + dim];
let stride_output = info[b + 2u * dim];
let shape_lhs = info[b + 3u * dim];
let shape_rhs = info[b + 4u * dim];
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
}
// Registers used in the compute pass
var results: array<{{ elem }}, T_M_X_T_N>;
var register_M: vec4<{{ elem }}>;
var register_N: vec4<{{ elem }}>;
// How close is the thread to the end of the matrix.
// If < 4 then it is an edge case
let remain_row_lhs = n_rows - row;
let remain_col_rhs = n_cols - col;
for (var k = 0u; k < K; k += B_K) {
// LHS LOAD PASS
// For the 4 vec4 columns of this thread
for (var j = 0u; j < 4u; j++) {
// The precise
let current_col = thread_col + j;
// Position of the column vec4 in shared memory
let lhs_sm_position = (thread_row/4u) * B_K + current_col;
// To avoid overwriting following row in share memory
if current_col < B_K {
// To pad with zeros if outside lhs
if current_col + k < K && remain_row_lhs >= 1u {
let lhs_position0 = offset_lhs + (k + current_col) * lhs_stride_col + thread_row * lhs_stride_row;
let lhs_position1 = lhs_position0 + lhs_stride_row;
let lhs_position2 = lhs_position1 + lhs_stride_row;
let lhs_position3 = lhs_position2 + lhs_stride_row;
if remain_row_lhs >= 4u {
shared_lhs[lhs_sm_position] = vec4(
lhs[lhs_position0],
lhs[lhs_position1],
lhs[lhs_position2],
lhs[lhs_position3],
);
} else if remain_row_lhs == 3u {
shared_lhs[lhs_sm_position] = vec4(
lhs[lhs_position0],
lhs[lhs_position1],
lhs[lhs_position2],
0.
);
} else if remain_row_lhs == 2u {
shared_lhs[lhs_sm_position] = vec4(
lhs[lhs_position0],
lhs[lhs_position1],
0.,
0.
);
} else if remain_row_lhs == 1u {
shared_lhs[lhs_sm_position] = vec4(
lhs[lhs_position0],
0.,
0.,
0.
);
}
} else {
shared_lhs[lhs_sm_position] = vec4(0.,0.,0.,0.);
}
}
}
// RHS LOAD PASS
for (var i = 0u; i < 4u; i++) {
let current_row = thread_row + i;
let rhs_sm_position = (current_row * B_N + thread_col) / 4u;
if current_row < B_K {
if current_row + k < K && remain_col_rhs >= 1u {
let rhs_position0 = offset_rhs + (k + current_row) * rhs_stride_row + thread_col * rhs_stride_col;
let rhs_position1 = rhs_position0 + rhs_stride_col;
let rhs_position2 = rhs_position1 + rhs_stride_col;
let rhs_position3 = rhs_position2 + rhs_stride_col;
if remain_col_rhs >= 4u {
shared_rhs[rhs_sm_position] = vec4(
rhs[rhs_position0],
rhs[rhs_position1],
rhs[rhs_position2],
rhs[rhs_position3],
);
} else if remain_col_rhs == 3u {
shared_rhs[rhs_sm_position] = vec4(
rhs[rhs_position0],
rhs[rhs_position1],
rhs[rhs_position2],
0.
);
} else if remain_col_rhs == 2u {
shared_rhs[rhs_sm_position] = vec4(
rhs[rhs_position0],
rhs[rhs_position1],
0.,
0.
);
} else if remain_col_rhs == 1u {
shared_rhs[rhs_sm_position] = vec4(
rhs[rhs_position0],
0.,
0.,
0.
);
}
} else {
shared_rhs[rhs_sm_position] = vec4(0.,0.,0.,0.);
}
}
}
workgroupBarrier();
// COMPUTE PASS
// Compute intermediate results
// Results are cumulated in results array and updated at each block
// Outer loop indicates which subcolumns/subrows to read from shared memories
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
// Load a subcolumn of values from lhs
let lhs_sm_position = (thread_row/4u) * B_K + dot_index;
register_M = shared_lhs[lhs_sm_position];
// Load a subrow of values from rhs
let rhs_sm_position = (dot_index * B_N + thread_col) / 4u;
register_N = shared_rhs[rhs_sm_position];
// Multiply subcolumn and subrow and store results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
}
}
}
workgroupBarrier();
}
// OUTPUT PASS
// Write output matrix
// Each thread is responsible of writing T_M x T_N results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
let row_index = row + res_idx_M;
let col_index = col + res_idx_N;
if row_index < n_rows && col_index < n_cols {
let result_position = res_idx_M * T_N + res_idx_N;
let output_position = offset_output + row_index * out_stride_row + col_index * out_stride_col;
output[output_position] = results[result_position];
}
}
}
}

View File

@ -1,165 +0,0 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const B_M = {{b_m}}u;
const B_N = {{b_n}}u;
const B_K = {{b_k}}u;
const B_M_X_B_K_4 = {{bm_x_bk_4}}u;
const B_K_X_B_N_4 = {{bk_x_bn_4}}u;
const T_M = 4u;
const T_N = 4u;
const T_M_X_T_N = 16u;
var<workgroup> shared_lhs: array<vec4<{{ elem }}>, B_M_X_B_K_4>;
var<workgroup> shared_rhs: array<vec4<{{ elem }}>, B_K_X_B_N_4>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let skip_row = workgroup_id.x * B_M;
let skip_col = workgroup_id.y * B_N;
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
let thread_row = (local_idx / n_thread_per_row) * T_M;
let thread_col = (local_idx % n_thread_per_row) * T_N;
let row = skip_row + thread_row;
let col = skip_col + thread_col;
let batch = global_id.z;
// Basic information
let dim = info[0];
let n_rows = info[6u * dim - 1u];
let n_cols = info[6u * dim];
let K = info[5u * dim - 1u];
// Row / col strides
let lhs_stride_row = info[dim - 1u];
let lhs_stride_col = info[dim];
let rhs_stride_row = info[2u * dim - 1u];
let rhs_stride_col = info[2u * dim];
let out_stride_row = info [3u * dim - 1u];
let out_stride_col = info [3u * dim];
// Calculate the corresponding offsets with support for broadcasting.
let offset_output = batch * n_rows * n_cols;
var offset_lhs: u32 = skip_row * lhs_stride_row;
var offset_rhs: u32 = skip_col * rhs_stride_col;
let batch_dims = dim - 2u;
for (var b: u32 = 1u; b <= batch_dims; b++) {
let stride_lhs = info[b];
let stride_rhs = info[b + dim];
let stride_output = info[b + 2u * dim];
let shape_lhs = info[b + 3u * dim];
let shape_rhs = info[b + 4u * dim];
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
}
var results: array<{{ elem }}, T_M_X_T_N>;
var register_M: vec4<{{ elem }}>;
var register_N: vec4<{{ elem }}>;
for (var k = 0u; k < K; k += B_K) {
// Load data into shared memories
// Each thread is responsible of loading T_M x T_N values from both lhs and rhs
for (var j = 0u; j < 4u; j++) {
let current_col = thread_col + j;
if current_col < B_K { // so that threads who work on between B_K and B_N store nothing
let lhs_sm_position = (thread_row/4u) * B_K + current_col;
let lhs_position0 = offset_lhs + (k + current_col) * lhs_stride_col + thread_row * lhs_stride_row;
let lhs_position1 = lhs_position0 + lhs_stride_row;
let lhs_position2 = lhs_position1 + lhs_stride_row;
let lhs_position3 = lhs_position2 + lhs_stride_row;
shared_lhs[lhs_sm_position] = vec4(
lhs[lhs_position0],
lhs[lhs_position1],
lhs[lhs_position2],
lhs[lhs_position3],
);
}
}
for (var i = 0u; i < 4u; i++) {
let current_row = thread_row + i;
if current_row < B_K { // so that threads who work on between B_K and B_N store nothing
let rhs_sm_position = (current_row * B_N + thread_col) / 4u;
let rhs_position0 = offset_rhs + (k + current_row) * rhs_stride_row + thread_col * rhs_stride_col;
let rhs_position1 = rhs_position0 + rhs_stride_col;
let rhs_position2 = rhs_position1 + rhs_stride_col;
let rhs_position3 = rhs_position2 + rhs_stride_col;
shared_rhs[rhs_sm_position] = vec4(
rhs[rhs_position0],
rhs[rhs_position1],
rhs[rhs_position2],
rhs[rhs_position3],
);
}
}
workgroupBarrier();
// Compute intermediate results
// Results are cumulated in results array and updated at each block
// Outer loop indicates which subcolumns/subrows to read from shared memories
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
// Load a subcolumn of values from lhs
let lhs_sm_position = (thread_row/4u) * B_K + dot_index;
register_M = shared_lhs[lhs_sm_position];
// Load a subrow of values from rhs
let rhs_sm_position = (dot_index * B_N + thread_col) / 4u;
register_N = shared_rhs[rhs_sm_position];
// Multiply subcolumn and subrow and store results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
}
}
}
workgroupBarrier();
}
// Write output matrix
// Each thread is responsible of writing T_M x T_N results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
let result_position = res_idx_M * T_N + res_idx_N;
let output_position = offset_output + (row + res_idx_M) * out_stride_row + (col + res_idx_N) * out_stride_col;
output[output_position] = results[result_position];
}
}
}

View File

@ -1,70 +0,0 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const BLOCK_SIZE = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
// Indices
let row = workgroup_id.x * BLOCK_SIZE + (local_idx / BLOCK_SIZE);
let col = workgroup_id.y * BLOCK_SIZE + (local_idx % BLOCK_SIZE);
let batch = global_id.z;
// Basic information
let dim = info[0];
let n_rows = info[6u * dim - 1u];
let n_cols = info[6u * dim];
let K = info[5u * dim - 1u];
// Returns if outside the output dimension
if row >= n_rows || col >= n_cols {
return;
}
// Calculate the corresponding offsets with support for broadcasting.
let offset_output = batch * n_rows * n_cols;
var offset_lhs: u32 = 0u;
var offset_rhs: u32 = 0u;
let batch_dims = dim - 2u;
for (var b: u32 = 1u; b <= batch_dims; b++) {
let stride_lhs = info[b];
let stride_rhs = info[b + dim];
let stride_output = info[b + 2u * dim];
let shape_lhs = info[b + 3u * dim];
let shape_rhs = info[b + 4u * dim];
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
}
// Basic matmul implementation
var sum = 0.0;
for (var k: u32 = 0u; k < K; k++) {
let lhs_index = row * K + k;
let rhs_index = k * n_cols + col;
sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];
}
let output_index = row * n_cols + col;
output[offset_output + output_index] = sum;
}

View File

@ -1,74 +0,0 @@
@group(0)
@binding(0)
var<storage, read> x: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> info: array<u32, 16>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let input_stride_0 = info[0];
let input_stride_1 = info[1];
let input_stride_2 = info[2];
let input_stride_3 = info[3];
let input_shape_0 = info[4];
let input_shape_1 = info[5];
let input_shape_2 = info[6];
let input_shape_3 = info[7];
let output_stride_0 = info[8];
let output_stride_1 = info[9];
let output_stride_2 = info[10];
let output_stride_3 = info[11];
let output_shape_0 = info[12];
let output_shape_1 = info[13];
let output_shape_2 = info[14];
let output_shape_3 = info[15];
let b = id / output_stride_0 % output_shape_0;
let c = id / output_stride_1 % output_shape_1;
let oh = id / output_stride_2 % output_shape_2;
let ow = id / output_stride_3 % output_shape_3;
let ih_start = start_index(oh, output_shape_2, input_shape_2);
let ih_end = end_index(oh, output_shape_2, input_shape_2);
let iw_start = start_index(ow, output_shape_3, input_shape_3);
let iw_end = end_index(ow, output_shape_3, input_shape_3);
var sum = 0.0;
for (var ih = ih_start; ih < ih_end; ih++) {
for (var iw = iw_start; iw < iw_end; iw++) {
let index_input = b * input_stride_0 + c * input_stride_1 + ih * input_stride_2 + iw * input_stride_3;
sum += x[index_input];
}
}
let count = {{ elem }}((ih_end - ih_start) * (iw_end - iw_start));
output[id] = sum / count;
}
fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
return u32(floor((f32(output_size_index) * f32(input_size)) / f32(output_size)));
}
fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
let index = u32(ceil((f32(output_size_index + 1u) * f32(input_size)) / f32(output_size)));
return min(index, input_size);
}

View File

@ -1,87 +0,0 @@
@group(0)
@binding(0)
var<storage, read> x: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> info: array<u32, 22>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
const COUNT_INCLUDE_PAD = {{ count_include_pad }};
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let input_stride_0 = info[0];
let input_stride_1 = info[1];
let input_stride_2 = info[2];
let input_stride_3 = info[3];
let input_shape_0 = info[4];
let input_shape_1 = info[5];
let input_shape_2 = info[6];
let input_shape_3 = info[7];
let output_stride_0 = info[8];
let output_stride_1 = info[9];
let output_stride_2 = info[10];
let output_stride_3 = info[11];
let output_shape_0 = info[12];
let output_shape_1 = info[13];
let output_shape_2 = info[14];
let output_shape_3 = info[15];
let kernel_size_0 = info[16];
let kernel_size_1 = info[17];
let pool_stride_0 = info[18];
let pool_stride_1 = info[19];
let padding_0 = info[20];
let padding_1 = info[21];
let b = id / output_stride_0 % output_shape_0;
let c = id / output_stride_1 % output_shape_1;
let oh = id / output_stride_2 % output_shape_2;
let ow = id / output_stride_3 % output_shape_3;
var sum = 0.0;
var count = 0.0;
for (var kh = 0u; kh < kernel_size_0; kh++) {
let ih = oh * pool_stride_0 + kh;
// Padding
if ih < padding_0 || ih >= input_shape_2 + padding_0 {
continue;
}
for (var kw = 0u; kw < kernel_size_1; kw++) {
let iw = ow * pool_stride_1 + kw;
// Padding
if iw < padding_1 || iw >= input_shape_3 + padding_1 {
continue;
}
// Correct indexes for padding
let ih_pad = ih - padding_0;
let iw_pad = iw - padding_1;
let index_input = b * input_stride_0 + c * input_stride_1 + ih_pad * input_stride_2 + iw_pad * input_stride_3;
count += 1.0;
sum += x[index_input];
}
}
if COUNT_INCLUDE_PAD {
count = {{ elem }}(kernel_size_1 * kernel_size_0);
}
output[id] = sum / count;
}

View File

@ -1,110 +0,0 @@
@group(0)
@binding(0)
var<storage, read> grad: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> info: array<u32, 22>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
const COUNT_INCLUDE_PAD = {{ count_include_pad }};
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let input_stride_0 = info[0];
let input_stride_1 = info[1];
let input_stride_2 = info[2];
let input_stride_3 = info[3];
let input_shape_0 = info[4];
let input_shape_1 = info[5];
let input_shape_2 = info[6];
let input_shape_3 = info[7];
let grad_stride_0 = info[8];
let grad_stride_1 = info[9];
let grad_stride_2 = info[10];
let grad_stride_3 = info[11];
let grad_shape_0 = info[12];
let grad_shape_1 = info[13];
let grad_shape_2 = info[14];
let grad_shape_3 = info[15];
let kernel_size_0 = info[16];
let kernel_size_1 = info[17];
let pool_stride_0 = info[18];
let pool_stride_1 = info[19];
let padding_0 = info[20];
let padding_1 = info[21];
let b = id / input_stride_0 % input_shape_0;
let c = id / input_stride_1 % input_shape_1;
let ih = id / input_stride_2 % input_shape_2;
let iw = id / input_stride_3 % input_shape_3;
// The maximum number of overlapping filters that may content the current index.
let kms_0 = i32(kernel_size_0) - i32(pool_stride_0);
let kms_1 = i32(kernel_size_1) - i32(pool_stride_1);
let oh_start_tmp = (i32(ih + padding_0) - kms_0) / i32(pool_stride_0);
let ow_start_tmp = (i32(iw + padding_1) - kms_1) / i32(pool_stride_1);
let oh_start = u32(max(oh_start_tmp, 0));
let ow_start = u32(max(ow_start_tmp, 0));
let oh_end = u32(max(kms_0, 0)) + oh_start;
let ow_end = u32(max(kms_1, 0)) + ow_start;
var grad_acc = 0.0;
// We iterate over each potentially resulting overlapping filters and check
// if their max index is the current one.
for (var oh = oh_start; oh <= oh_end; oh++) {
for (var ow = ow_start; ow <= ow_end; ow++) {
if oh >= grad_shape_2 || ow >= grad_shape_3 {
continue;
}
var ih_start = oh * pool_stride_0;
var iw_start = ow * pool_stride_1;
var ih_end = ih_start + kernel_size_0;
var iw_end = iw_start + kernel_size_1;
ih_start = max(ih_start, padding_0);
iw_start = max(iw_start, padding_1);
ih_end = min(ih_end, input_shape_2 + padding_0);
iw_end = min(iw_end, input_shape_3 + padding_1);
let contributed_h = ih + padding_0 >= ih_start && ih < ih_end;
let contributed_w = iw + padding_1 >= iw_start && iw < iw_end;
// If no contribution or outside of output shape, skip.
if !contributed_h || !contributed_w {
continue;
}
var count = 0.0;
if COUNT_INCLUDE_PAD {
count = {{ elem }}(kernel_size_0 * kernel_size_1);
} else {
count = {{ elem }}((ih_end - ih_start) * (iw_end - iw_start));
}
let index = b * grad_stride_0 + c * grad_stride_1 + oh * grad_stride_2 + ow * grad_stride_3;
grad_acc += grad[index] / count;
}
}
output[id] = grad_acc;
}

View File

@ -432,10 +432,6 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
},
gpu::Operator::Ceil(op) => wgsl::Instruction::Ceil {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
},
gpu::Operator::Log(op) => wgsl::Instruction::Log {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
@ -465,6 +461,14 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
},
gpu::Operator::Floor(op) => wgsl::Instruction::Floor {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
},
gpu::Operator::Ceil(op) => wgsl::Instruction::Ceil {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
},
gpu::Operator::Erf(op) => wgsl::Instruction::Erf {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),

View File

@ -109,10 +109,6 @@ pub enum Instruction {
input: Variable,
out: Variable,
},
Ceil {
input: Variable,
out: Variable,
},
Erf {
input: Variable,
out: Variable,
@ -214,6 +210,14 @@ pub enum Instruction {
rhs: Variable,
out: Variable,
},
Floor {
input: Variable,
out: Variable,
},
Ceil {
input: Variable,
out: Variable,
},
}
impl Display for Instruction {
@ -276,9 +280,6 @@ impl Display for Instruction {
Instruction::Sqrt { input, out } => {
f.write_fmt(format_args!("{out} = sqrt({input});\n"))
}
Instruction::Ceil { input, out } => {
f.write_fmt(format_args!("{out} = ceil({input});\n"))
}
Instruction::Log1p { input, out } => {
f.write_fmt(format_args!("{out} = log({input} + 1.0);\n"))
}
@ -468,6 +469,12 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
Instruction::ShiftRight { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} >> {rhs};\n"))
}
Instruction::Floor { input, out } => {
f.write_fmt(format_args!("{out} = floor({input});\n"))
}
Instruction::Ceil { input, out } => {
f.write_fmt(format_args!("{out} = ceil({input});\n"))
}
}
}
}