From 7ac5deebe2f2c5f0b44a84e8ed000e86f9f13412 Mon Sep 17 00:00:00 2001 From: Asher Jingkong Chen <37398747+AsherJingkongChen@users.noreply.github.com> Date: Mon, 16 Sep 2024 22:15:27 +0800 Subject: [PATCH] Refactor burn-tensor: Split conv backward ops to allow conditional gradient computation (#2278) --- crates/burn-autodiff/src/ops/module.rs | 263 +++++++--- .../src/tensor/ops/modules/base.rs | 200 +++++--- .../src/tensor/ops/modules/conv.rs | 485 ++++++++++-------- 3 files changed, 576 insertions(+), 372 deletions(-) diff --git a/crates/burn-autodiff/src/ops/module.rs b/crates/burn-autodiff/src/ops/module.rs index 3ee28c878..a0a70788b 100644 --- a/crates/burn-autodiff/src/ops/module.rs +++ b/crates/burn-autodiff/src/ops/module.rs @@ -78,20 +78,28 @@ impl ModuleOps> for Autodiff(&ops.node); let (x_state, weight_state, bias_state, options) = ops.state; - let x = checkpointer.retrieve_node_output(x_state); - let weight = checkpointer.retrieve_node_output(weight_state); - let bias = Some(checkpointer.retrieve_node_output(bias_state)); - - let backward = B::conv1d_backward(x, weight, bias, grad, options); + let x = checkpointer.retrieve_node_output::>(x_state); + let weight = + checkpointer.retrieve_node_output::>(weight_state); + let bias = + checkpointer.retrieve_node_output::>(bias_state); if let Some(node) = node_x { - grads.register::(node.id, backward.x_grad) + let grad = B::conv1d_x_backward( + x.clone(), + weight.clone(), + grad.clone(), + options.clone(), + ); + grads.register::(node.id, grad) } if let Some(node) = node_weight { - grads.register::(node.id, backward.weights_grad) + let grad = B::conv1d_weight_backward(x.clone(), weight, grad.clone(), options); + grads.register::(node.id, grad) } if let Some(node) = node_bias { - grads.register::(node.id, backward.bias_grad.unwrap()) + let grad = B::conv1d_bias_backward(x, bias, grad); + grads.register::(node.id, grad) } } } @@ -109,16 +117,22 @@ impl ModuleOps> for Autodiff(&ops.node); let (x_state, weight_state, options) = ops.state; - let x = checkpointer.retrieve_node_output(x_state); - let weight = checkpointer.retrieve_node_output(weight_state); - - let backward = B::conv1d_backward(x, weight, None, grad, options); + let x = checkpointer.retrieve_node_output::>(x_state); + let weight = + checkpointer.retrieve_node_output::>(weight_state); if let Some(node) = node_x { - grads.register::(node.id, backward.x_grad) + let grad = B::conv1d_x_backward( + x.clone(), + weight.clone(), + grad.clone(), + options.clone(), + ); + grads.register::(node.id, grad) } if let Some(node) = node_weight { - grads.register::(node.id, backward.weights_grad) + let grad = B::conv1d_weight_backward(x, weight, grad, options); + grads.register::(node.id, grad) } } } @@ -188,20 +202,32 @@ impl ModuleOps> for Autodiff(&ops.node); let (x_state, weight_state, bias_state, options) = ops.state; - let x = checkpointer.retrieve_node_output(x_state); - let weight = checkpointer.retrieve_node_output(weight_state); - let bias = Some(checkpointer.retrieve_node_output(bias_state)); - - let backward = B::conv_transpose1d_backward(x, weight, bias, grad, options); + let x = checkpointer.retrieve_node_output::>(x_state); + let weight = + checkpointer.retrieve_node_output::>(weight_state); + let bias = + checkpointer.retrieve_node_output::>(bias_state); if let Some(node) = node_x { - grads.register::(node.id, backward.x_grad) + let grad = B::conv_transpose1d_x_backward( + weight.clone(), + grad.clone(), + options.clone(), + ); + grads.register::(node.id, grad) } if let Some(node) = node_weight { - grads.register::(node.id, backward.weights_grad) + let grad = B::conv_transpose1d_weight_backward( + x.clone(), + weight, + grad.clone(), + options, + ); + grads.register::(node.id, grad) } if let Some(node) = node_bias { - grads.register::(node.id, backward.bias_grad.unwrap()) + let grad = B::conv_transpose1d_bias_backward(x, bias, grad); + grads.register::(node.id, grad) } } } @@ -219,16 +245,21 @@ impl ModuleOps> for Autodiff(&ops.node); let (x_state, weight_state, options) = ops.state; - let x = checkpointer.retrieve_node_output(x_state); - let weight = checkpointer.retrieve_node_output(weight_state); - - let backward = B::conv_transpose1d_backward(x, weight, None, grad, options); + let x = checkpointer.retrieve_node_output::>(x_state); + let weight = + checkpointer.retrieve_node_output::>(weight_state); if let Some(node) = node_x { - grads.register::(node.id, backward.x_grad) + let grad = B::conv_transpose1d_x_backward( + weight.clone(), + grad.clone(), + options.clone(), + ); + grads.register::(node.id, grad) } if let Some(node) = node_weight { - grads.register::(node.id, backward.weights_grad) + let grad = B::conv_transpose1d_weight_backward(x, weight, grad, options); + grads.register::(node.id, grad) } } } @@ -307,20 +338,29 @@ impl ModuleOps> for Autodiff(&ops.node); let (x_state, weight_state, bias_state, options) = ops.state; - let x = checkpointer.retrieve_node_output(x_state); - let weight = checkpointer.retrieve_node_output(weight_state); - let bias = Some(checkpointer.retrieve_node_output(bias_state)); - - let backward = B::conv2d_backward(x, weight, bias, grad, options); + let x = checkpointer.retrieve_node_output::>(x_state); + let weight = + checkpointer.retrieve_node_output::>(weight_state); + let bias = + checkpointer.retrieve_node_output::>(bias_state); if let Some(node) = node_x { - grads.register::(node.id, backward.x_grad) + let grad = B::conv2d_x_backward( + x.clone(), + weight.clone(), + grad.clone(), + options.clone(), + ); + grads.register::(node.id, grad) } if let Some(node) = node_weight { - grads.register::(node.id, backward.weights_grad) + let grad = + B::conv2d_weight_backward(x.clone(), weight.clone(), grad.clone(), options); + grads.register::(node.id, grad) } if let Some(node) = node_bias { - grads.register::(node.id, backward.bias_grad.unwrap()) + let grad = B::conv2d_bias_backward(x, weight, bias, grad); + grads.register::(node.id, grad) } } } @@ -338,16 +378,22 @@ impl ModuleOps> for Autodiff(&ops.node); let (x_state, weight_state, options) = ops.state; - let x = checkpointer.retrieve_node_output(x_state); - let weight = checkpointer.retrieve_node_output(weight_state); - - let backward = B::conv2d_backward(x, weight, None, grad, options); + let x = checkpointer.retrieve_node_output::>(x_state); + let weight = + checkpointer.retrieve_node_output::>(weight_state); if let Some(node) = node_x { - grads.register::(node.id, backward.x_grad) + let grad = B::conv2d_x_backward( + x.clone(), + weight.clone(), + grad.clone(), + options.clone(), + ); + grads.register::(node.id, grad) } if let Some(node) = node_weight { - grads.register::(node.id, backward.weights_grad) + let grad = B::conv2d_weight_backward(x, weight, grad, options); + grads.register::(node.id, grad) } } } @@ -419,20 +465,32 @@ impl ModuleOps> for Autodiff(&ops.node); let (x_state, weight_state, bias_state, options) = ops.state; - let x = checkpointer.retrieve_node_output(x_state); - let weight = checkpointer.retrieve_node_output(weight_state); - let bias = Some(checkpointer.retrieve_node_output(bias_state)); - - let backward = B::conv_transpose2d_backward(x, weight, bias, grad, options); + let x = checkpointer.retrieve_node_output::>(x_state); + let weight = + checkpointer.retrieve_node_output::>(weight_state); + let bias = + checkpointer.retrieve_node_output::>(bias_state); if let Some(node) = node_x { - grads.register::(node.id, backward.x_grad) + let grad = B::conv_transpose2d_x_backward( + weight.clone(), + grad.clone(), + options.clone(), + ); + grads.register::(node.id, grad) } if let Some(node) = node_weight { - grads.register::(node.id, backward.weights_grad) + let grad = B::conv_transpose2d_weight_backward( + x.clone(), + weight, + grad.clone(), + options, + ); + grads.register::(node.id, grad) } if let Some(node) = node_bias { - grads.register::(node.id, backward.bias_grad.unwrap()) + let grad = B::conv_transpose2d_bias_backward(x, bias, grad); + grads.register::(node.id, grad) } } } @@ -450,16 +508,21 @@ impl ModuleOps> for Autodiff(&ops.node); let (x_state, weight_state, options) = ops.state; - let x = checkpointer.retrieve_node_output(x_state); - let weight = checkpointer.retrieve_node_output(weight_state); - - let backward = B::conv_transpose2d_backward(x, weight, None, grad, options); + let x = checkpointer.retrieve_node_output::>(x_state); + let weight = + checkpointer.retrieve_node_output::>(weight_state); if let Some(node) = node_x { - grads.register::(node.id, backward.x_grad) + let grad = B::conv_transpose2d_x_backward( + weight.clone(), + grad.clone(), + options.clone(), + ); + grads.register::(node.id, grad) } if let Some(node) = node_weight { - grads.register::(node.id, backward.weights_grad) + let grad = B::conv_transpose2d_weight_backward(x, weight, grad, options); + grads.register::(node.id, grad) } } } @@ -540,20 +603,29 @@ impl ModuleOps> for Autodiff(&ops.node); let (x_state, weight_state, bias_state, options) = ops.state; - let x = checkpointer.retrieve_node_output(x_state); - let weight = checkpointer.retrieve_node_output(weight_state); - let bias = Some(checkpointer.retrieve_node_output(bias_state)); - - let backward = B::conv3d_backward(x, weight, bias, grad, options); + let x = checkpointer.retrieve_node_output::>(x_state); + let weight = + checkpointer.retrieve_node_output::>(weight_state); + let bias = + checkpointer.retrieve_node_output::>(bias_state); if let Some(node) = node_x { - grads.register::(node.id, backward.x_grad) + let grad = B::conv3d_x_backward( + x.clone(), + weight.clone(), + grad.clone(), + options.clone(), + ); + grads.register::(node.id, grad) } if let Some(node) = node_weight { - grads.register::(node.id, backward.weights_grad) + let grad = + B::conv3d_weight_backward(x.clone(), weight.clone(), grad.clone(), options); + grads.register::(node.id, grad) } if let Some(node) = node_bias { - grads.register::(node.id, backward.bias_grad.unwrap()) + let grad = B::conv3d_bias_backward(x, weight, bias, grad); + grads.register::(node.id, grad) } } } @@ -571,16 +643,22 @@ impl ModuleOps> for Autodiff(&ops.node); let (x_state, weight_state, options) = ops.state; - let x = checkpointer.retrieve_node_output(x_state); - let weight = checkpointer.retrieve_node_output(weight_state); - - let backward = B::conv3d_backward(x, weight, None, grad, options); + let x = checkpointer.retrieve_node_output::>(x_state); + let weight = + checkpointer.retrieve_node_output::>(weight_state); if let Some(node) = node_x { - grads.register::(node.id, backward.x_grad) + let grad = B::conv3d_x_backward( + x.clone(), + weight.clone(), + grad.clone(), + options.clone(), + ); + grads.register::(node.id, grad) } if let Some(node) = node_weight { - grads.register::(node.id, backward.weights_grad) + let grad = B::conv3d_weight_backward(x, weight, grad, options); + grads.register::(node.id, grad) } } } @@ -652,20 +730,32 @@ impl ModuleOps> for Autodiff(&ops.node); let (x_state, weight_state, bias_state, options) = ops.state; - let x = checkpointer.retrieve_node_output(x_state); - let weight = checkpointer.retrieve_node_output(weight_state); - let bias = Some(checkpointer.retrieve_node_output(bias_state)); - - let backward = B::conv_transpose3d_backward(x, weight, bias, grad, options); + let x = checkpointer.retrieve_node_output::>(x_state); + let weight = + checkpointer.retrieve_node_output::>(weight_state); + let bias = + checkpointer.retrieve_node_output::>(bias_state); if let Some(node) = node_x { - grads.register::(node.id, backward.x_grad) + let grad = B::conv_transpose3d_x_backward( + weight.clone(), + grad.clone(), + options.clone(), + ); + grads.register::(node.id, grad) } if let Some(node) = node_weight { - grads.register::(node.id, backward.weights_grad) + let grad = B::conv_transpose3d_weight_backward( + x.clone(), + weight, + grad.clone(), + options, + ); + grads.register::(node.id, grad) } if let Some(node) = node_bias { - grads.register::(node.id, backward.bias_grad.unwrap()) + let grad = B::conv_transpose3d_bias_backward(x, bias, grad); + grads.register::(node.id, grad) } } } @@ -683,16 +773,21 @@ impl ModuleOps> for Autodiff(&ops.node); let (x_state, weight_state, options) = ops.state; - let x = checkpointer.retrieve_node_output(x_state); - let weight = checkpointer.retrieve_node_output(weight_state); - - let backward = B::conv_transpose3d_backward(x, weight, None, grad, options); + let x = checkpointer.retrieve_node_output::>(x_state); + let weight = + checkpointer.retrieve_node_output::>(weight_state); if let Some(node) = node_x { - grads.register::(node.id, backward.x_grad) + let grad = B::conv_transpose3d_x_backward( + weight.clone(), + grad.clone(), + options.clone(), + ); + grads.register::(node.id, grad) } if let Some(node) = node_weight { - grads.register::(node.id, backward.weights_grad) + let grad = B::conv_transpose3d_weight_backward(x, weight, grad, options); + grads.register::(node.id, grad) } } } diff --git a/crates/burn-tensor/src/tensor/ops/modules/base.rs b/crates/burn-tensor/src/tensor/ops/modules/base.rs index 1f76b3d84..58ed8682c 100644 --- a/crates/burn-tensor/src/tensor/ops/modules/base.rs +++ b/crates/burn-tensor/src/tensor/ops/modules/base.rs @@ -5,32 +5,6 @@ use crate::{ Shape, }; -/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d). -#[derive(new)] -pub struct Conv2dBackward { - /// Gradient. - pub x_grad: FloatTensor, - - /// Weights gradient. - pub weights_grad: FloatTensor, - - /// Bias gradient. - pub bias_grad: Option>, -} - -/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d). -#[derive(new)] -pub struct Conv3dBackward { - /// Gradient. - pub x_grad: FloatTensor, - - /// Weights gradient. - pub weights_grad: FloatTensor, - - /// Bias gradient. - pub bias_grad: Option>, -} - /// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d). #[derive(new)] pub struct MaxPool1dBackward { @@ -65,19 +39,6 @@ pub struct MaxPool2dWithIndices { pub indices: IntTensor, } -/// Gradient computed during the backward pass for each tensor used by [conv1d](ModuleOps::conv1d). -#[derive(new)] -pub struct Conv1dBackward { - /// Gradient. - pub x_grad: FloatTensor, - - /// Weights gradient. - pub weights_grad: FloatTensor, - - /// Bias gradient. - pub bias_grad: Option>, -} - /// Convolution options. #[derive(new, Debug, Clone, Hash, PartialEq, Eq)] pub struct ConvOptions { @@ -221,15 +182,31 @@ pub trait ModuleOps { ) -> FloatTensor { conv::conv1d_from_conv2d::(x, weight, bias, options) } - /// Backward pass for the [conv1d](ModuleOps::conv1d) operation. - fn conv1d_backward( + /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`. + fn conv1d_x_backward( x: FloatTensor, weight: FloatTensor, - bias: Option>, output_grad: FloatTensor, options: ConvOptions<1>, - ) -> Conv1dBackward { - conv::conv1d_backward(x, weight, bias, output_grad, options) + ) -> FloatTensor { + conv::conv1d_x_backward::(x, weight, output_grad, options) + } + /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`. + fn conv1d_weight_backward( + x: FloatTensor, + weight: FloatTensor, + output_grad: FloatTensor, + options: ConvOptions<1>, + ) -> FloatTensor { + conv::conv1d_weight_backward::(x, weight, output_grad, options) + } + /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`. + fn conv1d_bias_backward( + x: FloatTensor, + bias: FloatTensor, + output_grad: FloatTensor, + ) -> FloatTensor { + conv::conv1d_bias_backward::(x, bias, output_grad) } /// Two dimensional convolution. /// @@ -244,15 +221,32 @@ pub trait ModuleOps { bias: Option>, options: ConvOptions<2>, ) -> FloatTensor; - /// Backward pass for the [conv2d](ModuleOps::conv2d) operation. - fn conv2d_backward( + /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`. + fn conv2d_x_backward( x: FloatTensor, weight: FloatTensor, - bias: Option>, output_grad: FloatTensor, options: ConvOptions<2>, - ) -> Conv2dBackward { - conv::conv2d_backward(x, weight, bias, output_grad, options) + ) -> FloatTensor { + conv::conv2d_x_backward::(x, weight, output_grad, options) + } + /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`. + fn conv2d_weight_backward( + x: FloatTensor, + weight: FloatTensor, + output_grad: FloatTensor, + options: ConvOptions<2>, + ) -> FloatTensor { + conv::conv2d_weight_backward::(x, weight, output_grad, options) + } + /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`. + fn conv2d_bias_backward( + x: FloatTensor, + weight: FloatTensor, + bias: FloatTensor, + output_grad: FloatTensor, + ) -> FloatTensor { + conv::conv2d_bias_backward::(x, weight, bias, output_grad) } /// Three dimensional convolution. /// @@ -267,15 +261,32 @@ pub trait ModuleOps { bias: Option>, options: ConvOptions<3>, ) -> FloatTensor; - /// Backward pass for the [conv3d](ModuleOps::conv3d) operation. - fn conv3d_backward( + /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`. + fn conv3d_x_backward( x: FloatTensor, weight: FloatTensor, - bias: Option>, output_grad: FloatTensor, options: ConvOptions<3>, - ) -> Conv3dBackward { - conv::conv3d_backward(x, weight, bias, output_grad, options) + ) -> FloatTensor { + conv::conv3d_x_backward::(x, weight, output_grad, options) + } + /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`. + fn conv3d_weight_backward( + x: FloatTensor, + weight: FloatTensor, + output_grad: FloatTensor, + options: ConvOptions<3>, + ) -> FloatTensor { + conv::conv3d_weight_backward::(x, weight, output_grad, options) + } + /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`. + fn conv3d_bias_backward( + x: FloatTensor, + weight: FloatTensor, + bias: FloatTensor, + output_grad: FloatTensor, + ) -> FloatTensor { + conv::conv3d_bias_backward::(x, weight, bias, output_grad) } /// One dimensional transposed convolution. /// @@ -292,15 +303,30 @@ pub trait ModuleOps { ) -> FloatTensor { conv::conv_transpose1d_from_conv_transpose2d::(x, weight, bias, options) } - /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation. - fn conv_transpose1d_backward( - x: FloatTensor, + /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`. + fn conv_transpose1d_x_backward( weight: FloatTensor, - bias: Option>, output_grad: FloatTensor, options: ConvTransposeOptions<1>, - ) -> Conv1dBackward { - conv::conv_transpose1d_backward(x, weight, bias, output_grad, options) + ) -> FloatTensor { + conv::conv_transpose1d_x_backward::(weight, output_grad, options) + } + /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`. + fn conv_transpose1d_weight_backward( + x: FloatTensor, + weight: FloatTensor, + output_grad: FloatTensor, + options: ConvTransposeOptions<1>, + ) -> FloatTensor { + conv::conv_transpose1d_weight_backward::(x, weight, output_grad, options) + } + /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`. + fn conv_transpose1d_bias_backward( + x: FloatTensor, + bias: FloatTensor, + output_grad: FloatTensor, + ) -> FloatTensor { + conv::conv_transpose1d_bias_backward::(x, bias, output_grad) } /// Two dimensional transposed convolution. @@ -316,15 +342,30 @@ pub trait ModuleOps { bias: Option>, options: ConvTransposeOptions<2>, ) -> FloatTensor; - /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation. - fn conv_transpose2d_backward( - x: FloatTensor, + /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`. + fn conv_transpose2d_x_backward( weight: FloatTensor, - bias: Option>, output_grad: FloatTensor, options: ConvTransposeOptions<2>, - ) -> Conv2dBackward { - conv::conv_transpose2d_backward(x, weight, bias, output_grad, options) + ) -> FloatTensor { + conv::conv_transpose2d_x_backward::(weight, output_grad, options) + } + /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`. + fn conv_transpose2d_weight_backward( + x: FloatTensor, + weight: FloatTensor, + output_grad: FloatTensor, + options: ConvTransposeOptions<2>, + ) -> FloatTensor { + conv::conv_transpose2d_weight_backward::(x, weight, output_grad, options) + } + /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`. + fn conv_transpose2d_bias_backward( + x: FloatTensor, + bias: FloatTensor, + output_grad: FloatTensor, + ) -> FloatTensor { + conv::conv_transpose2d_bias_backward::(x, bias, output_grad) } /// Three dimensional transposed convolution. @@ -340,15 +381,30 @@ pub trait ModuleOps { bias: Option>, options: ConvTransposeOptions<3>, ) -> FloatTensor; - /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation. - fn conv_transpose3d_backward( - x: FloatTensor, + /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`. + fn conv_transpose3d_x_backward( weight: FloatTensor, - bias: Option>, output_grad: FloatTensor, options: ConvTransposeOptions<3>, - ) -> Conv3dBackward { - conv::conv_transpose3d_backward(x, weight, bias, output_grad, options) + ) -> FloatTensor { + conv::conv_transpose3d_x_backward::(weight, output_grad, options) + } + /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`. + fn conv_transpose3d_weight_backward( + x: FloatTensor, + weight: FloatTensor, + output_grad: FloatTensor, + options: ConvTransposeOptions<3>, + ) -> FloatTensor { + conv::conv_transpose3d_weight_backward::(x, weight, output_grad, options) + } + /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`. + fn conv_transpose3d_bias_backward( + x: FloatTensor, + bias: FloatTensor, + output_grad: FloatTensor, + ) -> FloatTensor { + conv::conv_transpose3d_bias_backward::(x, bias, output_grad) } /// Four-dimensional unfolding. diff --git a/crates/burn-tensor/src/tensor/ops/modules/conv.rs b/crates/burn-tensor/src/tensor/ops/modules/conv.rs index 9d616f3f5..ba801aade 100644 --- a/crates/burn-tensor/src/tensor/ops/modules/conv.rs +++ b/crates/burn-tensor/src/tensor/ops/modules/conv.rs @@ -1,5 +1,5 @@ #![allow(clippy::single_range_in_vec_init)] -use super::{Conv1dBackward, Conv2dBackward, Conv3dBackward, ConvOptions, ConvTransposeOptions}; +use super::{ConvOptions, ConvTransposeOptions}; use crate::{backend::Backend, ops::FloatTensor, Shape}; #[cfg(not(feature = "std"))] @@ -57,19 +57,17 @@ pub fn calculate_pool_output_size( ((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1 } -/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass using convolutions. -pub(crate) fn conv1d_backward( +/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `x`. +pub(crate) fn conv1d_x_backward( x: FloatTensor, weight: FloatTensor, - bias: Option>, output_grad: FloatTensor, options: ConvOptions<1>, -) -> Conv1dBackward { +) -> FloatTensor { let weight_shape = B::float_shape(&weight); - let weight_device = B::float_device(&weight); - let [batch_size, _, length_in] = B::float_shape(&x).dims; - let [_batch_size, channels_out, length_out] = B::float_shape(&output_grad).dims; + let [_batch_size, _, length_in] = B::float_shape(&x).dims; + let [_batch_size, _channels_out, length_out] = B::float_shape(&output_grad).dims; let [_, _, kernel_size] = weight_shape.dims; let padding_out = calculate_padding_out( @@ -81,8 +79,8 @@ pub(crate) fn conv1d_backward( length_out, ); - let x_grad = B::conv_transpose1d( - output_grad.clone(), + B::conv_transpose1d( + output_grad, weight, None, ConvTransposeOptions::new( @@ -92,45 +90,58 @@ pub(crate) fn conv1d_backward( options.dilation, options.groups, ), - ); - - let weight_grad = match options.groups == 1 { - true => conv1d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options), - false => conv1d_weight_grad_groups::( - x, - B::float_zeros(weight_shape, &weight_device), - output_grad.clone(), - options, - ), - }; - - Conv1dBackward::new( - x_grad, - weight_grad, - bias.map(|b| { - let grad = B::float_swap_dims(output_grad, 0, 1); - let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out])); - let grad = B::float_sum_dim(grad, 1); - - B::float_reshape(grad, B::float_shape(&b)) - }), ) } -/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass using convolutions. -pub(crate) fn conv2d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvOptions<2>, -) -> Conv2dBackward { +/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `weight`. +pub(crate) fn conv1d_weight_backward( + x: FloatTensor, + weight: FloatTensor, + output_grad: FloatTensor, + options: ConvOptions<1>, +) -> FloatTensor { let weight_shape = B::float_shape(&weight); let weight_device = B::float_device(&weight); - let [batch_size, _channels_in, height_in, width_in] = B::float_shape(&x).dims; + match options.groups == 1 { + true => conv1d_weight_grad_no_groups::(x, output_grad, weight_shape, options), + false => conv1d_weight_grad_groups::( + x, + B::float_zeros(weight_shape, &weight_device), + output_grad, + options, + ), + } +} + +/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `bias`. +pub(crate) fn conv1d_bias_backward( + x: FloatTensor, + bias: FloatTensor, + output_grad: FloatTensor, +) -> FloatTensor { + let [batch_size, _, _length_in] = B::float_shape(&x).dims; + let [_batch_size, channels_out, length_out] = B::float_shape(&output_grad).dims; + + let grad = B::float_swap_dims(output_grad, 0, 1); + let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out])); + let grad = B::float_sum_dim(grad, 1); + + B::float_reshape(grad, B::float_shape(&bias)) +} + +/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `x`. +pub(crate) fn conv2d_x_backward( + x: FloatTensor, + weight: FloatTensor, + output_grad: FloatTensor, + options: ConvOptions<2>, +) -> FloatTensor { + let weight_shape = B::float_shape(&weight); + + let [_batch_size, _channels_in, height_in, width_in] = B::float_shape(&x).dims; let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims; - let [channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims; + let [_channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims; let padding_1_out = calculate_padding_out( kernel_size_1, @@ -149,8 +160,8 @@ pub(crate) fn conv2d_backward( width_out, ); - let x_grad = B::conv_transpose2d( - output_grad.clone(), + B::conv_transpose2d( + output_grad, weight, None, ConvTransposeOptions::new( @@ -160,48 +171,65 @@ pub(crate) fn conv2d_backward( options.dilation, options.groups, ), - ); - - let weight_grad = match options.groups == 1 { - true => conv2d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options), - false => conv2d_weight_grad_groups::( - x, - B::float_zeros(weight_shape, &weight_device), - output_grad.clone(), - options, - ), - }; - - Conv2dBackward::new( - x_grad, - weight_grad, - bias.map(|b| { - let grad = B::float_swap_dims(output_grad, 0, 1); - let grad = B::float_reshape( - grad, - Shape::new([channels_out, batch_size * height_out * width_out]), - ); - let grad = B::float_sum_dim(grad, 1); - - B::float_reshape(grad, B::float_shape(&b)) - }), ) } -/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass using convolutions. -pub(crate) fn conv3d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvOptions<3>, -) -> Conv3dBackward { +/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `weight`. +pub(crate) fn conv2d_weight_backward( + x: FloatTensor, + weight: FloatTensor, + output_grad: FloatTensor, + options: ConvOptions<2>, +) -> FloatTensor { let weight_shape = B::float_shape(&weight); let weight_device = B::float_device(&weight); - let [batch_size, _channels_in, depth_in, height_in, width_in] = B::float_shape(&x).dims; + match options.groups == 1 { + true => conv2d_weight_grad_no_groups::(x, output_grad, weight_shape, options), + false => conv2d_weight_grad_groups::( + x, + B::float_zeros(weight_shape, &weight_device), + output_grad, + options, + ), + } +} + +/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `bias`. +pub(crate) fn conv2d_bias_backward( + x: FloatTensor, + weight: FloatTensor, + bias: FloatTensor, + output_grad: FloatTensor, +) -> FloatTensor { + let weight_shape = B::float_shape(&weight); + + let [batch_size, _channels_in, _height_in, _width_in] = B::float_shape(&x).dims; + let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims; + let [channels_out, _, _kernel_size_1, _kernel_size_2] = weight_shape.dims; + + let grad = B::float_swap_dims(output_grad, 0, 1); + let grad = B::float_reshape( + grad, + Shape::new([channels_out, batch_size * height_out * width_out]), + ); + let grad = B::float_sum_dim(grad, 1); + + B::float_reshape(grad, B::float_shape(&bias)) +} + +/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `x`. +pub(crate) fn conv3d_x_backward( + x: FloatTensor, + weight: FloatTensor, + output_grad: FloatTensor, + options: ConvOptions<3>, +) -> FloatTensor { + let weight_shape = B::float_shape(&weight); + + let [_batch_size, _channels_in, depth_in, height_in, width_in] = B::float_shape(&x).dims; let [_, _, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims; - let [channels_out, _, kernel_size_1, kernel_size_2, kernel_size_3] = weight_shape.dims; + let [_channels_out, _, kernel_size_1, kernel_size_2, kernel_size_3] = weight_shape.dims; let padding_1_out = calculate_padding_out( kernel_size_1, @@ -228,8 +256,8 @@ pub(crate) fn conv3d_backward( width_out, ); - let x_grad = B::conv_transpose3d( - output_grad.clone(), + B::conv_transpose3d( + output_grad, weight, None, ConvTransposeOptions::new( @@ -239,53 +267,64 @@ pub(crate) fn conv3d_backward( options.dilation, options.groups, ), - ); + ) +} - let weight_grad = match options.groups == 1 { - true => conv3d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options), +/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `weight`. +pub(crate) fn conv3d_weight_backward( + x: FloatTensor, + weight: FloatTensor, + output_grad: FloatTensor, + options: ConvOptions<3>, +) -> FloatTensor { + let weight_shape = B::float_shape(&weight); + let weight_device = B::float_device(&weight); + + match options.groups == 1 { + true => conv3d_weight_grad_no_groups::(x, output_grad, weight_shape, options), false => conv3d_weight_grad_groups::( x, B::float_zeros(weight_shape, &weight_device), - output_grad.clone(), + output_grad, options, ), - }; - - Conv3dBackward::new( - x_grad, - weight_grad, - bias.map(|b| { - let grad = B::float_swap_dims(output_grad, 0, 1); - let grad = B::float_reshape( - grad, - Shape::new([ - channels_out, - batch_size * depth_out * height_out * width_out, - ]), - ); - let grad = B::float_sum_dim(grad, 1); - - B::float_reshape(grad, B::float_shape(&b)) - }), - ) + } } -/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass using convolutions. -pub(crate) fn conv_transpose1d_backward( - x: FloatTensor, +/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `bias`. +pub(crate) fn conv3d_bias_backward( + x: FloatTensor, + weight: FloatTensor, + bias: FloatTensor, + output_grad: FloatTensor, +) -> FloatTensor { + let weight_shape = B::float_shape(&weight); + + let [batch_size, _channels_in, _depth_in, _height_in, _width_in] = B::float_shape(&x).dims; + let [_, _, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims; + let [channels_out, _, _kernel_size_1, _kernel_size_2, _kernel_size_3] = weight_shape.dims; + + let grad = B::float_swap_dims(output_grad, 0, 1); + let grad = B::float_reshape( + grad, + Shape::new([ + channels_out, + batch_size * depth_out * height_out * width_out, + ]), + ); + let grad = B::float_sum_dim(grad, 1); + + B::float_reshape(grad, B::float_shape(&bias)) +} + +/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `x`. +pub(crate) fn conv_transpose1d_x_backward( weight: FloatTensor, - bias: Option>, output_grad: FloatTensor, options: ConvTransposeOptions<1>, -) -> Conv1dBackward { - let weight_shape = B::float_shape(&weight); - let weight_device = B::float_device(&weight); - - let [batch_size, _channels_in, _] = B::float_shape(&x).dims; - let [_, channels_out, length_out] = B::float_shape(&output_grad).dims; - - let x_grad = B::conv1d( - output_grad.clone(), +) -> FloatTensor { + B::conv1d( + output_grad, weight, None, ConvOptions::new( @@ -294,52 +333,54 @@ pub(crate) fn conv_transpose1d_backward( options.dilation, options.groups, ), - ); + ) +} - let weight_grad = match options.groups == 1 { - true => conv_transpose1d_weight_grad_no_groups::( - x, - output_grad.clone(), - weight_shape, - options, - ), +/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `weight`. +pub(crate) fn conv_transpose1d_weight_backward( + x: FloatTensor, + weight: FloatTensor, + output_grad: FloatTensor, + options: ConvTransposeOptions<1>, +) -> FloatTensor { + let weight_shape = B::float_shape(&weight); + let weight_device = B::float_device(&weight); + + match options.groups == 1 { + true => conv_transpose1d_weight_grad_no_groups::(x, output_grad, weight_shape, options), false => conv_transpose1d_weight_grad_groups::( x, B::float_zeros(weight_shape, &weight_device), - output_grad.clone(), + output_grad, options, ), - }; - - Conv1dBackward::new( - x_grad, - weight_grad, - bias.map(|b| { - let grad = B::float_swap_dims(output_grad, 0, 1); - let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out])); - let grad = B::float_sum_dim(grad, 1); - - B::float_reshape(grad, B::float_shape(&b)) - }), - ) + } } -/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass using convolutions. -pub(crate) fn conv_transpose2d_backward( - x: FloatTensor, +/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `bias`. +pub(crate) fn conv_transpose1d_bias_backward( + x: FloatTensor, + bias: FloatTensor, + output_grad: FloatTensor, +) -> FloatTensor { + let [batch_size, _channels_in, _] = B::float_shape(&x).dims; + let [_, channels_out, length_out] = B::float_shape(&output_grad).dims; + + let grad = B::float_swap_dims(output_grad, 0, 1); + let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out])); + let grad = B::float_sum_dim(grad, 1); + + B::float_reshape(grad, B::float_shape(&bias)) +} + +/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `x`. +pub(crate) fn conv_transpose2d_x_backward( weight: FloatTensor, - bias: Option>, output_grad: FloatTensor, options: ConvTransposeOptions<2>, -) -> Conv2dBackward { - let weight_shape = B::float_shape(&weight); - let weight_device = B::float_device(&weight); - - let [batch_size, _channels_in, _, _] = B::float_shape(&x).dims; - let [_, channels_out, height_out, width_out] = B::float_shape(&output_grad).dims; - - let x_grad = B::conv2d( - output_grad.clone(), +) -> FloatTensor { + B::conv2d( + output_grad, weight, None, ConvOptions::new( @@ -348,55 +389,57 @@ pub(crate) fn conv_transpose2d_backward( options.dilation, options.groups, ), - ); + ) +} - let weight_grad = match options.groups == 1 { - true => conv_transpose2d_weight_grad_no_groups::( - x, - output_grad.clone(), - weight_shape, - options, - ), +/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `weight`. +pub(crate) fn conv_transpose2d_weight_backward( + x: FloatTensor, + weight: FloatTensor, + output_grad: FloatTensor, + options: ConvTransposeOptions<2>, +) -> FloatTensor { + let weight_shape = B::float_shape(&weight); + let weight_device = B::float_device(&weight); + + match options.groups == 1 { + true => conv_transpose2d_weight_grad_no_groups::(x, output_grad, weight_shape, options), false => conv_transpose2d_weight_grad_groups::( x, B::float_zeros(weight_shape, &weight_device), - output_grad.clone(), + output_grad, options, ), - }; - - Conv2dBackward::new( - x_grad, - weight_grad, - bias.map(|b| { - let grad = B::float_swap_dims(output_grad, 0, 1); - let grad = B::float_reshape( - grad, - Shape::new([channels_out, batch_size * height_out * width_out]), - ); - let grad = B::float_sum_dim(grad, 1); - - B::float_reshape(grad, B::float_shape(&b)) - }), - ) + } } -/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass using convolutions. -pub(crate) fn conv_transpose3d_backward( - x: FloatTensor, +/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `bias`. +pub(crate) fn conv_transpose2d_bias_backward( + x: FloatTensor, + bias: FloatTensor, + output_grad: FloatTensor, +) -> FloatTensor { + let [batch_size, _channels_in, _, _] = B::float_shape(&x).dims; + let [_, channels_out, height_out, width_out] = B::float_shape(&output_grad).dims; + + let grad = B::float_swap_dims(output_grad, 0, 1); + let grad = B::float_reshape( + grad, + Shape::new([channels_out, batch_size * height_out * width_out]), + ); + let grad = B::float_sum_dim(grad, 1); + + B::float_reshape(grad, B::float_shape(&bias)) +} + +/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `x`. +pub(crate) fn conv_transpose3d_x_backward( weight: FloatTensor, - bias: Option>, output_grad: FloatTensor, options: ConvTransposeOptions<3>, -) -> Conv3dBackward { - let weight_shape = B::float_shape(&weight); - let weight_device = B::float_device(&weight); - - let [batch_size, _channels_in, _, _, _] = B::float_shape(&x).dims; - let [_, channels_out, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims; - - let x_grad = B::conv3d( - output_grad.clone(), +) -> FloatTensor { + B::conv3d( + output_grad, weight, None, ConvOptions::new( @@ -405,40 +448,50 @@ pub(crate) fn conv_transpose3d_backward( options.dilation, options.groups, ), - ); + ) +} - let weight_grad = match options.groups == 1 { - true => conv_transpose3d_weight_grad_no_groups::( - x, - output_grad.clone(), - weight_shape, - options, - ), +/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `weight`. +pub(crate) fn conv_transpose3d_weight_backward( + x: FloatTensor, + weight: FloatTensor, + output_grad: FloatTensor, + options: ConvTransposeOptions<3>, +) -> FloatTensor { + let weight_shape = B::float_shape(&weight); + let weight_device = B::float_device(&weight); + + match options.groups == 1 { + true => conv_transpose3d_weight_grad_no_groups::(x, output_grad, weight_shape, options), false => conv_transpose3d_weight_grad_groups::( x, B::float_zeros(weight_shape, &weight_device), - output_grad.clone(), + output_grad, options, ), - }; + } +} - Conv3dBackward::new( - x_grad, - weight_grad, - bias.map(|b| { - let grad = B::float_swap_dims(output_grad, 0, 1); - let grad = B::float_reshape( - grad, - Shape::new([ - channels_out, - batch_size * depth_out * height_out * width_out, - ]), - ); - let grad = B::float_sum_dim(grad, 1); +/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `bias`. +pub(crate) fn conv_transpose3d_bias_backward( + x: FloatTensor, + bias: FloatTensor, + output_grad: FloatTensor, +) -> FloatTensor { + let [batch_size, _channels_in, _, _, _] = B::float_shape(&x).dims; + let [_, channels_out, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims; - B::float_reshape(grad, B::float_shape(&b)) - }), - ) + let grad = B::float_swap_dims(output_grad, 0, 1); + let grad = B::float_reshape( + grad, + Shape::new([ + channels_out, + batch_size * depth_out * height_out * width_out, + ]), + ); + let grad = B::float_sum_dim(grad, 1); + + B::float_reshape(grad, B::float_shape(&bias)) } /// Execute a 1D convolution using a 2D convolution.