Refactor burn-tensor: Split conv backward ops to allow conditional gradient computation (#2278)

This commit is contained in:
Asher Jingkong Chen 2024-09-16 22:15:27 +08:00 committed by GitHub
parent 81ec64a929
commit 7ac5deebe2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 576 additions and 372 deletions

View File

@ -78,20 +78,28 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 3>(&ops.node);
let (x_state, weight_state, bias_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let bias = Some(checkpointer.retrieve_node_output(bias_state));
let backward = B::conv1d_backward(x, weight, bias, grad, options);
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
let bias =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
if let Some(node) = node_x {
grads.register::<B, 3>(node.id, backward.x_grad)
let grad = B::conv1d_x_backward(
x.clone(),
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 3>(node.id, grad)
}
if let Some(node) = node_weight {
grads.register::<B, 3>(node.id, backward.weights_grad)
let grad = B::conv1d_weight_backward(x.clone(), weight, grad.clone(), options);
grads.register::<B, 3>(node.id, grad)
}
if let Some(node) = node_bias {
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
let grad = B::conv1d_bias_backward(x, bias, grad);
grads.register::<B, 1>(node.id, grad)
}
}
}
@ -109,16 +117,22 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 3>(&ops.node);
let (x_state, weight_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let backward = B::conv1d_backward(x, weight, None, grad, options);
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
if let Some(node) = node_x {
grads.register::<B, 3>(node.id, backward.x_grad)
let grad = B::conv1d_x_backward(
x.clone(),
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 3>(node.id, grad)
}
if let Some(node) = node_weight {
grads.register::<B, 3>(node.id, backward.weights_grad)
let grad = B::conv1d_weight_backward(x, weight, grad, options);
grads.register::<B, 3>(node.id, grad)
}
}
}
@ -188,20 +202,32 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 3>(&ops.node);
let (x_state, weight_state, bias_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let bias = Some(checkpointer.retrieve_node_output(bias_state));
let backward = B::conv_transpose1d_backward(x, weight, bias, grad, options);
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
let bias =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
if let Some(node) = node_x {
grads.register::<B, 3>(node.id, backward.x_grad)
let grad = B::conv_transpose1d_x_backward(
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 3>(node.id, grad)
}
if let Some(node) = node_weight {
grads.register::<B, 3>(node.id, backward.weights_grad)
let grad = B::conv_transpose1d_weight_backward(
x.clone(),
weight,
grad.clone(),
options,
);
grads.register::<B, 3>(node.id, grad)
}
if let Some(node) = node_bias {
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
let grad = B::conv_transpose1d_bias_backward(x, bias, grad);
grads.register::<B, 1>(node.id, grad)
}
}
}
@ -219,16 +245,21 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 3>(&ops.node);
let (x_state, weight_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let backward = B::conv_transpose1d_backward(x, weight, None, grad, options);
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
if let Some(node) = node_x {
grads.register::<B, 3>(node.id, backward.x_grad)
let grad = B::conv_transpose1d_x_backward(
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 3>(node.id, grad)
}
if let Some(node) = node_weight {
grads.register::<B, 3>(node.id, backward.weights_grad)
let grad = B::conv_transpose1d_weight_backward(x, weight, grad, options);
grads.register::<B, 3>(node.id, grad)
}
}
}
@ -307,20 +338,29 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 4>(&ops.node);
let (x_state, weight_state, bias_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let bias = Some(checkpointer.retrieve_node_output(bias_state));
let backward = B::conv2d_backward(x, weight, bias, grad, options);
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
let bias =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
if let Some(node) = node_x {
grads.register::<B, 4>(node.id, backward.x_grad)
let grad = B::conv2d_x_backward(
x.clone(),
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 4>(node.id, grad)
}
if let Some(node) = node_weight {
grads.register::<B, 4>(node.id, backward.weights_grad)
let grad =
B::conv2d_weight_backward(x.clone(), weight.clone(), grad.clone(), options);
grads.register::<B, 4>(node.id, grad)
}
if let Some(node) = node_bias {
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
let grad = B::conv2d_bias_backward(x, weight, bias, grad);
grads.register::<B, 1>(node.id, grad)
}
}
}
@ -338,16 +378,22 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 4>(&ops.node);
let (x_state, weight_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let backward = B::conv2d_backward(x, weight, None, grad, options);
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
if let Some(node) = node_x {
grads.register::<B, 4>(node.id, backward.x_grad)
let grad = B::conv2d_x_backward(
x.clone(),
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 4>(node.id, grad)
}
if let Some(node) = node_weight {
grads.register::<B, 4>(node.id, backward.weights_grad)
let grad = B::conv2d_weight_backward(x, weight, grad, options);
grads.register::<B, 4>(node.id, grad)
}
}
}
@ -419,20 +465,32 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 4>(&ops.node);
let (x_state, weight_state, bias_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let bias = Some(checkpointer.retrieve_node_output(bias_state));
let backward = B::conv_transpose2d_backward(x, weight, bias, grad, options);
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
let bias =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
if let Some(node) = node_x {
grads.register::<B, 4>(node.id, backward.x_grad)
let grad = B::conv_transpose2d_x_backward(
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 4>(node.id, grad)
}
if let Some(node) = node_weight {
grads.register::<B, 4>(node.id, backward.weights_grad)
let grad = B::conv_transpose2d_weight_backward(
x.clone(),
weight,
grad.clone(),
options,
);
grads.register::<B, 4>(node.id, grad)
}
if let Some(node) = node_bias {
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
let grad = B::conv_transpose2d_bias_backward(x, bias, grad);
grads.register::<B, 1>(node.id, grad)
}
}
}
@ -450,16 +508,21 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 4>(&ops.node);
let (x_state, weight_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let backward = B::conv_transpose2d_backward(x, weight, None, grad, options);
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
if let Some(node) = node_x {
grads.register::<B, 4>(node.id, backward.x_grad)
let grad = B::conv_transpose2d_x_backward(
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 4>(node.id, grad)
}
if let Some(node) = node_weight {
grads.register::<B, 4>(node.id, backward.weights_grad)
let grad = B::conv_transpose2d_weight_backward(x, weight, grad, options);
grads.register::<B, 4>(node.id, grad)
}
}
}
@ -540,20 +603,29 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 5>(&ops.node);
let (x_state, weight_state, bias_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let bias = Some(checkpointer.retrieve_node_output(bias_state));
let backward = B::conv3d_backward(x, weight, bias, grad, options);
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
let bias =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
if let Some(node) = node_x {
grads.register::<B, 5>(node.id, backward.x_grad)
let grad = B::conv3d_x_backward(
x.clone(),
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 5>(node.id, grad)
}
if let Some(node) = node_weight {
grads.register::<B, 5>(node.id, backward.weights_grad)
let grad =
B::conv3d_weight_backward(x.clone(), weight.clone(), grad.clone(), options);
grads.register::<B, 5>(node.id, grad)
}
if let Some(node) = node_bias {
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
let grad = B::conv3d_bias_backward(x, weight, bias, grad);
grads.register::<B, 1>(node.id, grad)
}
}
}
@ -571,16 +643,22 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 5>(&ops.node);
let (x_state, weight_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let backward = B::conv3d_backward(x, weight, None, grad, options);
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
if let Some(node) = node_x {
grads.register::<B, 5>(node.id, backward.x_grad)
let grad = B::conv3d_x_backward(
x.clone(),
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 5>(node.id, grad)
}
if let Some(node) = node_weight {
grads.register::<B, 5>(node.id, backward.weights_grad)
let grad = B::conv3d_weight_backward(x, weight, grad, options);
grads.register::<B, 5>(node.id, grad)
}
}
}
@ -652,20 +730,32 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 5>(&ops.node);
let (x_state, weight_state, bias_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let bias = Some(checkpointer.retrieve_node_output(bias_state));
let backward = B::conv_transpose3d_backward(x, weight, bias, grad, options);
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
let bias =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
if let Some(node) = node_x {
grads.register::<B, 5>(node.id, backward.x_grad)
let grad = B::conv_transpose3d_x_backward(
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 5>(node.id, grad)
}
if let Some(node) = node_weight {
grads.register::<B, 5>(node.id, backward.weights_grad)
let grad = B::conv_transpose3d_weight_backward(
x.clone(),
weight,
grad.clone(),
options,
);
grads.register::<B, 5>(node.id, grad)
}
if let Some(node) = node_bias {
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
let grad = B::conv_transpose3d_bias_backward(x, bias, grad);
grads.register::<B, 1>(node.id, grad)
}
}
}
@ -683,16 +773,21 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 5>(&ops.node);
let (x_state, weight_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let backward = B::conv_transpose3d_backward(x, weight, None, grad, options);
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
if let Some(node) = node_x {
grads.register::<B, 5>(node.id, backward.x_grad)
let grad = B::conv_transpose3d_x_backward(
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 5>(node.id, grad)
}
if let Some(node) = node_weight {
grads.register::<B, 5>(node.id, backward.weights_grad)
let grad = B::conv_transpose3d_weight_backward(x, weight, grad, options);
grads.register::<B, 5>(node.id, grad)
}
}
}

View File

@ -5,32 +5,6 @@ use crate::{
Shape,
};
/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
#[derive(new)]
pub struct Conv2dBackward<B: Backend> {
/// Gradient.
pub x_grad: FloatTensor<B, 4>,
/// Weights gradient.
pub weights_grad: FloatTensor<B, 4>,
/// Bias gradient.
pub bias_grad: Option<FloatTensor<B, 1>>,
}
/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).
#[derive(new)]
pub struct Conv3dBackward<B: Backend> {
/// Gradient.
pub x_grad: FloatTensor<B, 5>,
/// Weights gradient.
pub weights_grad: FloatTensor<B, 5>,
/// Bias gradient.
pub bias_grad: Option<FloatTensor<B, 1>>,
}
/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
#[derive(new)]
pub struct MaxPool1dBackward<B: Backend> {
@ -65,19 +39,6 @@ pub struct MaxPool2dWithIndices<B: Backend> {
pub indices: IntTensor<B, 4>,
}
/// Gradient computed during the backward pass for each tensor used by [conv1d](ModuleOps::conv1d).
#[derive(new)]
pub struct Conv1dBackward<B: Backend> {
/// Gradient.
pub x_grad: FloatTensor<B, 3>,
/// Weights gradient.
pub weights_grad: FloatTensor<B, 3>,
/// Bias gradient.
pub bias_grad: Option<FloatTensor<B, 1>>,
}
/// Convolution options.
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
pub struct ConvOptions<const N: usize> {
@ -221,15 +182,31 @@ pub trait ModuleOps<B: Backend> {
) -> FloatTensor<B, 3> {
conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
}
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation.
fn conv1d_backward(
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`.
fn conv1d_x_backward(
x: FloatTensor<B, 3>,
weight: FloatTensor<B, 3>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 3>,
options: ConvOptions<1>,
) -> Conv1dBackward<B> {
conv::conv1d_backward(x, weight, bias, output_grad, options)
) -> FloatTensor<B, 3> {
conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`.
fn conv1d_weight_backward(
x: FloatTensor<B, 3>,
weight: FloatTensor<B, 3>,
output_grad: FloatTensor<B, 3>,
options: ConvOptions<1>,
) -> FloatTensor<B, 3> {
conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`.
fn conv1d_bias_backward(
x: FloatTensor<B, 3>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 3>,
) -> FloatTensor<B, 1> {
conv::conv1d_bias_backward::<B>(x, bias, output_grad)
}
/// Two dimensional convolution.
///
@ -244,15 +221,32 @@ pub trait ModuleOps<B: Backend> {
bias: Option<FloatTensor<B, 1>>,
options: ConvOptions<2>,
) -> FloatTensor<B, 4>;
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation.
fn conv2d_backward(
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`.
fn conv2d_x_backward(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 4>,
options: ConvOptions<2>,
) -> Conv2dBackward<B> {
conv::conv2d_backward(x, weight, bias, output_grad, options)
) -> FloatTensor<B, 4> {
conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`.
fn conv2d_weight_backward(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
output_grad: FloatTensor<B, 4>,
options: ConvOptions<2>,
) -> FloatTensor<B, 4> {
conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`.
fn conv2d_bias_backward(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 4>,
) -> FloatTensor<B, 1> {
conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
}
/// Three dimensional convolution.
///
@ -267,15 +261,32 @@ pub trait ModuleOps<B: Backend> {
bias: Option<FloatTensor<B, 1>>,
options: ConvOptions<3>,
) -> FloatTensor<B, 5>;
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation.
fn conv3d_backward(
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`.
fn conv3d_x_backward(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 5>,
options: ConvOptions<3>,
) -> Conv3dBackward<B> {
conv::conv3d_backward(x, weight, bias, output_grad, options)
) -> FloatTensor<B, 5> {
conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`.
fn conv3d_weight_backward(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
output_grad: FloatTensor<B, 5>,
options: ConvOptions<3>,
) -> FloatTensor<B, 5> {
conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`.
fn conv3d_bias_backward(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 5>,
) -> FloatTensor<B, 1> {
conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad)
}
/// One dimensional transposed convolution.
///
@ -292,15 +303,30 @@ pub trait ModuleOps<B: Backend> {
) -> FloatTensor<B, 3> {
conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
}
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation.
fn conv_transpose1d_backward(
x: FloatTensor<B, 3>,
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`.
fn conv_transpose1d_x_backward(
weight: FloatTensor<B, 3>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 3>,
options: ConvTransposeOptions<1>,
) -> Conv1dBackward<B> {
conv::conv_transpose1d_backward(x, weight, bias, output_grad, options)
) -> FloatTensor<B, 3> {
conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
}
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`.
fn conv_transpose1d_weight_backward(
x: FloatTensor<B, 3>,
weight: FloatTensor<B, 3>,
output_grad: FloatTensor<B, 3>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<B, 3> {
conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`.
fn conv_transpose1d_bias_backward(
x: FloatTensor<B, 3>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 3>,
) -> FloatTensor<B, 1> {
conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
}
/// Two dimensional transposed convolution.
@ -316,15 +342,30 @@ pub trait ModuleOps<B: Backend> {
bias: Option<FloatTensor<B, 1>>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<B, 4>;
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation.
fn conv_transpose2d_backward(
x: FloatTensor<B, 4>,
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`.
fn conv_transpose2d_x_backward(
weight: FloatTensor<B, 4>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 4>,
options: ConvTransposeOptions<2>,
) -> Conv2dBackward<B> {
conv::conv_transpose2d_backward(x, weight, bias, output_grad, options)
) -> FloatTensor<B, 4> {
conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
}
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`.
fn conv_transpose2d_weight_backward(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
output_grad: FloatTensor<B, 4>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<B, 4> {
conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`.
fn conv_transpose2d_bias_backward(
x: FloatTensor<B, 4>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 4>,
) -> FloatTensor<B, 1> {
conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
}
/// Three dimensional transposed convolution.
@ -340,15 +381,30 @@ pub trait ModuleOps<B: Backend> {
bias: Option<FloatTensor<B, 1>>,
options: ConvTransposeOptions<3>,
) -> FloatTensor<B, 5>;
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation.
fn conv_transpose3d_backward(
x: FloatTensor<B, 5>,
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`.
fn conv_transpose3d_x_backward(
weight: FloatTensor<B, 5>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 5>,
options: ConvTransposeOptions<3>,
) -> Conv3dBackward<B> {
conv::conv_transpose3d_backward(x, weight, bias, output_grad, options)
) -> FloatTensor<B, 5> {
conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
}
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`.
fn conv_transpose3d_weight_backward(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
output_grad: FloatTensor<B, 5>,
options: ConvTransposeOptions<3>,
) -> FloatTensor<B, 5> {
conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`.
fn conv_transpose3d_bias_backward(
x: FloatTensor<B, 5>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 5>,
) -> FloatTensor<B, 1> {
conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
}
/// Four-dimensional unfolding.

View File

@ -1,5 +1,5 @@
#![allow(clippy::single_range_in_vec_init)]
use super::{Conv1dBackward, Conv2dBackward, Conv3dBackward, ConvOptions, ConvTransposeOptions};
use super::{ConvOptions, ConvTransposeOptions};
use crate::{backend::Backend, ops::FloatTensor, Shape};
#[cfg(not(feature = "std"))]
@ -57,19 +57,17 @@ pub fn calculate_pool_output_size(
((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1
}
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass using convolutions.
pub(crate) fn conv1d_backward<B: Backend>(
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `x`.
pub(crate) fn conv1d_x_backward<B: Backend>(
x: FloatTensor<B, 3>,
weight: FloatTensor<B, 3>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 3>,
options: ConvOptions<1>,
) -> Conv1dBackward<B> {
) -> FloatTensor<B, 3> {
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
let [batch_size, _, length_in] = B::float_shape(&x).dims;
let [_batch_size, channels_out, length_out] = B::float_shape(&output_grad).dims;
let [_batch_size, _, length_in] = B::float_shape(&x).dims;
let [_batch_size, _channels_out, length_out] = B::float_shape(&output_grad).dims;
let [_, _, kernel_size] = weight_shape.dims;
let padding_out = calculate_padding_out(
@ -81,8 +79,8 @@ pub(crate) fn conv1d_backward<B: Backend>(
length_out,
);
let x_grad = B::conv_transpose1d(
output_grad.clone(),
B::conv_transpose1d(
output_grad,
weight,
None,
ConvTransposeOptions::new(
@ -92,45 +90,58 @@ pub(crate) fn conv1d_backward<B: Backend>(
options.dilation,
options.groups,
),
);
let weight_grad = match options.groups == 1 {
true => conv1d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
false => conv1d_weight_grad_groups::<B>(
x,
B::float_zeros(weight_shape, &weight_device),
output_grad.clone(),
options,
),
};
Conv1dBackward::new(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&b))
}),
)
}
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass using convolutions.
pub(crate) fn conv2d_backward<B: Backend>(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 4>,
options: ConvOptions<2>,
) -> Conv2dBackward<B> {
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `weight`.
pub(crate) fn conv1d_weight_backward<B: Backend>(
x: FloatTensor<B, 3>,
weight: FloatTensor<B, 3>,
output_grad: FloatTensor<B, 3>,
options: ConvOptions<1>,
) -> FloatTensor<B, 3> {
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
let [batch_size, _channels_in, height_in, width_in] = B::float_shape(&x).dims;
match options.groups == 1 {
true => conv1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
false => conv1d_weight_grad_groups::<B>(
x,
B::float_zeros(weight_shape, &weight_device),
output_grad,
options,
),
}
}
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `bias`.
pub(crate) fn conv1d_bias_backward<B: Backend>(
x: FloatTensor<B, 3>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 3>,
) -> FloatTensor<B, 1> {
let [batch_size, _, _length_in] = B::float_shape(&x).dims;
let [_batch_size, channels_out, length_out] = B::float_shape(&output_grad).dims;
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&bias))
}
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `x`.
pub(crate) fn conv2d_x_backward<B: Backend>(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
output_grad: FloatTensor<B, 4>,
options: ConvOptions<2>,
) -> FloatTensor<B, 4> {
let weight_shape = B::float_shape(&weight);
let [_batch_size, _channels_in, height_in, width_in] = B::float_shape(&x).dims;
let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims;
let [channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims;
let [_channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims;
let padding_1_out = calculate_padding_out(
kernel_size_1,
@ -149,8 +160,8 @@ pub(crate) fn conv2d_backward<B: Backend>(
width_out,
);
let x_grad = B::conv_transpose2d(
output_grad.clone(),
B::conv_transpose2d(
output_grad,
weight,
None,
ConvTransposeOptions::new(
@ -160,48 +171,65 @@ pub(crate) fn conv2d_backward<B: Backend>(
options.dilation,
options.groups,
),
);
let weight_grad = match options.groups == 1 {
true => conv2d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
false => conv2d_weight_grad_groups::<B>(
x,
B::float_zeros(weight_shape, &weight_device),
output_grad.clone(),
options,
),
};
Conv2dBackward::new(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([channels_out, batch_size * height_out * width_out]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&b))
}),
)
}
/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass using convolutions.
pub(crate) fn conv3d_backward<B: Backend>(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 5>,
options: ConvOptions<3>,
) -> Conv3dBackward<B> {
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `weight`.
pub(crate) fn conv2d_weight_backward<B: Backend>(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
output_grad: FloatTensor<B, 4>,
options: ConvOptions<2>,
) -> FloatTensor<B, 4> {
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
let [batch_size, _channels_in, depth_in, height_in, width_in] = B::float_shape(&x).dims;
match options.groups == 1 {
true => conv2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
false => conv2d_weight_grad_groups::<B>(
x,
B::float_zeros(weight_shape, &weight_device),
output_grad,
options,
),
}
}
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `bias`.
pub(crate) fn conv2d_bias_backward<B: Backend>(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 4>,
) -> FloatTensor<B, 1> {
let weight_shape = B::float_shape(&weight);
let [batch_size, _channels_in, _height_in, _width_in] = B::float_shape(&x).dims;
let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims;
let [channels_out, _, _kernel_size_1, _kernel_size_2] = weight_shape.dims;
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([channels_out, batch_size * height_out * width_out]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&bias))
}
/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `x`.
pub(crate) fn conv3d_x_backward<B: Backend>(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
output_grad: FloatTensor<B, 5>,
options: ConvOptions<3>,
) -> FloatTensor<B, 5> {
let weight_shape = B::float_shape(&weight);
let [_batch_size, _channels_in, depth_in, height_in, width_in] = B::float_shape(&x).dims;
let [_, _, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
let [channels_out, _, kernel_size_1, kernel_size_2, kernel_size_3] = weight_shape.dims;
let [_channels_out, _, kernel_size_1, kernel_size_2, kernel_size_3] = weight_shape.dims;
let padding_1_out = calculate_padding_out(
kernel_size_1,
@ -228,8 +256,8 @@ pub(crate) fn conv3d_backward<B: Backend>(
width_out,
);
let x_grad = B::conv_transpose3d(
output_grad.clone(),
B::conv_transpose3d(
output_grad,
weight,
None,
ConvTransposeOptions::new(
@ -239,53 +267,64 @@ pub(crate) fn conv3d_backward<B: Backend>(
options.dilation,
options.groups,
),
);
)
}
let weight_grad = match options.groups == 1 {
true => conv3d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `weight`.
pub(crate) fn conv3d_weight_backward<B: Backend>(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
output_grad: FloatTensor<B, 5>,
options: ConvOptions<3>,
) -> FloatTensor<B, 5> {
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
match options.groups == 1 {
true => conv3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
false => conv3d_weight_grad_groups::<B>(
x,
B::float_zeros(weight_shape, &weight_device),
output_grad.clone(),
output_grad,
options,
),
};
Conv3dBackward::new(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([
channels_out,
batch_size * depth_out * height_out * width_out,
]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&b))
}),
)
}
}
/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass using convolutions.
pub(crate) fn conv_transpose1d_backward<B: Backend>(
x: FloatTensor<B, 3>,
/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `bias`.
pub(crate) fn conv3d_bias_backward<B: Backend>(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 5>,
) -> FloatTensor<B, 1> {
let weight_shape = B::float_shape(&weight);
let [batch_size, _channels_in, _depth_in, _height_in, _width_in] = B::float_shape(&x).dims;
let [_, _, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
let [channels_out, _, _kernel_size_1, _kernel_size_2, _kernel_size_3] = weight_shape.dims;
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([
channels_out,
batch_size * depth_out * height_out * width_out,
]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&bias))
}
/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `x`.
pub(crate) fn conv_transpose1d_x_backward<B: Backend>(
weight: FloatTensor<B, 3>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 3>,
options: ConvTransposeOptions<1>,
) -> Conv1dBackward<B> {
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
let [batch_size, _channels_in, _] = B::float_shape(&x).dims;
let [_, channels_out, length_out] = B::float_shape(&output_grad).dims;
let x_grad = B::conv1d(
output_grad.clone(),
) -> FloatTensor<B, 3> {
B::conv1d(
output_grad,
weight,
None,
ConvOptions::new(
@ -294,52 +333,54 @@ pub(crate) fn conv_transpose1d_backward<B: Backend>(
options.dilation,
options.groups,
),
);
)
}
let weight_grad = match options.groups == 1 {
true => conv_transpose1d_weight_grad_no_groups::<B>(
x,
output_grad.clone(),
weight_shape,
options,
),
/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `weight`.
pub(crate) fn conv_transpose1d_weight_backward<B: Backend>(
x: FloatTensor<B, 3>,
weight: FloatTensor<B, 3>,
output_grad: FloatTensor<B, 3>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<B, 3> {
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
match options.groups == 1 {
true => conv_transpose1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
false => conv_transpose1d_weight_grad_groups::<B>(
x,
B::float_zeros(weight_shape, &weight_device),
output_grad.clone(),
output_grad,
options,
),
};
Conv1dBackward::new(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&b))
}),
)
}
}
/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass using convolutions.
pub(crate) fn conv_transpose2d_backward<B: Backend>(
x: FloatTensor<B, 4>,
/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `bias`.
pub(crate) fn conv_transpose1d_bias_backward<B: Backend>(
x: FloatTensor<B, 3>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 3>,
) -> FloatTensor<B, 1> {
let [batch_size, _channels_in, _] = B::float_shape(&x).dims;
let [_, channels_out, length_out] = B::float_shape(&output_grad).dims;
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&bias))
}
/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `x`.
pub(crate) fn conv_transpose2d_x_backward<B: Backend>(
weight: FloatTensor<B, 4>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 4>,
options: ConvTransposeOptions<2>,
) -> Conv2dBackward<B> {
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
let [batch_size, _channels_in, _, _] = B::float_shape(&x).dims;
let [_, channels_out, height_out, width_out] = B::float_shape(&output_grad).dims;
let x_grad = B::conv2d(
output_grad.clone(),
) -> FloatTensor<B, 4> {
B::conv2d(
output_grad,
weight,
None,
ConvOptions::new(
@ -348,55 +389,57 @@ pub(crate) fn conv_transpose2d_backward<B: Backend>(
options.dilation,
options.groups,
),
);
)
}
let weight_grad = match options.groups == 1 {
true => conv_transpose2d_weight_grad_no_groups::<B>(
x,
output_grad.clone(),
weight_shape,
options,
),
/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `weight`.
pub(crate) fn conv_transpose2d_weight_backward<B: Backend>(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
output_grad: FloatTensor<B, 4>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<B, 4> {
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
match options.groups == 1 {
true => conv_transpose2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
false => conv_transpose2d_weight_grad_groups::<B>(
x,
B::float_zeros(weight_shape, &weight_device),
output_grad.clone(),
output_grad,
options,
),
};
Conv2dBackward::new(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([channels_out, batch_size * height_out * width_out]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&b))
}),
)
}
}
/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass using convolutions.
pub(crate) fn conv_transpose3d_backward<B: Backend>(
x: FloatTensor<B, 5>,
/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `bias`.
pub(crate) fn conv_transpose2d_bias_backward<B: Backend>(
x: FloatTensor<B, 4>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 4>,
) -> FloatTensor<B, 1> {
let [batch_size, _channels_in, _, _] = B::float_shape(&x).dims;
let [_, channels_out, height_out, width_out] = B::float_shape(&output_grad).dims;
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([channels_out, batch_size * height_out * width_out]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&bias))
}
/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `x`.
pub(crate) fn conv_transpose3d_x_backward<B: Backend>(
weight: FloatTensor<B, 5>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 5>,
options: ConvTransposeOptions<3>,
) -> Conv3dBackward<B> {
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
let [batch_size, _channels_in, _, _, _] = B::float_shape(&x).dims;
let [_, channels_out, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
let x_grad = B::conv3d(
output_grad.clone(),
) -> FloatTensor<B, 5> {
B::conv3d(
output_grad,
weight,
None,
ConvOptions::new(
@ -405,40 +448,50 @@ pub(crate) fn conv_transpose3d_backward<B: Backend>(
options.dilation,
options.groups,
),
);
)
}
let weight_grad = match options.groups == 1 {
true => conv_transpose3d_weight_grad_no_groups::<B>(
x,
output_grad.clone(),
weight_shape,
options,
),
/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `weight`.
pub(crate) fn conv_transpose3d_weight_backward<B: Backend>(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
output_grad: FloatTensor<B, 5>,
options: ConvTransposeOptions<3>,
) -> FloatTensor<B, 5> {
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
match options.groups == 1 {
true => conv_transpose3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
false => conv_transpose3d_weight_grad_groups::<B>(
x,
B::float_zeros(weight_shape, &weight_device),
output_grad.clone(),
output_grad,
options,
),
};
}
}
Conv3dBackward::new(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([
channels_out,
batch_size * depth_out * height_out * width_out,
]),
);
let grad = B::float_sum_dim(grad, 1);
/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `bias`.
pub(crate) fn conv_transpose3d_bias_backward<B: Backend>(
x: FloatTensor<B, 5>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 5>,
) -> FloatTensor<B, 1> {
let [batch_size, _channels_in, _, _, _] = B::float_shape(&x).dims;
let [_, channels_out, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
B::float_reshape(grad, B::float_shape(&b))
}),
)
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([
channels_out,
batch_size * depth_out * height_out * width_out,
]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&bias))
}
/// Execute a 1D convolution using a 2D convolution.