Conv Transpose: migration + decent speedup (#1541)

* convtranspose benchmark

* adjust bench

* conv transpose works

* Conv Transpose: migration + decent speedup

* delete template folder

* typos

* fix
This commit is contained in:
Louis Fortier-Dubois 2024-03-28 12:13:06 -04:00 committed by GitHub
parent b8fc3f141e
commit 279be0496a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 469 additions and 155 deletions

View File

@ -65,6 +65,11 @@ name = "max-pool2d"
path = "benches/max_pool2d.rs" path = "benches/max_pool2d.rs"
harness = false harness = false
[[bench]]
name = "conv-transpose2d"
path = "benches/conv_transpose2d.rs"
harness = false
[[bench]] [[bench]]
name = "matmul" name = "matmul"
harness = false harness = false

View File

@ -0,0 +1,93 @@
use backend_comparison::persistence::save;
use burn::tensor::{
backend::Backend, module::conv_transpose2d, ops::ConvTransposeOptions, Distribution, Shape,
Tensor,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
pub struct ConvTranspose2dBenchmark<B: Backend> {
input_shape: Shape<4>,
weight_shape: Shape<4>,
bias_shape: Shape<1>,
options: ConvTransposeOptions<2>,
device: B::Device,
}
impl<B: Backend> Benchmark for ConvTranspose2dBenchmark<B> {
type Args = (Tensor<B, 4>, Tensor<B, 4>, Tensor<B, 1>);
fn name(&self) -> String {
"conv_transpose2d".into()
}
fn shapes(&self) -> Vec<Vec<usize>> {
vec![
self.input_shape.dims.into(),
self.weight_shape.dims.into(),
self.bias_shape.dims.into(),
]
}
fn execute(&self, (x, w, b): Self::Args) {
conv_transpose2d(x, w, Some(b), self.options.clone());
}
fn prepare(&self) -> Self::Args {
(
Tensor::random(
self.input_shape.clone(),
Distribution::Default,
&self.device,
),
Tensor::random(
self.weight_shape.clone(),
Distribution::Default,
&self.device,
),
Tensor::random(self.bias_shape.clone(), Distribution::Default, &self.device),
)
}
fn sync(&self) {
B::sync(&self.device)
}
}
#[allow(dead_code)]
fn bench<B: Backend>(device: &B::Device, url: Option<&str>, token: Option<&str>) {
// Shapes
let batch_size = 16;
let channels_in = 16;
let channels_out = 16;
let height_in = 64;
let width_in = 64;
let kernel_size_0 = 8;
let kernel_size_1 = 8;
// Options
let strides = [1, 1];
let padding = [0, 0];
let padding_out = [0, 0];
let dilations = [1, 1];
let groups = 1;
let options = ConvTransposeOptions::new(strides, padding, padding_out, dilations, groups);
let benchmark = ConvTranspose2dBenchmark::<B> {
input_shape: [batch_size, channels_in, height_in, width_in].into(),
weight_shape: [
channels_in,
channels_out / groups,
kernel_size_0,
kernel_size_1,
]
.into(),
bias_shape: [channels_out].into(),
options,
device: device.clone(),
};
save::<B>(vec![run_benchmark(benchmark)], device, url, token).unwrap();
}
fn main() {
backend_comparison::bench_on_backend!();
}

View File

@ -1,15 +1,348 @@
use std::marker::PhantomData;
use crate::{ use crate::{
compute::StaticKernel, codegen::{
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
element::JitElement, element::JitElement,
kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, gpu::{gpu, Elem, Scope, Variable, Visibility},
kernel_wgsl, kernel::{self, DynamicKernelSource, SourceTemplate},
ops::numeric::empty_device, ops::{
numeric::{empty_device, zeros_device},
reshape,
},
tensor::JitTensor, tensor::JitTensor,
Runtime, Compiler, Runtime,
}; };
use burn_tensor::{ops::ConvTransposeOptions, Element, ElementConversion, Shape}; use burn_tensor::{ops::ConvTransposeOptions, Element, ElementConversion, Shape};
kernel_wgsl!(ConvTranspose2d, "../../template/conv/conv_transpose2d.wgsl"); #[derive(new)]
struct Conv2dTransposeEagerKernel<R, E> {
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}
#[derive(new)]
struct Conv2dTransposeComputeShader<E> {
input: Variable,
weight: Variable,
bias: Variable,
output: Variable,
_elem: PhantomData<E>,
}
impl<E: JitElement> Conv2dTransposeComputeShader<E> {
fn expand(self, scope: &mut Scope) {
let input = self.input;
let weight = self.weight;
let bias = self.bias;
let output = self.output;
let id = Variable::Id;
let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt);
let input_stride_2 = scope.create_local(Elem::UInt);
let input_stride_3 = scope.create_local(Elem::UInt);
let input_shape_0 = scope.create_local(Elem::UInt);
let input_shape_1 = scope.create_local(Elem::UInt);
let input_shape_2 = scope.create_local(Elem::UInt);
let input_shape_3 = scope.create_local(Elem::UInt);
gpu!(scope, input_stride_0 = stride(input, 0u32));
gpu!(scope, input_stride_1 = stride(input, 1u32));
gpu!(scope, input_stride_2 = stride(input, 2u32));
gpu!(scope, input_stride_3 = stride(input, 3u32));
gpu!(scope, input_shape_0 = shape(input, 0u32));
gpu!(scope, input_shape_1 = shape(input, 1u32));
gpu!(scope, input_shape_2 = shape(input, 2u32));
gpu!(scope, input_shape_3 = shape(input, 3u32));
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
gpu!(scope, output_stride_0 = stride(output, 0u32));
gpu!(scope, output_stride_1 = stride(output, 1u32));
gpu!(scope, output_stride_2 = stride(output, 2u32));
gpu!(scope, output_stride_3 = stride(output, 3u32));
gpu!(scope, output_shape_0 = shape(output, 0u32));
gpu!(scope, output_shape_1 = shape(output, 1u32));
gpu!(scope, output_shape_2 = shape(output, 2u32));
gpu!(scope, output_shape_3 = shape(output, 3u32));
let weight_stride_0 = scope.create_local(Elem::UInt);
let weight_stride_1 = scope.create_local(Elem::UInt);
let weight_stride_2 = scope.create_local(Elem::UInt);
let weight_stride_3 = scope.create_local(Elem::UInt);
let in_channels = scope.create_local(Elem::UInt);
let weight_shape_1 = scope.create_local(Elem::UInt);
let kernel_size_0 = scope.create_local(Elem::UInt);
let kernel_size_1 = scope.create_local(Elem::UInt);
gpu!(scope, weight_stride_0 = stride(weight, 0u32));
gpu!(scope, weight_stride_1 = stride(weight, 1u32));
gpu!(scope, weight_stride_2 = stride(weight, 2u32));
gpu!(scope, weight_stride_3 = stride(weight, 3u32));
gpu!(scope, in_channels = shape(weight, 0u32));
gpu!(scope, weight_shape_1 = shape(weight, 1u32));
gpu!(scope, kernel_size_0 = shape(weight, 2u32));
gpu!(scope, kernel_size_1 = shape(weight, 3u32));
let conv_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
let conv_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
let groups = Variable::GlobalScalar(6, Elem::UInt);
let stride_0_i = scope.create_local(Elem::Int);
let stride_1_i = scope.create_local(Elem::Int);
gpu!(scope, stride_0_i = cast(conv_stride_0));
gpu!(scope, stride_1_i = cast(conv_stride_1));
let oc_out = scope.create_local(Elem::UInt);
let oc = scope.create_local(Elem::UInt);
let b = scope.create_local(Elem::UInt);
let oh = scope.create_local(Elem::UInt);
let ow = scope.create_local(Elem::UInt);
let k = scope.create_local(Elem::UInt);
let g = scope.create_local(Elem::UInt);
let ic_start = scope.create_local(Elem::UInt);
let ic_end = scope.create_local(Elem::UInt);
let ic_tmp = scope.create_local(Elem::UInt);
gpu!(scope, b = id / output_stride_0);
gpu!(scope, b = b % output_shape_0);
gpu!(scope, oc_out = id / output_stride_1);
gpu!(scope, oc_out = oc_out % output_shape_1);
gpu!(scope, oh = id / output_stride_2);
gpu!(scope, oh = oh % output_shape_2);
gpu!(scope, ow = id / output_stride_3);
gpu!(scope, ow = ow % output_shape_3);
gpu!(scope, k = oc_out / weight_shape_1);
gpu!(scope, g = k % groups);
gpu!(scope, oc = weight_shape_1 * g);
gpu!(scope, oc = oc_out - oc);
gpu!(scope, ic_tmp = in_channels / groups);
gpu!(scope, ic_start = g * ic_tmp);
gpu!(scope, ic_end = ic_start + ic_tmp);
let tmp_u = scope.create_local(Elem::UInt);
let tmp_i = scope.create_local(Elem::Int);
let zero_i = scope.zero(Elem::Int);
let one_i = scope.create_with_value(1, Elem::Int);
let kms_u = scope.create_local(Elem::UInt);
let kms_0 = scope.create_local(Elem::Int);
let kms_1 = scope.create_local(Elem::Int);
let ih_start_tmp = scope.create_local(Elem::Int);
let iw_start_tmp = scope.create_local(Elem::Int);
let ih_start = scope.create_local(Elem::UInt);
let iw_start = scope.create_local(Elem::UInt);
let ih_end = scope.create_local(Elem::UInt);
let iw_end = scope.create_local(Elem::UInt);
gpu!(scope, kms_u = kernel_size_0 * dilation_0);
gpu!(scope, kms_0 = cast(kms_u));
gpu!(scope, kms_0 = kms_0 - stride_0_i);
gpu!(scope, kms_u = kernel_size_1 * dilation_1);
gpu!(scope, kms_1 = cast(kms_u));
gpu!(scope, kms_1 = kms_1 - stride_1_i);
gpu!(scope, tmp_u = oh + padding_0);
gpu!(scope, tmp_i = cast(tmp_u));
gpu!(scope, ih_start_tmp = tmp_i - kms_0);
gpu!(scope, ih_start_tmp = ih_start_tmp / stride_0_i);
gpu!(scope, tmp_u = ow + padding_1);
gpu!(scope, tmp_i = cast(tmp_u));
gpu!(scope, iw_start_tmp = tmp_i - kms_1);
gpu!(scope, iw_start_tmp = iw_start_tmp / stride_1_i);
gpu!(scope, tmp_i = max(ih_start_tmp, zero_i));
gpu!(scope, ih_start = cast(tmp_i));
gpu!(scope, tmp_i = kms_0 + ih_start_tmp);
gpu!(scope, tmp_i += one_i);
gpu!(scope, tmp_i = max(tmp_i, zero_i));
gpu!(scope, tmp_u = cast(tmp_i));
gpu!(scope, ih_end = min(tmp_u, input_shape_2));
gpu!(scope, tmp_i = max(iw_start_tmp, zero_i));
gpu!(scope, iw_start = cast(tmp_i));
gpu!(scope, tmp_i = kms_1 + iw_start_tmp);
gpu!(scope, tmp_i += one_i);
gpu!(scope, tmp_i = max(tmp_i, zero_i));
gpu!(scope, tmp_u = cast(tmp_i));
gpu!(scope, iw_end = min(tmp_u, input_shape_3));
let index_input = scope.create_local(Elem::UInt);
let index_weight = scope.create_local(Elem::UInt);
let index_input_b = scope.create_local(Elem::UInt);
let index_input_ic = scope.create_local(Elem::UInt);
let index_input_ih = scope.create_local(Elem::UInt);
let index_input_iw = scope.create_local(Elem::UInt);
let index_weight_ic = scope.create_local(Elem::UInt);
let index_weight_oc = scope.create_local(Elem::UInt);
let index_weight_kh = scope.create_local(Elem::UInt);
let index_weight_kw = scope.create_local(Elem::UInt);
gpu!(scope, index_input_b = b * input_stride_0);
gpu!(scope, index_weight_oc = oc * weight_stride_1);
let prod = scope.create_local(output.item());
let prod_tmp = scope.create_local(output.item());
let sum = scope.create_local(output.item());
gpu!(scope, sum = bias[oc_out]);
let kh = scope.create_local(Elem::UInt);
let kw = scope.create_local(Elem::UInt);
let numerator_h_base = scope.create_local(Elem::UInt);
let numerator_h = scope.create_local(Elem::UInt);
let numerator_w_base = scope.create_local(Elem::UInt);
let numerator_w = scope.create_local(Elem::UInt);
let numerator_tmp = scope.create_local(Elem::UInt);
let numerator_mod = scope.create_local(Elem::UInt);
let zero = scope.zero(Elem::UInt);
let divisible = scope.create_local(Elem::Bool);
let not_neg = scope.create_local(Elem::Bool);
let cond = scope.create_local(Elem::Bool);
gpu!(scope, numerator_h_base = oh + padding_0);
gpu!(scope, numerator_w_base = ow + padding_1);
gpu!(
scope,
range(ic_start, ic_end).for_each(|ic, scope| {
gpu!(scope, index_input_ic = ic * input_stride_1);
gpu!(scope, index_weight_ic = ic * weight_stride_0);
gpu!(
scope,
range(ih_start, ih_end).for_each(|ih, scope| {
gpu!(scope, numerator_tmp = ih * conv_stride_0);
gpu!(scope, not_neg = numerator_h_base >= numerator_tmp);
gpu!(scope, numerator_h = numerator_h_base - numerator_tmp);
gpu!(scope, numerator_mod = numerator_h % dilation_0);
gpu!(scope, divisible = numerator_mod == zero);
gpu!(scope, cond = not_neg && divisible);
gpu!(scope, if(cond).then(|scope|{
gpu!(scope, kh = numerator_h / dilation_0);
gpu!(scope, index_input_ih = ih * input_stride_2);
gpu!(scope, index_weight_kh = kh * weight_stride_2);
gpu!(
scope,
range(iw_start, iw_end).for_each(|iw, scope| {
gpu!(scope, numerator_tmp = iw * conv_stride_1);
gpu!(scope, not_neg = numerator_w_base >= numerator_tmp);
gpu!(scope, numerator_w = numerator_w_base - numerator_tmp);
gpu!(scope, numerator_mod = numerator_w % dilation_1);
gpu!(scope, divisible = numerator_mod == zero);
gpu!(scope, cond = not_neg && divisible);
gpu!(scope, if(cond).then(|scope|{
gpu!(scope, kw = numerator_w / dilation_1);
gpu!(scope, index_input_iw = iw * input_stride_3);
gpu!(scope, index_weight_kw = kw * weight_stride_3);
gpu!(scope, index_input = index_input_b);
gpu!(scope, index_input += index_input_ic);
gpu!(scope, index_input += index_input_ih);
gpu!(scope, index_input += index_input_iw);
gpu!(scope, index_weight = index_weight_ic);
gpu!(scope, index_weight += index_weight_oc);
gpu!(scope, index_weight += index_weight_kh);
gpu!(scope, index_weight += index_weight_kw);
gpu!(scope, prod = input[index_input]);
gpu!(scope, prod_tmp = weight[index_weight]);
gpu!(scope, prod *= prod_tmp);
gpu!(scope, sum += prod);
}));
})
);
}));
})
);
})
);
gpu!(scope, output[id] = sum);
}
}
impl<R: Runtime, E: JitElement> DynamicKernelSource for Conv2dTransposeEagerKernel<R, E> {
fn source(&self) -> kernel::SourceTemplate {
let mut scope = Scope::root();
let item = E::gpu_elem().into();
let input = Variable::GlobalInputArray(0, item);
let weight = Variable::GlobalInputArray(1, item);
let bias = Variable::GlobalInputArray(2, item);
let output = Variable::GlobalOutputArray(0, item);
scope.write_global_custom(output);
Conv2dTransposeComputeShader {
input,
weight,
bias,
output,
_elem: PhantomData::<E>,
}
.expand(&mut scope);
let input = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let weight = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let bias = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let scalars = InputInfo::Scalar {
elem: Elem::UInt,
size: 7,
};
let output = OutputInfo::Array { item };
let info = CompilationInfo {
inputs: vec![input, weight, bias, scalars],
outputs: vec![output],
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!("{:?}", core::any::TypeId::of::<Self>())
}
}
pub(crate) fn conv_transpose2d<R: Runtime, E: JitElement + Element>( pub(crate) fn conv_transpose2d<R: Runtime, E: JitElement + Element>(
input: JitTensor<R, E, 4>, input: JitTensor<R, E, 4>,
@ -34,42 +367,47 @@ pub(crate) fn conv_transpose2d<R: Runtime, E: JitElement + Element>(
+ 1; + 1;
let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]); let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]);
let num_elems = shape_out.num_elements();
let output = empty_device( let output = empty_device(
input.client.clone(), input.client.clone(),
input.device.clone(), input.device.clone(),
shape_out.clone(), shape_out.clone(),
); );
let mut info = build_info(&[&input, &output, &weight]);
info.push(options.stride[0] as u32); let bias = match bias {
info.push(options.stride[1] as u32); Some(bias) => {
info.push(options.padding[0] as u32); let shape = Shape::from([bias.shape.dims[0], 1, 1, 1]);
info.push(options.padding[1] as u32); reshape(bias, shape)
info.push(options.dilation[0] as u32); }
info.push(options.dilation[1] as u32); None => {
info.push(options.groups as u32); let shape = Shape::from([output.shape.dims[0], 1, 1, 1]);
zeros_device(input.client.clone(), input.device.clone(), shape)
}
};
let bias_handle = bias let kernel = Conv2dTransposeEagerKernel::<R, E>::new();
.map(|bias| bias.handle)
.unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()])));
let info_handle = input.client.create(bytemuck::cast_slice(&info)); Execution::start(kernel, input.client.clone())
.inputs(&[
let kernel = StaticKernel::< EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
KernelSettings<ConvTranspose2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, EagerHandle::new(&weight.handle, &weight.strides, &weight.shape.dims),
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); EagerHandle::new(&bias.handle, &bias.strides, &bias.shape.dims),
input.client.execute( ])
Box::new(kernel), .outputs(&[EagerHandle::new(
&[
&input.handle,
&weight.handle,
&bias_handle,
&output.handle, &output.handle,
&info_handle, &output.strides,
], &output.shape.dims,
); )])
.with_scalars(&[
(options.stride[0] as u32).elem::<u32>(),
(options.stride[1] as u32).elem(),
(options.dilation[0] as u32).elem(),
(options.dilation[1] as u32).elem(),
(options.padding[0] as u32).elem(),
(options.padding[1] as u32).elem(),
(options.groups as u32).elem(),
])
.execute(WorkgroupLaunch::Output { pos: 0 });
output output
} }

View File

@ -2,4 +2,4 @@ mod conv2d;
mod conv_transpose2d; mod conv_transpose2d;
pub(crate) use conv2d::*; pub(crate) use conv2d::*;
pub use conv_transpose2d::*; pub(crate) use conv_transpose2d::*;

View File

@ -1,122 +0,0 @@
@group(0)
@binding(0)
var<storage, read> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> weight: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> bias: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(4)
var<storage, read> info: array<u32, 32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let input_stride_0 = info[1];
let input_stride_1 = info[2];
let input_stride_2 = info[3];
let input_stride_3 = info[4];
let output_stride_0 = info[5];
let output_stride_1 = info[6];
let output_stride_2 = info[7];
let output_stride_3 = info[8];
let weight_stride_0 = info[9];
let weight_stride_1 = info[10];
let weight_stride_2 = info[11];
let weight_stride_3 = info[12];
let input_shape_0 = info[13];
let input_shape_1 = info[14];
let input_shape_2 = info[15];
let input_shape_3 = info[16];
let output_shape_0 = info[17];
let output_shape_1 = info[18];
let output_shape_2 = info[19];
let output_shape_3 = info[20];
let weight_shape_0 = info[21];
let weight_shape_1 = info[22];
let weight_shape_2 = info[23];
let weight_shape_3 = info[24];
let stride_0 = info[25];
let stride_1 = info[26];
let padding_0 = info[27];
let padding_1 = info[28];
let dilation_0 = info[29];
let dilation_1 = info[30];
let groups = info[31];
let in_channels = weight_shape_0;
let kernel_size_0 = weight_shape_2;
let kernel_size_1 = weight_shape_3;
let b = id / output_stride_0 % output_shape_0;
let oc_out = id / output_stride_1 % output_shape_1;
let oh = id / output_stride_2 % output_shape_2;
let ow = id / output_stride_3 % output_shape_3;
let k = oc_out / weight_shape_1;
let g = k % groups;
let oc = oc_out - (weight_shape_1 * g);
var sum = bias[oc_out];
let ic_start = g * (in_channels / groups);
let ic_end = ic_start + in_channels / groups;
// The maximum number of overlapping filters that may content the current index.
let kms_0 = i32(kernel_size_0 * dilation_0) - i32(stride_0);
let kms_1 = i32(kernel_size_1 * dilation_1) - i32(stride_1);
let ih_start_tmp = (i32(oh + padding_0) - kms_0) / i32(stride_0);
let iw_start_tmp = (i32(ow + padding_1) - kms_1) / i32(stride_1);
let ih_start = u32(max(ih_start_tmp, 0));
let iw_start = u32(max(iw_start_tmp, 0));
let ih_end = min(u32(max(kms_0 + ih_start_tmp + 1, 0)), input_shape_2);
let iw_end = min(u32(max(kms_1 + iw_start_tmp + 1, 0)), input_shape_3);
for (var ic = ic_start; ic < ic_end; ic++) {
for (var ih = ih_start; ih < ih_end; ih++) {
for (var iw = iw_start; iw < iw_end; iw++) {
for (var kh = 0u; kh < kernel_size_0; kh++) {
for (var kw = 0u; kw < kernel_size_1; kw++) {
let oh_tmp = ih * stride_0 + kh * dilation_0;
let ow_tmp = iw * stride_1 + kw * dilation_1;
if oh_tmp >= padding_0 && ow_tmp >= padding_1 {
let oh_tmp_pad = oh_tmp - padding_0;
let ow_tmp_pad = ow_tmp - padding_1;
if oh_tmp_pad == oh && ow_tmp_pad == ow {
let index_input = b * input_stride_0 + ic * input_stride_1 + ih * input_stride_2 + iw * input_stride_3;
let index_weight = ic * weight_stride_0 + oc * weight_stride_1 + kh * weight_stride_2 + kw * weight_stride_3;
sum += input[index_input] * weight[index_weight];
}
}
}
}
}
}
}
output[id] = sum;
}