mirror of https://github.com/tracel-ai/burn.git
Conv Transpose: migration + decent speedup (#1541)
* convtranspose benchmark * adjust bench * conv transpose works * Conv Transpose: migration + decent speedup * delete template folder * typos * fix
This commit is contained in:
parent
b8fc3f141e
commit
279be0496a
|
@ -65,6 +65,11 @@ name = "max-pool2d"
|
|||
path = "benches/max_pool2d.rs"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "conv-transpose2d"
|
||||
path = "benches/conv_transpose2d.rs"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "matmul"
|
||||
harness = false
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
use backend_comparison::persistence::save;
|
||||
use burn::tensor::{
|
||||
backend::Backend, module::conv_transpose2d, ops::ConvTransposeOptions, Distribution, Shape,
|
||||
Tensor,
|
||||
};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
|
||||
pub struct ConvTranspose2dBenchmark<B: Backend> {
|
||||
input_shape: Shape<4>,
|
||||
weight_shape: Shape<4>,
|
||||
bias_shape: Shape<1>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
device: B::Device,
|
||||
}
|
||||
|
||||
impl<B: Backend> Benchmark for ConvTranspose2dBenchmark<B> {
|
||||
type Args = (Tensor<B, 4>, Tensor<B, 4>, Tensor<B, 1>);
|
||||
|
||||
fn name(&self) -> String {
|
||||
"conv_transpose2d".into()
|
||||
}
|
||||
|
||||
fn shapes(&self) -> Vec<Vec<usize>> {
|
||||
vec![
|
||||
self.input_shape.dims.into(),
|
||||
self.weight_shape.dims.into(),
|
||||
self.bias_shape.dims.into(),
|
||||
]
|
||||
}
|
||||
|
||||
fn execute(&self, (x, w, b): Self::Args) {
|
||||
conv_transpose2d(x, w, Some(b), self.options.clone());
|
||||
}
|
||||
|
||||
fn prepare(&self) -> Self::Args {
|
||||
(
|
||||
Tensor::random(
|
||||
self.input_shape.clone(),
|
||||
Distribution::Default,
|
||||
&self.device,
|
||||
),
|
||||
Tensor::random(
|
||||
self.weight_shape.clone(),
|
||||
Distribution::Default,
|
||||
&self.device,
|
||||
),
|
||||
Tensor::random(self.bias_shape.clone(), Distribution::Default, &self.device),
|
||||
)
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn bench<B: Backend>(device: &B::Device, url: Option<&str>, token: Option<&str>) {
|
||||
// Shapes
|
||||
let batch_size = 16;
|
||||
let channels_in = 16;
|
||||
let channels_out = 16;
|
||||
let height_in = 64;
|
||||
let width_in = 64;
|
||||
let kernel_size_0 = 8;
|
||||
let kernel_size_1 = 8;
|
||||
|
||||
// Options
|
||||
let strides = [1, 1];
|
||||
let padding = [0, 0];
|
||||
let padding_out = [0, 0];
|
||||
let dilations = [1, 1];
|
||||
let groups = 1;
|
||||
let options = ConvTransposeOptions::new(strides, padding, padding_out, dilations, groups);
|
||||
let benchmark = ConvTranspose2dBenchmark::<B> {
|
||||
input_shape: [batch_size, channels_in, height_in, width_in].into(),
|
||||
weight_shape: [
|
||||
channels_in,
|
||||
channels_out / groups,
|
||||
kernel_size_0,
|
||||
kernel_size_1,
|
||||
]
|
||||
.into(),
|
||||
bias_shape: [channels_out].into(),
|
||||
options,
|
||||
device: device.clone(),
|
||||
};
|
||||
|
||||
save::<B>(vec![run_benchmark(benchmark)], device, url, token).unwrap();
|
||||
}
|
||||
|
||||
fn main() {
|
||||
backend_comparison::bench_on_backend!();
|
||||
}
|
|
@ -1,15 +1,348 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
compute::StaticKernel,
|
||||
codegen::{
|
||||
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
kernel::{self, DynamicKernelSource, SourceTemplate},
|
||||
ops::{
|
||||
numeric::{empty_device, zeros_device},
|
||||
reshape,
|
||||
},
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
Compiler, Runtime,
|
||||
};
|
||||
use burn_tensor::{ops::ConvTransposeOptions, Element, ElementConversion, Shape};
|
||||
|
||||
kernel_wgsl!(ConvTranspose2d, "../../template/conv/conv_transpose2d.wgsl");
|
||||
#[derive(new)]
|
||||
struct Conv2dTransposeEagerKernel<R, E> {
|
||||
_runtime: PhantomData<R>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct Conv2dTransposeComputeShader<E> {
|
||||
input: Variable,
|
||||
weight: Variable,
|
||||
bias: Variable,
|
||||
output: Variable,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
||||
fn expand(self, scope: &mut Scope) {
|
||||
let input = self.input;
|
||||
let weight = self.weight;
|
||||
let bias = self.bias;
|
||||
let output = self.output;
|
||||
let id = Variable::Id;
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||
let input_stride_2 = scope.create_local(Elem::UInt);
|
||||
let input_stride_3 = scope.create_local(Elem::UInt);
|
||||
let input_shape_0 = scope.create_local(Elem::UInt);
|
||||
let input_shape_1 = scope.create_local(Elem::UInt);
|
||||
let input_shape_2 = scope.create_local(Elem::UInt);
|
||||
let input_shape_3 = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, input_stride_0 = stride(input, 0u32));
|
||||
gpu!(scope, input_stride_1 = stride(input, 1u32));
|
||||
gpu!(scope, input_stride_2 = stride(input, 2u32));
|
||||
gpu!(scope, input_stride_3 = stride(input, 3u32));
|
||||
gpu!(scope, input_shape_0 = shape(input, 0u32));
|
||||
gpu!(scope, input_shape_1 = shape(input, 1u32));
|
||||
gpu!(scope, input_shape_2 = shape(input, 2u32));
|
||||
gpu!(scope, input_shape_3 = shape(input, 3u32));
|
||||
|
||||
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, output_stride_0 = stride(output, 0u32));
|
||||
gpu!(scope, output_stride_1 = stride(output, 1u32));
|
||||
gpu!(scope, output_stride_2 = stride(output, 2u32));
|
||||
gpu!(scope, output_stride_3 = stride(output, 3u32));
|
||||
gpu!(scope, output_shape_0 = shape(output, 0u32));
|
||||
gpu!(scope, output_shape_1 = shape(output, 1u32));
|
||||
gpu!(scope, output_shape_2 = shape(output, 2u32));
|
||||
gpu!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let weight_stride_0 = scope.create_local(Elem::UInt);
|
||||
let weight_stride_1 = scope.create_local(Elem::UInt);
|
||||
let weight_stride_2 = scope.create_local(Elem::UInt);
|
||||
let weight_stride_3 = scope.create_local(Elem::UInt);
|
||||
let in_channels = scope.create_local(Elem::UInt);
|
||||
let weight_shape_1 = scope.create_local(Elem::UInt);
|
||||
let kernel_size_0 = scope.create_local(Elem::UInt);
|
||||
let kernel_size_1 = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, weight_stride_0 = stride(weight, 0u32));
|
||||
gpu!(scope, weight_stride_1 = stride(weight, 1u32));
|
||||
gpu!(scope, weight_stride_2 = stride(weight, 2u32));
|
||||
gpu!(scope, weight_stride_3 = stride(weight, 3u32));
|
||||
gpu!(scope, in_channels = shape(weight, 0u32));
|
||||
gpu!(scope, weight_shape_1 = shape(weight, 1u32));
|
||||
gpu!(scope, kernel_size_0 = shape(weight, 2u32));
|
||||
gpu!(scope, kernel_size_1 = shape(weight, 3u32));
|
||||
|
||||
let conv_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
||||
let conv_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
||||
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
|
||||
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
|
||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
||||
let groups = Variable::GlobalScalar(6, Elem::UInt);
|
||||
|
||||
let stride_0_i = scope.create_local(Elem::Int);
|
||||
let stride_1_i = scope.create_local(Elem::Int);
|
||||
gpu!(scope, stride_0_i = cast(conv_stride_0));
|
||||
gpu!(scope, stride_1_i = cast(conv_stride_1));
|
||||
|
||||
let oc_out = scope.create_local(Elem::UInt);
|
||||
let oc = scope.create_local(Elem::UInt);
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let oh = scope.create_local(Elem::UInt);
|
||||
let ow = scope.create_local(Elem::UInt);
|
||||
let k = scope.create_local(Elem::UInt);
|
||||
let g = scope.create_local(Elem::UInt);
|
||||
|
||||
let ic_start = scope.create_local(Elem::UInt);
|
||||
let ic_end = scope.create_local(Elem::UInt);
|
||||
let ic_tmp = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, b = id / output_stride_0);
|
||||
gpu!(scope, b = b % output_shape_0);
|
||||
|
||||
gpu!(scope, oc_out = id / output_stride_1);
|
||||
gpu!(scope, oc_out = oc_out % output_shape_1);
|
||||
|
||||
gpu!(scope, oh = id / output_stride_2);
|
||||
gpu!(scope, oh = oh % output_shape_2);
|
||||
|
||||
gpu!(scope, ow = id / output_stride_3);
|
||||
gpu!(scope, ow = ow % output_shape_3);
|
||||
|
||||
gpu!(scope, k = oc_out / weight_shape_1);
|
||||
gpu!(scope, g = k % groups);
|
||||
gpu!(scope, oc = weight_shape_1 * g);
|
||||
gpu!(scope, oc = oc_out - oc);
|
||||
|
||||
gpu!(scope, ic_tmp = in_channels / groups);
|
||||
gpu!(scope, ic_start = g * ic_tmp);
|
||||
gpu!(scope, ic_end = ic_start + ic_tmp);
|
||||
|
||||
let tmp_u = scope.create_local(Elem::UInt);
|
||||
let tmp_i = scope.create_local(Elem::Int);
|
||||
let zero_i = scope.zero(Elem::Int);
|
||||
let one_i = scope.create_with_value(1, Elem::Int);
|
||||
|
||||
let kms_u = scope.create_local(Elem::UInt);
|
||||
let kms_0 = scope.create_local(Elem::Int);
|
||||
let kms_1 = scope.create_local(Elem::Int);
|
||||
let ih_start_tmp = scope.create_local(Elem::Int);
|
||||
let iw_start_tmp = scope.create_local(Elem::Int);
|
||||
let ih_start = scope.create_local(Elem::UInt);
|
||||
let iw_start = scope.create_local(Elem::UInt);
|
||||
let ih_end = scope.create_local(Elem::UInt);
|
||||
let iw_end = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, kms_u = kernel_size_0 * dilation_0);
|
||||
gpu!(scope, kms_0 = cast(kms_u));
|
||||
gpu!(scope, kms_0 = kms_0 - stride_0_i);
|
||||
gpu!(scope, kms_u = kernel_size_1 * dilation_1);
|
||||
gpu!(scope, kms_1 = cast(kms_u));
|
||||
gpu!(scope, kms_1 = kms_1 - stride_1_i);
|
||||
|
||||
gpu!(scope, tmp_u = oh + padding_0);
|
||||
gpu!(scope, tmp_i = cast(tmp_u));
|
||||
gpu!(scope, ih_start_tmp = tmp_i - kms_0);
|
||||
gpu!(scope, ih_start_tmp = ih_start_tmp / stride_0_i);
|
||||
gpu!(scope, tmp_u = ow + padding_1);
|
||||
gpu!(scope, tmp_i = cast(tmp_u));
|
||||
gpu!(scope, iw_start_tmp = tmp_i - kms_1);
|
||||
gpu!(scope, iw_start_tmp = iw_start_tmp / stride_1_i);
|
||||
|
||||
gpu!(scope, tmp_i = max(ih_start_tmp, zero_i));
|
||||
gpu!(scope, ih_start = cast(tmp_i));
|
||||
gpu!(scope, tmp_i = kms_0 + ih_start_tmp);
|
||||
gpu!(scope, tmp_i += one_i);
|
||||
gpu!(scope, tmp_i = max(tmp_i, zero_i));
|
||||
gpu!(scope, tmp_u = cast(tmp_i));
|
||||
gpu!(scope, ih_end = min(tmp_u, input_shape_2));
|
||||
|
||||
gpu!(scope, tmp_i = max(iw_start_tmp, zero_i));
|
||||
gpu!(scope, iw_start = cast(tmp_i));
|
||||
gpu!(scope, tmp_i = kms_1 + iw_start_tmp);
|
||||
gpu!(scope, tmp_i += one_i);
|
||||
gpu!(scope, tmp_i = max(tmp_i, zero_i));
|
||||
gpu!(scope, tmp_u = cast(tmp_i));
|
||||
gpu!(scope, iw_end = min(tmp_u, input_shape_3));
|
||||
|
||||
let index_input = scope.create_local(Elem::UInt);
|
||||
let index_weight = scope.create_local(Elem::UInt);
|
||||
|
||||
let index_input_b = scope.create_local(Elem::UInt);
|
||||
let index_input_ic = scope.create_local(Elem::UInt);
|
||||
let index_input_ih = scope.create_local(Elem::UInt);
|
||||
let index_input_iw = scope.create_local(Elem::UInt);
|
||||
let index_weight_ic = scope.create_local(Elem::UInt);
|
||||
let index_weight_oc = scope.create_local(Elem::UInt);
|
||||
let index_weight_kh = scope.create_local(Elem::UInt);
|
||||
let index_weight_kw = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, index_input_b = b * input_stride_0);
|
||||
gpu!(scope, index_weight_oc = oc * weight_stride_1);
|
||||
|
||||
let prod = scope.create_local(output.item());
|
||||
let prod_tmp = scope.create_local(output.item());
|
||||
let sum = scope.create_local(output.item());
|
||||
gpu!(scope, sum = bias[oc_out]);
|
||||
|
||||
let kh = scope.create_local(Elem::UInt);
|
||||
let kw = scope.create_local(Elem::UInt);
|
||||
let numerator_h_base = scope.create_local(Elem::UInt);
|
||||
let numerator_h = scope.create_local(Elem::UInt);
|
||||
let numerator_w_base = scope.create_local(Elem::UInt);
|
||||
let numerator_w = scope.create_local(Elem::UInt);
|
||||
let numerator_tmp = scope.create_local(Elem::UInt);
|
||||
let numerator_mod = scope.create_local(Elem::UInt);
|
||||
let zero = scope.zero(Elem::UInt);
|
||||
let divisible = scope.create_local(Elem::Bool);
|
||||
let not_neg = scope.create_local(Elem::Bool);
|
||||
let cond = scope.create_local(Elem::Bool);
|
||||
|
||||
gpu!(scope, numerator_h_base = oh + padding_0);
|
||||
gpu!(scope, numerator_w_base = ow + padding_1);
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
range(ic_start, ic_end).for_each(|ic, scope| {
|
||||
gpu!(scope, index_input_ic = ic * input_stride_1);
|
||||
gpu!(scope, index_weight_ic = ic * weight_stride_0);
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
range(ih_start, ih_end).for_each(|ih, scope| {
|
||||
gpu!(scope, numerator_tmp = ih * conv_stride_0);
|
||||
gpu!(scope, not_neg = numerator_h_base >= numerator_tmp);
|
||||
gpu!(scope, numerator_h = numerator_h_base - numerator_tmp);
|
||||
|
||||
gpu!(scope, numerator_mod = numerator_h % dilation_0);
|
||||
gpu!(scope, divisible = numerator_mod == zero);
|
||||
gpu!(scope, cond = not_neg && divisible);
|
||||
|
||||
gpu!(scope, if(cond).then(|scope|{
|
||||
gpu!(scope, kh = numerator_h / dilation_0);
|
||||
gpu!(scope, index_input_ih = ih * input_stride_2);
|
||||
gpu!(scope, index_weight_kh = kh * weight_stride_2);
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
range(iw_start, iw_end).for_each(|iw, scope| {
|
||||
gpu!(scope, numerator_tmp = iw * conv_stride_1);
|
||||
gpu!(scope, not_neg = numerator_w_base >= numerator_tmp);
|
||||
gpu!(scope, numerator_w = numerator_w_base - numerator_tmp);
|
||||
|
||||
gpu!(scope, numerator_mod = numerator_w % dilation_1);
|
||||
gpu!(scope, divisible = numerator_mod == zero);
|
||||
gpu!(scope, cond = not_neg && divisible);
|
||||
|
||||
gpu!(scope, if(cond).then(|scope|{
|
||||
gpu!(scope, kw = numerator_w / dilation_1);
|
||||
gpu!(scope, index_input_iw = iw * input_stride_3);
|
||||
gpu!(scope, index_weight_kw = kw * weight_stride_3);
|
||||
|
||||
gpu!(scope, index_input = index_input_b);
|
||||
gpu!(scope, index_input += index_input_ic);
|
||||
gpu!(scope, index_input += index_input_ih);
|
||||
gpu!(scope, index_input += index_input_iw);
|
||||
|
||||
gpu!(scope, index_weight = index_weight_ic);
|
||||
gpu!(scope, index_weight += index_weight_oc);
|
||||
gpu!(scope, index_weight += index_weight_kh);
|
||||
gpu!(scope, index_weight += index_weight_kw);
|
||||
|
||||
gpu!(scope, prod = input[index_input]);
|
||||
gpu!(scope, prod_tmp = weight[index_weight]);
|
||||
gpu!(scope, prod *= prod_tmp);
|
||||
gpu!(scope, sum += prod);
|
||||
}));
|
||||
})
|
||||
);
|
||||
|
||||
}));
|
||||
})
|
||||
);
|
||||
})
|
||||
);
|
||||
|
||||
gpu!(scope, output[id] = sum);
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime, E: JitElement> DynamicKernelSource for Conv2dTransposeEagerKernel<R, E> {
|
||||
fn source(&self) -> kernel::SourceTemplate {
|
||||
let mut scope = Scope::root();
|
||||
let item = E::gpu_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let weight = Variable::GlobalInputArray(1, item);
|
||||
let bias = Variable::GlobalInputArray(2, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
Conv2dTransposeComputeShader {
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
output,
|
||||
_elem: PhantomData::<E>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
let input = InputInfo::Array {
|
||||
item,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
let weight = InputInfo::Array {
|
||||
item,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
let bias = InputInfo::Array {
|
||||
item,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
let scalars = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
size: 7,
|
||||
};
|
||||
|
||||
let output = OutputInfo::Array { item };
|
||||
|
||||
let info = CompilationInfo {
|
||||
inputs: vec![input, weight, bias, scalars],
|
||||
outputs: vec![output],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = CompilationSettings::default();
|
||||
let shader = Compilation::new(info).compile(settings);
|
||||
let shader = <R::Compiler as Compiler>::compile(shader);
|
||||
SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}", core::any::TypeId::of::<Self>())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv_transpose2d<R: Runtime, E: JitElement + Element>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
|
@ -34,42 +367,47 @@ pub(crate) fn conv_transpose2d<R: Runtime, E: JitElement + Element>(
|
|||
+ 1;
|
||||
|
||||
let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]);
|
||||
let num_elems = shape_out.num_elements();
|
||||
|
||||
let output = empty_device(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
shape_out.clone(),
|
||||
);
|
||||
let mut info = build_info(&[&input, &output, &weight]);
|
||||
|
||||
info.push(options.stride[0] as u32);
|
||||
info.push(options.stride[1] as u32);
|
||||
info.push(options.padding[0] as u32);
|
||||
info.push(options.padding[1] as u32);
|
||||
info.push(options.dilation[0] as u32);
|
||||
info.push(options.dilation[1] as u32);
|
||||
info.push(options.groups as u32);
|
||||
let bias = match bias {
|
||||
Some(bias) => {
|
||||
let shape = Shape::from([bias.shape.dims[0], 1, 1, 1]);
|
||||
reshape(bias, shape)
|
||||
}
|
||||
None => {
|
||||
let shape = Shape::from([output.shape.dims[0], 1, 1, 1]);
|
||||
zeros_device(input.client.clone(), input.device.clone(), shape)
|
||||
}
|
||||
};
|
||||
|
||||
let bias_handle = bias
|
||||
.map(|bias| bias.handle)
|
||||
.unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()])));
|
||||
let kernel = Conv2dTransposeEagerKernel::<R, E>::new();
|
||||
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<ConvTranspose2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[
|
||||
&input.handle,
|
||||
&weight.handle,
|
||||
&bias_handle,
|
||||
Execution::start(kernel, input.client.clone())
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
|
||||
EagerHandle::new(&weight.handle, &weight.strides, &weight.shape.dims),
|
||||
EagerHandle::new(&bias.handle, &bias.strides, &bias.shape.dims),
|
||||
])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&info_handle,
|
||||
],
|
||||
);
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)])
|
||||
.with_scalars(&[
|
||||
(options.stride[0] as u32).elem::<u32>(),
|
||||
(options.stride[1] as u32).elem(),
|
||||
(options.dilation[0] as u32).elem(),
|
||||
(options.dilation[1] as u32).elem(),
|
||||
(options.padding[0] as u32).elem(),
|
||||
(options.padding[1] as u32).elem(),
|
||||
(options.groups as u32).elem(),
|
||||
])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -2,4 +2,4 @@ mod conv2d;
|
|||
mod conv_transpose2d;
|
||||
|
||||
pub(crate) use conv2d::*;
|
||||
pub use conv_transpose2d::*;
|
||||
pub(crate) use conv_transpose2d::*;
|
||||
|
|
|
@ -1,122 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> input: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> weight: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> bias: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(4)
|
||||
var<storage, read> info: array<u32, 32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
|
||||
let input_stride_0 = info[1];
|
||||
let input_stride_1 = info[2];
|
||||
let input_stride_2 = info[3];
|
||||
let input_stride_3 = info[4];
|
||||
let output_stride_0 = info[5];
|
||||
let output_stride_1 = info[6];
|
||||
let output_stride_2 = info[7];
|
||||
let output_stride_3 = info[8];
|
||||
let weight_stride_0 = info[9];
|
||||
let weight_stride_1 = info[10];
|
||||
let weight_stride_2 = info[11];
|
||||
let weight_stride_3 = info[12];
|
||||
|
||||
let input_shape_0 = info[13];
|
||||
let input_shape_1 = info[14];
|
||||
let input_shape_2 = info[15];
|
||||
let input_shape_3 = info[16];
|
||||
let output_shape_0 = info[17];
|
||||
let output_shape_1 = info[18];
|
||||
let output_shape_2 = info[19];
|
||||
let output_shape_3 = info[20];
|
||||
let weight_shape_0 = info[21];
|
||||
let weight_shape_1 = info[22];
|
||||
let weight_shape_2 = info[23];
|
||||
let weight_shape_3 = info[24];
|
||||
|
||||
let stride_0 = info[25];
|
||||
let stride_1 = info[26];
|
||||
let padding_0 = info[27];
|
||||
let padding_1 = info[28];
|
||||
let dilation_0 = info[29];
|
||||
let dilation_1 = info[30];
|
||||
let groups = info[31];
|
||||
|
||||
let in_channels = weight_shape_0;
|
||||
let kernel_size_0 = weight_shape_2;
|
||||
let kernel_size_1 = weight_shape_3;
|
||||
|
||||
let b = id / output_stride_0 % output_shape_0;
|
||||
let oc_out = id / output_stride_1 % output_shape_1;
|
||||
let oh = id / output_stride_2 % output_shape_2;
|
||||
let ow = id / output_stride_3 % output_shape_3;
|
||||
|
||||
let k = oc_out / weight_shape_1;
|
||||
let g = k % groups;
|
||||
let oc = oc_out - (weight_shape_1 * g);
|
||||
|
||||
var sum = bias[oc_out];
|
||||
|
||||
let ic_start = g * (in_channels / groups);
|
||||
let ic_end = ic_start + in_channels / groups;
|
||||
|
||||
// The maximum number of overlapping filters that may content the current index.
|
||||
let kms_0 = i32(kernel_size_0 * dilation_0) - i32(stride_0);
|
||||
let kms_1 = i32(kernel_size_1 * dilation_1) - i32(stride_1);
|
||||
|
||||
let ih_start_tmp = (i32(oh + padding_0) - kms_0) / i32(stride_0);
|
||||
let iw_start_tmp = (i32(ow + padding_1) - kms_1) / i32(stride_1);
|
||||
|
||||
let ih_start = u32(max(ih_start_tmp, 0));
|
||||
let iw_start = u32(max(iw_start_tmp, 0));
|
||||
|
||||
let ih_end = min(u32(max(kms_0 + ih_start_tmp + 1, 0)), input_shape_2);
|
||||
let iw_end = min(u32(max(kms_1 + iw_start_tmp + 1, 0)), input_shape_3);
|
||||
|
||||
for (var ic = ic_start; ic < ic_end; ic++) {
|
||||
for (var ih = ih_start; ih < ih_end; ih++) {
|
||||
for (var iw = iw_start; iw < iw_end; iw++) {
|
||||
for (var kh = 0u; kh < kernel_size_0; kh++) {
|
||||
for (var kw = 0u; kw < kernel_size_1; kw++) {
|
||||
let oh_tmp = ih * stride_0 + kh * dilation_0;
|
||||
let ow_tmp = iw * stride_1 + kw * dilation_1;
|
||||
|
||||
if oh_tmp >= padding_0 && ow_tmp >= padding_1 {
|
||||
let oh_tmp_pad = oh_tmp - padding_0;
|
||||
let ow_tmp_pad = ow_tmp - padding_1;
|
||||
|
||||
if oh_tmp_pad == oh && ow_tmp_pad == ow {
|
||||
let index_input = b * input_stride_0 + ic * input_stride_1 + ih * input_stride_2 + iw * input_stride_3;
|
||||
let index_weight = ic * weight_stride_0 + oc * weight_stride_1 + kh * weight_stride_2 + kw * weight_stride_3;
|
||||
|
||||
sum += input[index_input] * weight[index_weight];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output[id] = sum;
|
||||
}
|
Loading…
Reference in New Issue