diff --git a/burn-autodiff/src/tests/maxpool2d.rs b/burn-autodiff/src/tests/maxpool2d.rs index 25a7c332c..ae7a43b7b 100644 --- a/burn-autodiff/src/tests/maxpool2d.rs +++ b/burn-autodiff/src/tests/maxpool2d.rs @@ -4,9 +4,45 @@ mod tests { use burn_tensor::{module::max_pool2d, Data}; #[test] - fn test_max_pool2d_simple() { - let batch_size = 1; - let channels_in = 1; + fn test_max_pool2d_simple_1() { + let kernel_size_1 = 3; + let kernel_size_2 = 3; + let padding_1 = 0; + let padding_2 = 0; + let stride_1 = 1; + let stride_2 = 1; + + let x = TestADTensor::from_floats([[[ + [0.2479, 0.6386, 0.3166, 0.5742], + [0.7065, 0.1940, 0.6305, 0.8959], + [0.5416, 0.8602, 0.8129, 0.1662], + [0.3358, 0.3059, 0.8293, 0.0990], + ]]]) + .require_grad(); + let x_grad_expected = TestADTensor::from_floats([[[ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 2.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ]]]); + + let output = max_pool2d( + x.clone(), + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + ); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool2d_simple_2() { let kernel_size_1 = 2; let kernel_size_2 = 2; let padding_1 = 1; @@ -45,8 +81,6 @@ mod tests { #[test] fn test_max_pool2d_complex() { - let batch_size = 1; - let channels_in = 1; let kernel_size_1 = 4; let kernel_size_2 = 2; let padding_1 = 2; diff --git a/burn-wgpu/src/kernel/pool/max_pool2d.rs b/burn-wgpu/src/kernel/pool/max_pool2d.rs index 7ac1beca7..7bec89639 100644 --- a/burn-wgpu/src/kernel/pool/max_pool2d.rs +++ b/burn-wgpu/src/kernel/pool/max_pool2d.rs @@ -1,6 +1,10 @@ use crate::{ element::WgpuElement, - kernel::{self, elemwise_workgroup, pool::build_output_and_info_pool2d, KernelSettings}, + kernel::{ + self, elemwise_workgroup, + pool::{build_output_and_info_pool2d, build_pool2d_info}, + KernelSettings, + }, kernel_wgsl, tensor::WgpuTensor, }; @@ -87,31 +91,7 @@ pub(crate) fn max_pool2d_with_indices_backward( .create_buffer(num_elems * core::mem::size_of::()); let output = WgpuTensor::new(x.context.clone(), x.shape.clone(), buffer); - let mut info: [u32; 18] = [0; 18]; - info[0] = x.strides[0] as u32; - info[1] = x.strides[1] as u32; - info[2] = x.strides[2] as u32; - info[3] = x.strides[3] as u32; - info[4] = x.shape.dims[0] as u32; - info[5] = x.shape.dims[1] as u32; - info[6] = x.shape.dims[2] as u32; - info[7] = x.shape.dims[3] as u32; - - info[8] = grad.strides[0] as u32; - info[9] = grad.strides[1] as u32; - info[10] = grad.strides[2] as u32; - info[11] = grad.strides[3] as u32; - - info[12] = kernel_size[0] as u32; - info[13] = kernel_size[1] as u32; - info[14] = stride[0] as u32; - info[15] = stride[1] as u32; - info[16] = padding[0] as u32; - info[17] = padding[1] as u32; - - let info_buffer = x - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding); let kernel = x.context.compile_static:: output: array<{{ elem }}>; @group(0) @binding(3) -var info: array; +var info: array; const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; @@ -38,13 +38,17 @@ fn main( let grad_stride_1 = info[9]; let grad_stride_2 = info[10]; let grad_stride_3 = info[11]; + let grad_shape_0 = info[12]; + let grad_shape_1 = info[13]; + let grad_shape_2 = info[14]; + let grad_shape_3 = info[15]; - let kernel_size_0 = info[12]; - let kernel_size_1 = info[13]; - let pool_stride_0 = info[14]; - let pool_stride_1 = info[15]; - let padding_0 = info[16]; - let padding_1 = info[17]; + let kernel_size_0 = info[16]; + let kernel_size_1 = info[17]; + let pool_stride_0 = info[18]; + let pool_stride_1 = info[19]; + let padding_0 = info[20]; + let padding_1 = info[21]; let b = id / input_stride_0 % input_shape_0; let c = id / input_stride_1 % input_shape_1; @@ -61,8 +65,8 @@ fn main( let oh_start = u32(max(oh_start_tmp, 0)); let ow_start = u32(max(ow_start_tmp, 0)); - let oh_end = u32(max(kms_0, 0)) + oh_start; - let ow_end = u32(max(kms_1, 0)) + ow_start; + let oh_end = min(u32(max(kms_0, 0)) + oh_start, grad_shape_2 - 1u); + let ow_end = min(u32(max(kms_1, 0)) + ow_start, grad_shape_3 - 1u); let index_current = ih * input_stride_2 + iw * input_stride_3; var grad_acc = 0.0;