Fix/wgpu/max pool2d backward (#613)

This commit is contained in:
Nathaniel Simard 2023-08-09 16:45:49 -04:00 committed by GitHub
parent 894783f08d
commit c74e75f748
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 40 deletions

View File

@ -4,9 +4,45 @@ mod tests {
use burn_tensor::{module::max_pool2d, Data};
#[test]
fn test_max_pool2d_simple() {
let batch_size = 1;
let channels_in = 1;
fn test_max_pool2d_simple_1() {
let kernel_size_1 = 3;
let kernel_size_2 = 3;
let padding_1 = 0;
let padding_2 = 0;
let stride_1 = 1;
let stride_2 = 1;
let x = TestADTensor::from_floats([[[
[0.2479, 0.6386, 0.3166, 0.5742],
[0.7065, 0.1940, 0.6305, 0.8959],
[0.5416, 0.8602, 0.8129, 0.1662],
[0.3358, 0.3059, 0.8293, 0.0990],
]]])
.require_grad();
let x_grad_expected = TestADTensor::from_floats([[[
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 2.0],
[0.0, 2.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
]]]);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
}
#[test]
fn test_max_pool2d_simple_2() {
let kernel_size_1 = 2;
let kernel_size_2 = 2;
let padding_1 = 1;
@ -45,8 +81,6 @@ mod tests {
#[test]
fn test_max_pool2d_complex() {
let batch_size = 1;
let channels_in = 1;
let kernel_size_1 = 4;
let kernel_size_2 = 2;
let padding_1 = 2;

View File

@ -1,6 +1,10 @@
use crate::{
element::WgpuElement,
kernel::{self, elemwise_workgroup, pool::build_output_and_info_pool2d, KernelSettings},
kernel::{
self, elemwise_workgroup,
pool::{build_output_and_info_pool2d, build_pool2d_info},
KernelSettings,
},
kernel_wgsl,
tensor::WgpuTensor,
};
@ -87,31 +91,7 @@ pub(crate) fn max_pool2d_with_indices_backward<E: WgpuElement, I: WgpuElement>(
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(x.context.clone(), x.shape.clone(), buffer);
let mut info: [u32; 18] = [0; 18];
info[0] = x.strides[0] as u32;
info[1] = x.strides[1] as u32;
info[2] = x.strides[2] as u32;
info[3] = x.strides[3] as u32;
info[4] = x.shape.dims[0] as u32;
info[5] = x.shape.dims[1] as u32;
info[6] = x.shape.dims[2] as u32;
info[7] = x.shape.dims[3] as u32;
info[8] = grad.strides[0] as u32;
info[9] = grad.strides[1] as u32;
info[10] = grad.strides[2] as u32;
info[11] = grad.strides[3] as u32;
info[12] = kernel_size[0] as u32;
info[13] = kernel_size[1] as u32;
info[14] = stride[0] as u32;
info[15] = stride[1] as u32;
info[16] = padding[0] as u32;
info[17] = padding[1] as u32;
let info_buffer = x
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding);
let kernel = x.context.compile_static::<KernelSettings<
MaxPool2dWithIndicesBackward,

View File

@ -13,7 +13,7 @@ var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32, 18>;
var<storage, read> info: array<u32, 22>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@ -38,13 +38,17 @@ fn main(
let grad_stride_1 = info[9];
let grad_stride_2 = info[10];
let grad_stride_3 = info[11];
let grad_shape_0 = info[12];
let grad_shape_1 = info[13];
let grad_shape_2 = info[14];
let grad_shape_3 = info[15];
let kernel_size_0 = info[12];
let kernel_size_1 = info[13];
let pool_stride_0 = info[14];
let pool_stride_1 = info[15];
let padding_0 = info[16];
let padding_1 = info[17];
let kernel_size_0 = info[16];
let kernel_size_1 = info[17];
let pool_stride_0 = info[18];
let pool_stride_1 = info[19];
let padding_0 = info[20];
let padding_1 = info[21];
let b = id / input_stride_0 % input_shape_0;
let c = id / input_stride_1 % input_shape_1;
@ -61,8 +65,8 @@ fn main(
let oh_start = u32(max(oh_start_tmp, 0));
let ow_start = u32(max(ow_start_tmp, 0));
let oh_end = u32(max(kms_0, 0)) + oh_start;
let ow_end = u32(max(kms_1, 0)) + ow_start;
let oh_end = min(u32(max(kms_0, 0)) + oh_start, grad_shape_2 - 1u);
let ow_end = min(u32(max(kms_1, 0)) + ow_start, grad_shape_3 - 1u);
let index_current = ih * input_stride_2 + iw * input_stride_3;
var grad_acc = 0.0;