mirror of https://github.com/tracel-ai/burn.git
Fix/wgpu/max pool2d backward (#613)
This commit is contained in:
parent
894783f08d
commit
c74e75f748
|
@ -4,9 +4,45 @@ mod tests {
|
|||
use burn_tensor::{module::max_pool2d, Data};
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_simple() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
fn test_max_pool2d_simple_1() {
|
||||
let kernel_size_1 = 3;
|
||||
let kernel_size_2 = 3;
|
||||
let padding_1 = 0;
|
||||
let padding_2 = 0;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
|
||||
let x = TestADTensor::from_floats([[[
|
||||
[0.2479, 0.6386, 0.3166, 0.5742],
|
||||
[0.7065, 0.1940, 0.6305, 0.8959],
|
||||
[0.5416, 0.8602, 0.8129, 0.1662],
|
||||
[0.3358, 0.3059, 0.8293, 0.0990],
|
||||
]]])
|
||||
.require_grad();
|
||||
let x_grad_expected = TestADTensor::from_floats([[[
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 2.0],
|
||||
[0.0, 2.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
]]]);
|
||||
|
||||
let output = max_pool2d(
|
||||
x.clone(),
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_simple_2() {
|
||||
let kernel_size_1 = 2;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 1;
|
||||
|
@ -45,8 +81,6 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_max_pool2d_complex() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size_1 = 4;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 2;
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{self, elemwise_workgroup, pool::build_output_and_info_pool2d, KernelSettings},
|
||||
kernel::{
|
||||
self, elemwise_workgroup,
|
||||
pool::{build_output_and_info_pool2d, build_pool2d_info},
|
||||
KernelSettings,
|
||||
},
|
||||
kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
@ -87,31 +91,7 @@ pub(crate) fn max_pool2d_with_indices_backward<E: WgpuElement, I: WgpuElement>(
|
|||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(x.context.clone(), x.shape.clone(), buffer);
|
||||
|
||||
let mut info: [u32; 18] = [0; 18];
|
||||
info[0] = x.strides[0] as u32;
|
||||
info[1] = x.strides[1] as u32;
|
||||
info[2] = x.strides[2] as u32;
|
||||
info[3] = x.strides[3] as u32;
|
||||
info[4] = x.shape.dims[0] as u32;
|
||||
info[5] = x.shape.dims[1] as u32;
|
||||
info[6] = x.shape.dims[2] as u32;
|
||||
info[7] = x.shape.dims[3] as u32;
|
||||
|
||||
info[8] = grad.strides[0] as u32;
|
||||
info[9] = grad.strides[1] as u32;
|
||||
info[10] = grad.strides[2] as u32;
|
||||
info[11] = grad.strides[3] as u32;
|
||||
|
||||
info[12] = kernel_size[0] as u32;
|
||||
info[13] = kernel_size[1] as u32;
|
||||
info[14] = stride[0] as u32;
|
||||
info[15] = stride[1] as u32;
|
||||
info[16] = padding[0] as u32;
|
||||
info[17] = padding[1] as u32;
|
||||
|
||||
let info_buffer = x
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding);
|
||||
|
||||
let kernel = x.context.compile_static::<KernelSettings<
|
||||
MaxPool2dWithIndicesBackward,
|
||||
|
|
|
@ -13,7 +13,7 @@ var<storage, read_write> output: array<{{ elem }}>;
|
|||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32, 18>;
|
||||
var<storage, read> info: array<u32, 22>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
|
@ -38,13 +38,17 @@ fn main(
|
|||
let grad_stride_1 = info[9];
|
||||
let grad_stride_2 = info[10];
|
||||
let grad_stride_3 = info[11];
|
||||
let grad_shape_0 = info[12];
|
||||
let grad_shape_1 = info[13];
|
||||
let grad_shape_2 = info[14];
|
||||
let grad_shape_3 = info[15];
|
||||
|
||||
let kernel_size_0 = info[12];
|
||||
let kernel_size_1 = info[13];
|
||||
let pool_stride_0 = info[14];
|
||||
let pool_stride_1 = info[15];
|
||||
let padding_0 = info[16];
|
||||
let padding_1 = info[17];
|
||||
let kernel_size_0 = info[16];
|
||||
let kernel_size_1 = info[17];
|
||||
let pool_stride_0 = info[18];
|
||||
let pool_stride_1 = info[19];
|
||||
let padding_0 = info[20];
|
||||
let padding_1 = info[21];
|
||||
|
||||
let b = id / input_stride_0 % input_shape_0;
|
||||
let c = id / input_stride_1 % input_shape_1;
|
||||
|
@ -61,8 +65,8 @@ fn main(
|
|||
let oh_start = u32(max(oh_start_tmp, 0));
|
||||
let ow_start = u32(max(ow_start_tmp, 0));
|
||||
|
||||
let oh_end = u32(max(kms_0, 0)) + oh_start;
|
||||
let ow_end = u32(max(kms_1, 0)) + ow_start;
|
||||
let oh_end = min(u32(max(kms_0, 0)) + oh_start, grad_shape_2 - 1u);
|
||||
let ow_end = min(u32(max(kms_1, 0)) + ow_start, grad_shape_3 - 1u);
|
||||
|
||||
let index_current = ih * input_stride_2 + iw * input_stride_3;
|
||||
var grad_acc = 0.0;
|
||||
|
|
Loading…
Reference in New Issue