mirror of https://github.com/tracel-ai/burn.git
Conv Transpose: migration + decent speedup (#1541)
* convtranspose benchmark * adjust bench * conv transpose works * Conv Transpose: migration + decent speedup * delete template folder * typos * fix
This commit is contained in:
parent
b8fc3f141e
commit
279be0496a
|
@ -65,6 +65,11 @@ name = "max-pool2d"
|
||||||
path = "benches/max_pool2d.rs"
|
path = "benches/max_pool2d.rs"
|
||||||
harness = false
|
harness = false
|
||||||
|
|
||||||
|
[[bench]]
|
||||||
|
name = "conv-transpose2d"
|
||||||
|
path = "benches/conv_transpose2d.rs"
|
||||||
|
harness = false
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "matmul"
|
name = "matmul"
|
||||||
harness = false
|
harness = false
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
use backend_comparison::persistence::save;
|
||||||
|
use burn::tensor::{
|
||||||
|
backend::Backend, module::conv_transpose2d, ops::ConvTransposeOptions, Distribution, Shape,
|
||||||
|
Tensor,
|
||||||
|
};
|
||||||
|
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||||
|
|
||||||
|
pub struct ConvTranspose2dBenchmark<B: Backend> {
|
||||||
|
input_shape: Shape<4>,
|
||||||
|
weight_shape: Shape<4>,
|
||||||
|
bias_shape: Shape<1>,
|
||||||
|
options: ConvTransposeOptions<2>,
|
||||||
|
device: B::Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Benchmark for ConvTranspose2dBenchmark<B> {
|
||||||
|
type Args = (Tensor<B, 4>, Tensor<B, 4>, Tensor<B, 1>);
|
||||||
|
|
||||||
|
fn name(&self) -> String {
|
||||||
|
"conv_transpose2d".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn shapes(&self) -> Vec<Vec<usize>> {
|
||||||
|
vec![
|
||||||
|
self.input_shape.dims.into(),
|
||||||
|
self.weight_shape.dims.into(),
|
||||||
|
self.bias_shape.dims.into(),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn execute(&self, (x, w, b): Self::Args) {
|
||||||
|
conv_transpose2d(x, w, Some(b), self.options.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prepare(&self) -> Self::Args {
|
||||||
|
(
|
||||||
|
Tensor::random(
|
||||||
|
self.input_shape.clone(),
|
||||||
|
Distribution::Default,
|
||||||
|
&self.device,
|
||||||
|
),
|
||||||
|
Tensor::random(
|
||||||
|
self.weight_shape.clone(),
|
||||||
|
Distribution::Default,
|
||||||
|
&self.device,
|
||||||
|
),
|
||||||
|
Tensor::random(self.bias_shape.clone(), Distribution::Default, &self.device),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sync(&self) {
|
||||||
|
B::sync(&self.device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn bench<B: Backend>(device: &B::Device, url: Option<&str>, token: Option<&str>) {
|
||||||
|
// Shapes
|
||||||
|
let batch_size = 16;
|
||||||
|
let channels_in = 16;
|
||||||
|
let channels_out = 16;
|
||||||
|
let height_in = 64;
|
||||||
|
let width_in = 64;
|
||||||
|
let kernel_size_0 = 8;
|
||||||
|
let kernel_size_1 = 8;
|
||||||
|
|
||||||
|
// Options
|
||||||
|
let strides = [1, 1];
|
||||||
|
let padding = [0, 0];
|
||||||
|
let padding_out = [0, 0];
|
||||||
|
let dilations = [1, 1];
|
||||||
|
let groups = 1;
|
||||||
|
let options = ConvTransposeOptions::new(strides, padding, padding_out, dilations, groups);
|
||||||
|
let benchmark = ConvTranspose2dBenchmark::<B> {
|
||||||
|
input_shape: [batch_size, channels_in, height_in, width_in].into(),
|
||||||
|
weight_shape: [
|
||||||
|
channels_in,
|
||||||
|
channels_out / groups,
|
||||||
|
kernel_size_0,
|
||||||
|
kernel_size_1,
|
||||||
|
]
|
||||||
|
.into(),
|
||||||
|
bias_shape: [channels_out].into(),
|
||||||
|
options,
|
||||||
|
device: device.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
save::<B>(vec![run_benchmark(benchmark)], device, url, token).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
backend_comparison::bench_on_backend!();
|
||||||
|
}
|
|
@ -1,15 +1,348 @@
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
compute::StaticKernel,
|
codegen::{
|
||||||
|
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||||
|
OutputInfo, WorkgroupLaunch,
|
||||||
|
},
|
||||||
element::JitElement,
|
element::JitElement,
|
||||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||||
kernel_wgsl,
|
kernel::{self, DynamicKernelSource, SourceTemplate},
|
||||||
ops::numeric::empty_device,
|
ops::{
|
||||||
|
numeric::{empty_device, zeros_device},
|
||||||
|
reshape,
|
||||||
|
},
|
||||||
tensor::JitTensor,
|
tensor::JitTensor,
|
||||||
Runtime,
|
Compiler, Runtime,
|
||||||
};
|
};
|
||||||
use burn_tensor::{ops::ConvTransposeOptions, Element, ElementConversion, Shape};
|
use burn_tensor::{ops::ConvTransposeOptions, Element, ElementConversion, Shape};
|
||||||
|
|
||||||
kernel_wgsl!(ConvTranspose2d, "../../template/conv/conv_transpose2d.wgsl");
|
#[derive(new)]
|
||||||
|
struct Conv2dTransposeEagerKernel<R, E> {
|
||||||
|
_runtime: PhantomData<R>,
|
||||||
|
_elem: PhantomData<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(new)]
|
||||||
|
struct Conv2dTransposeComputeShader<E> {
|
||||||
|
input: Variable,
|
||||||
|
weight: Variable,
|
||||||
|
bias: Variable,
|
||||||
|
output: Variable,
|
||||||
|
_elem: PhantomData<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
||||||
|
fn expand(self, scope: &mut Scope) {
|
||||||
|
let input = self.input;
|
||||||
|
let weight = self.weight;
|
||||||
|
let bias = self.bias;
|
||||||
|
let output = self.output;
|
||||||
|
let id = Variable::Id;
|
||||||
|
|
||||||
|
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||||
|
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||||
|
let input_stride_2 = scope.create_local(Elem::UInt);
|
||||||
|
let input_stride_3 = scope.create_local(Elem::UInt);
|
||||||
|
let input_shape_0 = scope.create_local(Elem::UInt);
|
||||||
|
let input_shape_1 = scope.create_local(Elem::UInt);
|
||||||
|
let input_shape_2 = scope.create_local(Elem::UInt);
|
||||||
|
let input_shape_3 = scope.create_local(Elem::UInt);
|
||||||
|
gpu!(scope, input_stride_0 = stride(input, 0u32));
|
||||||
|
gpu!(scope, input_stride_1 = stride(input, 1u32));
|
||||||
|
gpu!(scope, input_stride_2 = stride(input, 2u32));
|
||||||
|
gpu!(scope, input_stride_3 = stride(input, 3u32));
|
||||||
|
gpu!(scope, input_shape_0 = shape(input, 0u32));
|
||||||
|
gpu!(scope, input_shape_1 = shape(input, 1u32));
|
||||||
|
gpu!(scope, input_shape_2 = shape(input, 2u32));
|
||||||
|
gpu!(scope, input_shape_3 = shape(input, 3u32));
|
||||||
|
|
||||||
|
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||||
|
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||||
|
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||||
|
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||||
|
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||||
|
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||||
|
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||||
|
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||||
|
gpu!(scope, output_stride_0 = stride(output, 0u32));
|
||||||
|
gpu!(scope, output_stride_1 = stride(output, 1u32));
|
||||||
|
gpu!(scope, output_stride_2 = stride(output, 2u32));
|
||||||
|
gpu!(scope, output_stride_3 = stride(output, 3u32));
|
||||||
|
gpu!(scope, output_shape_0 = shape(output, 0u32));
|
||||||
|
gpu!(scope, output_shape_1 = shape(output, 1u32));
|
||||||
|
gpu!(scope, output_shape_2 = shape(output, 2u32));
|
||||||
|
gpu!(scope, output_shape_3 = shape(output, 3u32));
|
||||||
|
|
||||||
|
let weight_stride_0 = scope.create_local(Elem::UInt);
|
||||||
|
let weight_stride_1 = scope.create_local(Elem::UInt);
|
||||||
|
let weight_stride_2 = scope.create_local(Elem::UInt);
|
||||||
|
let weight_stride_3 = scope.create_local(Elem::UInt);
|
||||||
|
let in_channels = scope.create_local(Elem::UInt);
|
||||||
|
let weight_shape_1 = scope.create_local(Elem::UInt);
|
||||||
|
let kernel_size_0 = scope.create_local(Elem::UInt);
|
||||||
|
let kernel_size_1 = scope.create_local(Elem::UInt);
|
||||||
|
gpu!(scope, weight_stride_0 = stride(weight, 0u32));
|
||||||
|
gpu!(scope, weight_stride_1 = stride(weight, 1u32));
|
||||||
|
gpu!(scope, weight_stride_2 = stride(weight, 2u32));
|
||||||
|
gpu!(scope, weight_stride_3 = stride(weight, 3u32));
|
||||||
|
gpu!(scope, in_channels = shape(weight, 0u32));
|
||||||
|
gpu!(scope, weight_shape_1 = shape(weight, 1u32));
|
||||||
|
gpu!(scope, kernel_size_0 = shape(weight, 2u32));
|
||||||
|
gpu!(scope, kernel_size_1 = shape(weight, 3u32));
|
||||||
|
|
||||||
|
let conv_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
||||||
|
let conv_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
||||||
|
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
|
||||||
|
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
|
||||||
|
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
||||||
|
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
||||||
|
let groups = Variable::GlobalScalar(6, Elem::UInt);
|
||||||
|
|
||||||
|
let stride_0_i = scope.create_local(Elem::Int);
|
||||||
|
let stride_1_i = scope.create_local(Elem::Int);
|
||||||
|
gpu!(scope, stride_0_i = cast(conv_stride_0));
|
||||||
|
gpu!(scope, stride_1_i = cast(conv_stride_1));
|
||||||
|
|
||||||
|
let oc_out = scope.create_local(Elem::UInt);
|
||||||
|
let oc = scope.create_local(Elem::UInt);
|
||||||
|
|
||||||
|
let b = scope.create_local(Elem::UInt);
|
||||||
|
let oh = scope.create_local(Elem::UInt);
|
||||||
|
let ow = scope.create_local(Elem::UInt);
|
||||||
|
let k = scope.create_local(Elem::UInt);
|
||||||
|
let g = scope.create_local(Elem::UInt);
|
||||||
|
|
||||||
|
let ic_start = scope.create_local(Elem::UInt);
|
||||||
|
let ic_end = scope.create_local(Elem::UInt);
|
||||||
|
let ic_tmp = scope.create_local(Elem::UInt);
|
||||||
|
|
||||||
|
gpu!(scope, b = id / output_stride_0);
|
||||||
|
gpu!(scope, b = b % output_shape_0);
|
||||||
|
|
||||||
|
gpu!(scope, oc_out = id / output_stride_1);
|
||||||
|
gpu!(scope, oc_out = oc_out % output_shape_1);
|
||||||
|
|
||||||
|
gpu!(scope, oh = id / output_stride_2);
|
||||||
|
gpu!(scope, oh = oh % output_shape_2);
|
||||||
|
|
||||||
|
gpu!(scope, ow = id / output_stride_3);
|
||||||
|
gpu!(scope, ow = ow % output_shape_3);
|
||||||
|
|
||||||
|
gpu!(scope, k = oc_out / weight_shape_1);
|
||||||
|
gpu!(scope, g = k % groups);
|
||||||
|
gpu!(scope, oc = weight_shape_1 * g);
|
||||||
|
gpu!(scope, oc = oc_out - oc);
|
||||||
|
|
||||||
|
gpu!(scope, ic_tmp = in_channels / groups);
|
||||||
|
gpu!(scope, ic_start = g * ic_tmp);
|
||||||
|
gpu!(scope, ic_end = ic_start + ic_tmp);
|
||||||
|
|
||||||
|
let tmp_u = scope.create_local(Elem::UInt);
|
||||||
|
let tmp_i = scope.create_local(Elem::Int);
|
||||||
|
let zero_i = scope.zero(Elem::Int);
|
||||||
|
let one_i = scope.create_with_value(1, Elem::Int);
|
||||||
|
|
||||||
|
let kms_u = scope.create_local(Elem::UInt);
|
||||||
|
let kms_0 = scope.create_local(Elem::Int);
|
||||||
|
let kms_1 = scope.create_local(Elem::Int);
|
||||||
|
let ih_start_tmp = scope.create_local(Elem::Int);
|
||||||
|
let iw_start_tmp = scope.create_local(Elem::Int);
|
||||||
|
let ih_start = scope.create_local(Elem::UInt);
|
||||||
|
let iw_start = scope.create_local(Elem::UInt);
|
||||||
|
let ih_end = scope.create_local(Elem::UInt);
|
||||||
|
let iw_end = scope.create_local(Elem::UInt);
|
||||||
|
|
||||||
|
gpu!(scope, kms_u = kernel_size_0 * dilation_0);
|
||||||
|
gpu!(scope, kms_0 = cast(kms_u));
|
||||||
|
gpu!(scope, kms_0 = kms_0 - stride_0_i);
|
||||||
|
gpu!(scope, kms_u = kernel_size_1 * dilation_1);
|
||||||
|
gpu!(scope, kms_1 = cast(kms_u));
|
||||||
|
gpu!(scope, kms_1 = kms_1 - stride_1_i);
|
||||||
|
|
||||||
|
gpu!(scope, tmp_u = oh + padding_0);
|
||||||
|
gpu!(scope, tmp_i = cast(tmp_u));
|
||||||
|
gpu!(scope, ih_start_tmp = tmp_i - kms_0);
|
||||||
|
gpu!(scope, ih_start_tmp = ih_start_tmp / stride_0_i);
|
||||||
|
gpu!(scope, tmp_u = ow + padding_1);
|
||||||
|
gpu!(scope, tmp_i = cast(tmp_u));
|
||||||
|
gpu!(scope, iw_start_tmp = tmp_i - kms_1);
|
||||||
|
gpu!(scope, iw_start_tmp = iw_start_tmp / stride_1_i);
|
||||||
|
|
||||||
|
gpu!(scope, tmp_i = max(ih_start_tmp, zero_i));
|
||||||
|
gpu!(scope, ih_start = cast(tmp_i));
|
||||||
|
gpu!(scope, tmp_i = kms_0 + ih_start_tmp);
|
||||||
|
gpu!(scope, tmp_i += one_i);
|
||||||
|
gpu!(scope, tmp_i = max(tmp_i, zero_i));
|
||||||
|
gpu!(scope, tmp_u = cast(tmp_i));
|
||||||
|
gpu!(scope, ih_end = min(tmp_u, input_shape_2));
|
||||||
|
|
||||||
|
gpu!(scope, tmp_i = max(iw_start_tmp, zero_i));
|
||||||
|
gpu!(scope, iw_start = cast(tmp_i));
|
||||||
|
gpu!(scope, tmp_i = kms_1 + iw_start_tmp);
|
||||||
|
gpu!(scope, tmp_i += one_i);
|
||||||
|
gpu!(scope, tmp_i = max(tmp_i, zero_i));
|
||||||
|
gpu!(scope, tmp_u = cast(tmp_i));
|
||||||
|
gpu!(scope, iw_end = min(tmp_u, input_shape_3));
|
||||||
|
|
||||||
|
let index_input = scope.create_local(Elem::UInt);
|
||||||
|
let index_weight = scope.create_local(Elem::UInt);
|
||||||
|
|
||||||
|
let index_input_b = scope.create_local(Elem::UInt);
|
||||||
|
let index_input_ic = scope.create_local(Elem::UInt);
|
||||||
|
let index_input_ih = scope.create_local(Elem::UInt);
|
||||||
|
let index_input_iw = scope.create_local(Elem::UInt);
|
||||||
|
let index_weight_ic = scope.create_local(Elem::UInt);
|
||||||
|
let index_weight_oc = scope.create_local(Elem::UInt);
|
||||||
|
let index_weight_kh = scope.create_local(Elem::UInt);
|
||||||
|
let index_weight_kw = scope.create_local(Elem::UInt);
|
||||||
|
|
||||||
|
gpu!(scope, index_input_b = b * input_stride_0);
|
||||||
|
gpu!(scope, index_weight_oc = oc * weight_stride_1);
|
||||||
|
|
||||||
|
let prod = scope.create_local(output.item());
|
||||||
|
let prod_tmp = scope.create_local(output.item());
|
||||||
|
let sum = scope.create_local(output.item());
|
||||||
|
gpu!(scope, sum = bias[oc_out]);
|
||||||
|
|
||||||
|
let kh = scope.create_local(Elem::UInt);
|
||||||
|
let kw = scope.create_local(Elem::UInt);
|
||||||
|
let numerator_h_base = scope.create_local(Elem::UInt);
|
||||||
|
let numerator_h = scope.create_local(Elem::UInt);
|
||||||
|
let numerator_w_base = scope.create_local(Elem::UInt);
|
||||||
|
let numerator_w = scope.create_local(Elem::UInt);
|
||||||
|
let numerator_tmp = scope.create_local(Elem::UInt);
|
||||||
|
let numerator_mod = scope.create_local(Elem::UInt);
|
||||||
|
let zero = scope.zero(Elem::UInt);
|
||||||
|
let divisible = scope.create_local(Elem::Bool);
|
||||||
|
let not_neg = scope.create_local(Elem::Bool);
|
||||||
|
let cond = scope.create_local(Elem::Bool);
|
||||||
|
|
||||||
|
gpu!(scope, numerator_h_base = oh + padding_0);
|
||||||
|
gpu!(scope, numerator_w_base = ow + padding_1);
|
||||||
|
|
||||||
|
gpu!(
|
||||||
|
scope,
|
||||||
|
range(ic_start, ic_end).for_each(|ic, scope| {
|
||||||
|
gpu!(scope, index_input_ic = ic * input_stride_1);
|
||||||
|
gpu!(scope, index_weight_ic = ic * weight_stride_0);
|
||||||
|
|
||||||
|
gpu!(
|
||||||
|
scope,
|
||||||
|
range(ih_start, ih_end).for_each(|ih, scope| {
|
||||||
|
gpu!(scope, numerator_tmp = ih * conv_stride_0);
|
||||||
|
gpu!(scope, not_neg = numerator_h_base >= numerator_tmp);
|
||||||
|
gpu!(scope, numerator_h = numerator_h_base - numerator_tmp);
|
||||||
|
|
||||||
|
gpu!(scope, numerator_mod = numerator_h % dilation_0);
|
||||||
|
gpu!(scope, divisible = numerator_mod == zero);
|
||||||
|
gpu!(scope, cond = not_neg && divisible);
|
||||||
|
|
||||||
|
gpu!(scope, if(cond).then(|scope|{
|
||||||
|
gpu!(scope, kh = numerator_h / dilation_0);
|
||||||
|
gpu!(scope, index_input_ih = ih * input_stride_2);
|
||||||
|
gpu!(scope, index_weight_kh = kh * weight_stride_2);
|
||||||
|
|
||||||
|
gpu!(
|
||||||
|
scope,
|
||||||
|
range(iw_start, iw_end).for_each(|iw, scope| {
|
||||||
|
gpu!(scope, numerator_tmp = iw * conv_stride_1);
|
||||||
|
gpu!(scope, not_neg = numerator_w_base >= numerator_tmp);
|
||||||
|
gpu!(scope, numerator_w = numerator_w_base - numerator_tmp);
|
||||||
|
|
||||||
|
gpu!(scope, numerator_mod = numerator_w % dilation_1);
|
||||||
|
gpu!(scope, divisible = numerator_mod == zero);
|
||||||
|
gpu!(scope, cond = not_neg && divisible);
|
||||||
|
|
||||||
|
gpu!(scope, if(cond).then(|scope|{
|
||||||
|
gpu!(scope, kw = numerator_w / dilation_1);
|
||||||
|
gpu!(scope, index_input_iw = iw * input_stride_3);
|
||||||
|
gpu!(scope, index_weight_kw = kw * weight_stride_3);
|
||||||
|
|
||||||
|
gpu!(scope, index_input = index_input_b);
|
||||||
|
gpu!(scope, index_input += index_input_ic);
|
||||||
|
gpu!(scope, index_input += index_input_ih);
|
||||||
|
gpu!(scope, index_input += index_input_iw);
|
||||||
|
|
||||||
|
gpu!(scope, index_weight = index_weight_ic);
|
||||||
|
gpu!(scope, index_weight += index_weight_oc);
|
||||||
|
gpu!(scope, index_weight += index_weight_kh);
|
||||||
|
gpu!(scope, index_weight += index_weight_kw);
|
||||||
|
|
||||||
|
gpu!(scope, prod = input[index_input]);
|
||||||
|
gpu!(scope, prod_tmp = weight[index_weight]);
|
||||||
|
gpu!(scope, prod *= prod_tmp);
|
||||||
|
gpu!(scope, sum += prod);
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
);
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
gpu!(scope, output[id] = sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: Runtime, E: JitElement> DynamicKernelSource for Conv2dTransposeEagerKernel<R, E> {
|
||||||
|
fn source(&self) -> kernel::SourceTemplate {
|
||||||
|
let mut scope = Scope::root();
|
||||||
|
let item = E::gpu_elem().into();
|
||||||
|
|
||||||
|
let input = Variable::GlobalInputArray(0, item);
|
||||||
|
let weight = Variable::GlobalInputArray(1, item);
|
||||||
|
let bias = Variable::GlobalInputArray(2, item);
|
||||||
|
let output = Variable::GlobalOutputArray(0, item);
|
||||||
|
|
||||||
|
scope.write_global_custom(output);
|
||||||
|
|
||||||
|
Conv2dTransposeComputeShader {
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
output,
|
||||||
|
_elem: PhantomData::<E>,
|
||||||
|
}
|
||||||
|
.expand(&mut scope);
|
||||||
|
|
||||||
|
let input = InputInfo::Array {
|
||||||
|
item,
|
||||||
|
visibility: Visibility::Read,
|
||||||
|
};
|
||||||
|
let weight = InputInfo::Array {
|
||||||
|
item,
|
||||||
|
visibility: Visibility::Read,
|
||||||
|
};
|
||||||
|
let bias = InputInfo::Array {
|
||||||
|
item,
|
||||||
|
visibility: Visibility::Read,
|
||||||
|
};
|
||||||
|
let scalars = InputInfo::Scalar {
|
||||||
|
elem: Elem::UInt,
|
||||||
|
size: 7,
|
||||||
|
};
|
||||||
|
|
||||||
|
let output = OutputInfo::Array { item };
|
||||||
|
|
||||||
|
let info = CompilationInfo {
|
||||||
|
inputs: vec![input, weight, bias, scalars],
|
||||||
|
outputs: vec![output],
|
||||||
|
scope,
|
||||||
|
};
|
||||||
|
|
||||||
|
let settings = CompilationSettings::default();
|
||||||
|
let shader = Compilation::new(info).compile(settings);
|
||||||
|
let shader = <R::Compiler as Compiler>::compile(shader);
|
||||||
|
SourceTemplate::new(shader.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn id(&self) -> String {
|
||||||
|
format!("{:?}", core::any::TypeId::of::<Self>())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn conv_transpose2d<R: Runtime, E: JitElement + Element>(
|
pub(crate) fn conv_transpose2d<R: Runtime, E: JitElement + Element>(
|
||||||
input: JitTensor<R, E, 4>,
|
input: JitTensor<R, E, 4>,
|
||||||
|
@ -34,42 +367,47 @@ pub(crate) fn conv_transpose2d<R: Runtime, E: JitElement + Element>(
|
||||||
+ 1;
|
+ 1;
|
||||||
|
|
||||||
let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]);
|
let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]);
|
||||||
let num_elems = shape_out.num_elements();
|
|
||||||
|
|
||||||
let output = empty_device(
|
let output = empty_device(
|
||||||
input.client.clone(),
|
input.client.clone(),
|
||||||
input.device.clone(),
|
input.device.clone(),
|
||||||
shape_out.clone(),
|
shape_out.clone(),
|
||||||
);
|
);
|
||||||
let mut info = build_info(&[&input, &output, &weight]);
|
|
||||||
|
|
||||||
info.push(options.stride[0] as u32);
|
let bias = match bias {
|
||||||
info.push(options.stride[1] as u32);
|
Some(bias) => {
|
||||||
info.push(options.padding[0] as u32);
|
let shape = Shape::from([bias.shape.dims[0], 1, 1, 1]);
|
||||||
info.push(options.padding[1] as u32);
|
reshape(bias, shape)
|
||||||
info.push(options.dilation[0] as u32);
|
}
|
||||||
info.push(options.dilation[1] as u32);
|
None => {
|
||||||
info.push(options.groups as u32);
|
let shape = Shape::from([output.shape.dims[0], 1, 1, 1]);
|
||||||
|
zeros_device(input.client.clone(), input.device.clone(), shape)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let bias_handle = bias
|
let kernel = Conv2dTransposeEagerKernel::<R, E>::new();
|
||||||
.map(|bias| bias.handle)
|
|
||||||
.unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()])));
|
|
||||||
|
|
||||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
Execution::start(kernel, input.client.clone())
|
||||||
|
.inputs(&[
|
||||||
let kernel = StaticKernel::<
|
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
|
||||||
KernelSettings<ConvTranspose2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
EagerHandle::new(&weight.handle, &weight.strides, &weight.shape.dims),
|
||||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
EagerHandle::new(&bias.handle, &bias.strides, &bias.shape.dims),
|
||||||
input.client.execute(
|
])
|
||||||
Box::new(kernel),
|
.outputs(&[EagerHandle::new(
|
||||||
&[
|
|
||||||
&input.handle,
|
|
||||||
&weight.handle,
|
|
||||||
&bias_handle,
|
|
||||||
&output.handle,
|
&output.handle,
|
||||||
&info_handle,
|
&output.strides,
|
||||||
],
|
&output.shape.dims,
|
||||||
);
|
)])
|
||||||
|
.with_scalars(&[
|
||||||
|
(options.stride[0] as u32).elem::<u32>(),
|
||||||
|
(options.stride[1] as u32).elem(),
|
||||||
|
(options.dilation[0] as u32).elem(),
|
||||||
|
(options.dilation[1] as u32).elem(),
|
||||||
|
(options.padding[0] as u32).elem(),
|
||||||
|
(options.padding[1] as u32).elem(),
|
||||||
|
(options.groups as u32).elem(),
|
||||||
|
])
|
||||||
|
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||||
|
|
||||||
output
|
output
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,4 +2,4 @@ mod conv2d;
|
||||||
mod conv_transpose2d;
|
mod conv_transpose2d;
|
||||||
|
|
||||||
pub(crate) use conv2d::*;
|
pub(crate) use conv2d::*;
|
||||||
pub use conv_transpose2d::*;
|
pub(crate) use conv_transpose2d::*;
|
||||||
|
|
|
@ -1,122 +0,0 @@
|
||||||
@group(0)
|
|
||||||
@binding(0)
|
|
||||||
var<storage, read> input: array<{{ elem }}>;
|
|
||||||
|
|
||||||
@group(0)
|
|
||||||
@binding(1)
|
|
||||||
var<storage, read> weight: array<{{ elem }}>;
|
|
||||||
|
|
||||||
@group(0)
|
|
||||||
@binding(2)
|
|
||||||
var<storage, read> bias: array<{{ elem }}>;
|
|
||||||
|
|
||||||
@group(0)
|
|
||||||
@binding(3)
|
|
||||||
var<storage, read_write> output: array<{{ elem }}>;
|
|
||||||
|
|
||||||
@group(0)
|
|
||||||
@binding(4)
|
|
||||||
var<storage, read> info: array<u32, 32>;
|
|
||||||
|
|
||||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
|
||||||
|
|
||||||
@compute
|
|
||||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
|
||||||
fn main(
|
|
||||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
|
||||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
|
||||||
) {
|
|
||||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
|
||||||
|
|
||||||
let input_stride_0 = info[1];
|
|
||||||
let input_stride_1 = info[2];
|
|
||||||
let input_stride_2 = info[3];
|
|
||||||
let input_stride_3 = info[4];
|
|
||||||
let output_stride_0 = info[5];
|
|
||||||
let output_stride_1 = info[6];
|
|
||||||
let output_stride_2 = info[7];
|
|
||||||
let output_stride_3 = info[8];
|
|
||||||
let weight_stride_0 = info[9];
|
|
||||||
let weight_stride_1 = info[10];
|
|
||||||
let weight_stride_2 = info[11];
|
|
||||||
let weight_stride_3 = info[12];
|
|
||||||
|
|
||||||
let input_shape_0 = info[13];
|
|
||||||
let input_shape_1 = info[14];
|
|
||||||
let input_shape_2 = info[15];
|
|
||||||
let input_shape_3 = info[16];
|
|
||||||
let output_shape_0 = info[17];
|
|
||||||
let output_shape_1 = info[18];
|
|
||||||
let output_shape_2 = info[19];
|
|
||||||
let output_shape_3 = info[20];
|
|
||||||
let weight_shape_0 = info[21];
|
|
||||||
let weight_shape_1 = info[22];
|
|
||||||
let weight_shape_2 = info[23];
|
|
||||||
let weight_shape_3 = info[24];
|
|
||||||
|
|
||||||
let stride_0 = info[25];
|
|
||||||
let stride_1 = info[26];
|
|
||||||
let padding_0 = info[27];
|
|
||||||
let padding_1 = info[28];
|
|
||||||
let dilation_0 = info[29];
|
|
||||||
let dilation_1 = info[30];
|
|
||||||
let groups = info[31];
|
|
||||||
|
|
||||||
let in_channels = weight_shape_0;
|
|
||||||
let kernel_size_0 = weight_shape_2;
|
|
||||||
let kernel_size_1 = weight_shape_3;
|
|
||||||
|
|
||||||
let b = id / output_stride_0 % output_shape_0;
|
|
||||||
let oc_out = id / output_stride_1 % output_shape_1;
|
|
||||||
let oh = id / output_stride_2 % output_shape_2;
|
|
||||||
let ow = id / output_stride_3 % output_shape_3;
|
|
||||||
|
|
||||||
let k = oc_out / weight_shape_1;
|
|
||||||
let g = k % groups;
|
|
||||||
let oc = oc_out - (weight_shape_1 * g);
|
|
||||||
|
|
||||||
var sum = bias[oc_out];
|
|
||||||
|
|
||||||
let ic_start = g * (in_channels / groups);
|
|
||||||
let ic_end = ic_start + in_channels / groups;
|
|
||||||
|
|
||||||
// The maximum number of overlapping filters that may content the current index.
|
|
||||||
let kms_0 = i32(kernel_size_0 * dilation_0) - i32(stride_0);
|
|
||||||
let kms_1 = i32(kernel_size_1 * dilation_1) - i32(stride_1);
|
|
||||||
|
|
||||||
let ih_start_tmp = (i32(oh + padding_0) - kms_0) / i32(stride_0);
|
|
||||||
let iw_start_tmp = (i32(ow + padding_1) - kms_1) / i32(stride_1);
|
|
||||||
|
|
||||||
let ih_start = u32(max(ih_start_tmp, 0));
|
|
||||||
let iw_start = u32(max(iw_start_tmp, 0));
|
|
||||||
|
|
||||||
let ih_end = min(u32(max(kms_0 + ih_start_tmp + 1, 0)), input_shape_2);
|
|
||||||
let iw_end = min(u32(max(kms_1 + iw_start_tmp + 1, 0)), input_shape_3);
|
|
||||||
|
|
||||||
for (var ic = ic_start; ic < ic_end; ic++) {
|
|
||||||
for (var ih = ih_start; ih < ih_end; ih++) {
|
|
||||||
for (var iw = iw_start; iw < iw_end; iw++) {
|
|
||||||
for (var kh = 0u; kh < kernel_size_0; kh++) {
|
|
||||||
for (var kw = 0u; kw < kernel_size_1; kw++) {
|
|
||||||
let oh_tmp = ih * stride_0 + kh * dilation_0;
|
|
||||||
let ow_tmp = iw * stride_1 + kw * dilation_1;
|
|
||||||
|
|
||||||
if oh_tmp >= padding_0 && ow_tmp >= padding_1 {
|
|
||||||
let oh_tmp_pad = oh_tmp - padding_0;
|
|
||||||
let ow_tmp_pad = ow_tmp - padding_1;
|
|
||||||
|
|
||||||
if oh_tmp_pad == oh && ow_tmp_pad == ow {
|
|
||||||
let index_input = b * input_stride_0 + ic * input_stride_1 + ih * input_stride_2 + iw * input_stride_3;
|
|
||||||
let index_weight = ic * weight_stride_0 + oc * weight_stride_1 + kh * weight_stride_2 + kw * weight_stride_3;
|
|
||||||
|
|
||||||
sum += input[index_input] * weight[index_weight];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
output[id] = sum;
|
|
||||||
}
|
|
Loading…
Reference in New Issue