Refactor burn-tensor: Split conv backward ops to allow conditional gradient computation (#2278)

This commit is contained in:
Asher Jingkong Chen 2024-09-16 22:15:27 +08:00 committed by GitHub
parent 81ec64a929
commit 7ac5deebe2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 576 additions and 372 deletions

View File

@ -78,20 +78,28 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 3>(&ops.node); let grad = grads.consume::<B, 3>(&ops.node);
let (x_state, weight_state, bias_state, options) = ops.state; let (x_state, weight_state, bias_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state); let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
let weight = checkpointer.retrieve_node_output(weight_state); let weight =
let bias = Some(checkpointer.retrieve_node_output(bias_state)); checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
let bias =
let backward = B::conv1d_backward(x, weight, bias, grad, options); checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
if let Some(node) = node_x { if let Some(node) = node_x {
grads.register::<B, 3>(node.id, backward.x_grad) let grad = B::conv1d_x_backward(
x.clone(),
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 3>(node.id, grad)
} }
if let Some(node) = node_weight { if let Some(node) = node_weight {
grads.register::<B, 3>(node.id, backward.weights_grad) let grad = B::conv1d_weight_backward(x.clone(), weight, grad.clone(), options);
grads.register::<B, 3>(node.id, grad)
} }
if let Some(node) = node_bias { if let Some(node) = node_bias {
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap()) let grad = B::conv1d_bias_backward(x, bias, grad);
grads.register::<B, 1>(node.id, grad)
} }
} }
} }
@ -109,16 +117,22 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 3>(&ops.node); let grad = grads.consume::<B, 3>(&ops.node);
let (x_state, weight_state, options) = ops.state; let (x_state, weight_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state); let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
let weight = checkpointer.retrieve_node_output(weight_state); let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
let backward = B::conv1d_backward(x, weight, None, grad, options);
if let Some(node) = node_x { if let Some(node) = node_x {
grads.register::<B, 3>(node.id, backward.x_grad) let grad = B::conv1d_x_backward(
x.clone(),
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 3>(node.id, grad)
} }
if let Some(node) = node_weight { if let Some(node) = node_weight {
grads.register::<B, 3>(node.id, backward.weights_grad) let grad = B::conv1d_weight_backward(x, weight, grad, options);
grads.register::<B, 3>(node.id, grad)
} }
} }
} }
@ -188,20 +202,32 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 3>(&ops.node); let grad = grads.consume::<B, 3>(&ops.node);
let (x_state, weight_state, bias_state, options) = ops.state; let (x_state, weight_state, bias_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state); let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
let weight = checkpointer.retrieve_node_output(weight_state); let weight =
let bias = Some(checkpointer.retrieve_node_output(bias_state)); checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
let bias =
let backward = B::conv_transpose1d_backward(x, weight, bias, grad, options); checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
if let Some(node) = node_x { if let Some(node) = node_x {
grads.register::<B, 3>(node.id, backward.x_grad) let grad = B::conv_transpose1d_x_backward(
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 3>(node.id, grad)
} }
if let Some(node) = node_weight { if let Some(node) = node_weight {
grads.register::<B, 3>(node.id, backward.weights_grad) let grad = B::conv_transpose1d_weight_backward(
x.clone(),
weight,
grad.clone(),
options,
);
grads.register::<B, 3>(node.id, grad)
} }
if let Some(node) = node_bias { if let Some(node) = node_bias {
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap()) let grad = B::conv_transpose1d_bias_backward(x, bias, grad);
grads.register::<B, 1>(node.id, grad)
} }
} }
} }
@ -219,16 +245,21 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 3>(&ops.node); let grad = grads.consume::<B, 3>(&ops.node);
let (x_state, weight_state, options) = ops.state; let (x_state, weight_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state); let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
let weight = checkpointer.retrieve_node_output(weight_state); let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
let backward = B::conv_transpose1d_backward(x, weight, None, grad, options);
if let Some(node) = node_x { if let Some(node) = node_x {
grads.register::<B, 3>(node.id, backward.x_grad) let grad = B::conv_transpose1d_x_backward(
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 3>(node.id, grad)
} }
if let Some(node) = node_weight { if let Some(node) = node_weight {
grads.register::<B, 3>(node.id, backward.weights_grad) let grad = B::conv_transpose1d_weight_backward(x, weight, grad, options);
grads.register::<B, 3>(node.id, grad)
} }
} }
} }
@ -307,20 +338,29 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 4>(&ops.node); let grad = grads.consume::<B, 4>(&ops.node);
let (x_state, weight_state, bias_state, options) = ops.state; let (x_state, weight_state, bias_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state); let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
let weight = checkpointer.retrieve_node_output(weight_state); let weight =
let bias = Some(checkpointer.retrieve_node_output(bias_state)); checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
let bias =
let backward = B::conv2d_backward(x, weight, bias, grad, options); checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
if let Some(node) = node_x { if let Some(node) = node_x {
grads.register::<B, 4>(node.id, backward.x_grad) let grad = B::conv2d_x_backward(
x.clone(),
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 4>(node.id, grad)
} }
if let Some(node) = node_weight { if let Some(node) = node_weight {
grads.register::<B, 4>(node.id, backward.weights_grad) let grad =
B::conv2d_weight_backward(x.clone(), weight.clone(), grad.clone(), options);
grads.register::<B, 4>(node.id, grad)
} }
if let Some(node) = node_bias { if let Some(node) = node_bias {
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap()) let grad = B::conv2d_bias_backward(x, weight, bias, grad);
grads.register::<B, 1>(node.id, grad)
} }
} }
} }
@ -338,16 +378,22 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 4>(&ops.node); let grad = grads.consume::<B, 4>(&ops.node);
let (x_state, weight_state, options) = ops.state; let (x_state, weight_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state); let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
let weight = checkpointer.retrieve_node_output(weight_state); let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
let backward = B::conv2d_backward(x, weight, None, grad, options);
if let Some(node) = node_x { if let Some(node) = node_x {
grads.register::<B, 4>(node.id, backward.x_grad) let grad = B::conv2d_x_backward(
x.clone(),
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 4>(node.id, grad)
} }
if let Some(node) = node_weight { if let Some(node) = node_weight {
grads.register::<B, 4>(node.id, backward.weights_grad) let grad = B::conv2d_weight_backward(x, weight, grad, options);
grads.register::<B, 4>(node.id, grad)
} }
} }
} }
@ -419,20 +465,32 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 4>(&ops.node); let grad = grads.consume::<B, 4>(&ops.node);
let (x_state, weight_state, bias_state, options) = ops.state; let (x_state, weight_state, bias_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state); let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
let weight = checkpointer.retrieve_node_output(weight_state); let weight =
let bias = Some(checkpointer.retrieve_node_output(bias_state)); checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
let bias =
let backward = B::conv_transpose2d_backward(x, weight, bias, grad, options); checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
if let Some(node) = node_x { if let Some(node) = node_x {
grads.register::<B, 4>(node.id, backward.x_grad) let grad = B::conv_transpose2d_x_backward(
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 4>(node.id, grad)
} }
if let Some(node) = node_weight { if let Some(node) = node_weight {
grads.register::<B, 4>(node.id, backward.weights_grad) let grad = B::conv_transpose2d_weight_backward(
x.clone(),
weight,
grad.clone(),
options,
);
grads.register::<B, 4>(node.id, grad)
} }
if let Some(node) = node_bias { if let Some(node) = node_bias {
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap()) let grad = B::conv_transpose2d_bias_backward(x, bias, grad);
grads.register::<B, 1>(node.id, grad)
} }
} }
} }
@ -450,16 +508,21 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 4>(&ops.node); let grad = grads.consume::<B, 4>(&ops.node);
let (x_state, weight_state, options) = ops.state; let (x_state, weight_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state); let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
let weight = checkpointer.retrieve_node_output(weight_state); let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
let backward = B::conv_transpose2d_backward(x, weight, None, grad, options);
if let Some(node) = node_x { if let Some(node) = node_x {
grads.register::<B, 4>(node.id, backward.x_grad) let grad = B::conv_transpose2d_x_backward(
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 4>(node.id, grad)
} }
if let Some(node) = node_weight { if let Some(node) = node_weight {
grads.register::<B, 4>(node.id, backward.weights_grad) let grad = B::conv_transpose2d_weight_backward(x, weight, grad, options);
grads.register::<B, 4>(node.id, grad)
} }
} }
} }
@ -540,20 +603,29 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 5>(&ops.node); let grad = grads.consume::<B, 5>(&ops.node);
let (x_state, weight_state, bias_state, options) = ops.state; let (x_state, weight_state, bias_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state); let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
let weight = checkpointer.retrieve_node_output(weight_state); let weight =
let bias = Some(checkpointer.retrieve_node_output(bias_state)); checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
let bias =
let backward = B::conv3d_backward(x, weight, bias, grad, options); checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
if let Some(node) = node_x { if let Some(node) = node_x {
grads.register::<B, 5>(node.id, backward.x_grad) let grad = B::conv3d_x_backward(
x.clone(),
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 5>(node.id, grad)
} }
if let Some(node) = node_weight { if let Some(node) = node_weight {
grads.register::<B, 5>(node.id, backward.weights_grad) let grad =
B::conv3d_weight_backward(x.clone(), weight.clone(), grad.clone(), options);
grads.register::<B, 5>(node.id, grad)
} }
if let Some(node) = node_bias { if let Some(node) = node_bias {
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap()) let grad = B::conv3d_bias_backward(x, weight, bias, grad);
grads.register::<B, 1>(node.id, grad)
} }
} }
} }
@ -571,16 +643,22 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 5>(&ops.node); let grad = grads.consume::<B, 5>(&ops.node);
let (x_state, weight_state, options) = ops.state; let (x_state, weight_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state); let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
let weight = checkpointer.retrieve_node_output(weight_state); let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
let backward = B::conv3d_backward(x, weight, None, grad, options);
if let Some(node) = node_x { if let Some(node) = node_x {
grads.register::<B, 5>(node.id, backward.x_grad) let grad = B::conv3d_x_backward(
x.clone(),
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 5>(node.id, grad)
} }
if let Some(node) = node_weight { if let Some(node) = node_weight {
grads.register::<B, 5>(node.id, backward.weights_grad) let grad = B::conv3d_weight_backward(x, weight, grad, options);
grads.register::<B, 5>(node.id, grad)
} }
} }
} }
@ -652,20 +730,32 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 5>(&ops.node); let grad = grads.consume::<B, 5>(&ops.node);
let (x_state, weight_state, bias_state, options) = ops.state; let (x_state, weight_state, bias_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state); let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
let weight = checkpointer.retrieve_node_output(weight_state); let weight =
let bias = Some(checkpointer.retrieve_node_output(bias_state)); checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
let bias =
let backward = B::conv_transpose3d_backward(x, weight, bias, grad, options); checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
if let Some(node) = node_x { if let Some(node) = node_x {
grads.register::<B, 5>(node.id, backward.x_grad) let grad = B::conv_transpose3d_x_backward(
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 5>(node.id, grad)
} }
if let Some(node) = node_weight { if let Some(node) = node_weight {
grads.register::<B, 5>(node.id, backward.weights_grad) let grad = B::conv_transpose3d_weight_backward(
x.clone(),
weight,
grad.clone(),
options,
);
grads.register::<B, 5>(node.id, grad)
} }
if let Some(node) = node_bias { if let Some(node) = node_bias {
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap()) let grad = B::conv_transpose3d_bias_backward(x, bias, grad);
grads.register::<B, 1>(node.id, grad)
} }
} }
} }
@ -683,16 +773,21 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let grad = grads.consume::<B, 5>(&ops.node); let grad = grads.consume::<B, 5>(&ops.node);
let (x_state, weight_state, options) = ops.state; let (x_state, weight_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state); let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
let weight = checkpointer.retrieve_node_output(weight_state); let weight =
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
let backward = B::conv_transpose3d_backward(x, weight, None, grad, options);
if let Some(node) = node_x { if let Some(node) = node_x {
grads.register::<B, 5>(node.id, backward.x_grad) let grad = B::conv_transpose3d_x_backward(
weight.clone(),
grad.clone(),
options.clone(),
);
grads.register::<B, 5>(node.id, grad)
} }
if let Some(node) = node_weight { if let Some(node) = node_weight {
grads.register::<B, 5>(node.id, backward.weights_grad) let grad = B::conv_transpose3d_weight_backward(x, weight, grad, options);
grads.register::<B, 5>(node.id, grad)
} }
} }
} }

View File

@ -5,32 +5,6 @@ use crate::{
Shape, Shape,
}; };
/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
#[derive(new)]
pub struct Conv2dBackward<B: Backend> {
/// Gradient.
pub x_grad: FloatTensor<B, 4>,
/// Weights gradient.
pub weights_grad: FloatTensor<B, 4>,
/// Bias gradient.
pub bias_grad: Option<FloatTensor<B, 1>>,
}
/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).
#[derive(new)]
pub struct Conv3dBackward<B: Backend> {
/// Gradient.
pub x_grad: FloatTensor<B, 5>,
/// Weights gradient.
pub weights_grad: FloatTensor<B, 5>,
/// Bias gradient.
pub bias_grad: Option<FloatTensor<B, 1>>,
}
/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d). /// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
#[derive(new)] #[derive(new)]
pub struct MaxPool1dBackward<B: Backend> { pub struct MaxPool1dBackward<B: Backend> {
@ -65,19 +39,6 @@ pub struct MaxPool2dWithIndices<B: Backend> {
pub indices: IntTensor<B, 4>, pub indices: IntTensor<B, 4>,
} }
/// Gradient computed during the backward pass for each tensor used by [conv1d](ModuleOps::conv1d).
#[derive(new)]
pub struct Conv1dBackward<B: Backend> {
/// Gradient.
pub x_grad: FloatTensor<B, 3>,
/// Weights gradient.
pub weights_grad: FloatTensor<B, 3>,
/// Bias gradient.
pub bias_grad: Option<FloatTensor<B, 1>>,
}
/// Convolution options. /// Convolution options.
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)] #[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
pub struct ConvOptions<const N: usize> { pub struct ConvOptions<const N: usize> {
@ -221,15 +182,31 @@ pub trait ModuleOps<B: Backend> {
) -> FloatTensor<B, 3> { ) -> FloatTensor<B, 3> {
conv::conv1d_from_conv2d::<B>(x, weight, bias, options) conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
} }
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation. /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`.
fn conv1d_backward( fn conv1d_x_backward(
x: FloatTensor<B, 3>, x: FloatTensor<B, 3>,
weight: FloatTensor<B, 3>, weight: FloatTensor<B, 3>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 3>, output_grad: FloatTensor<B, 3>,
options: ConvOptions<1>, options: ConvOptions<1>,
) -> Conv1dBackward<B> { ) -> FloatTensor<B, 3> {
conv::conv1d_backward(x, weight, bias, output_grad, options) conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`.
fn conv1d_weight_backward(
x: FloatTensor<B, 3>,
weight: FloatTensor<B, 3>,
output_grad: FloatTensor<B, 3>,
options: ConvOptions<1>,
) -> FloatTensor<B, 3> {
conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`.
fn conv1d_bias_backward(
x: FloatTensor<B, 3>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 3>,
) -> FloatTensor<B, 1> {
conv::conv1d_bias_backward::<B>(x, bias, output_grad)
} }
/// Two dimensional convolution. /// Two dimensional convolution.
/// ///
@ -244,15 +221,32 @@ pub trait ModuleOps<B: Backend> {
bias: Option<FloatTensor<B, 1>>, bias: Option<FloatTensor<B, 1>>,
options: ConvOptions<2>, options: ConvOptions<2>,
) -> FloatTensor<B, 4>; ) -> FloatTensor<B, 4>;
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation. /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`.
fn conv2d_backward( fn conv2d_x_backward(
x: FloatTensor<B, 4>, x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>, weight: FloatTensor<B, 4>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 4>, output_grad: FloatTensor<B, 4>,
options: ConvOptions<2>, options: ConvOptions<2>,
) -> Conv2dBackward<B> { ) -> FloatTensor<B, 4> {
conv::conv2d_backward(x, weight, bias, output_grad, options) conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`.
fn conv2d_weight_backward(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
output_grad: FloatTensor<B, 4>,
options: ConvOptions<2>,
) -> FloatTensor<B, 4> {
conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`.
fn conv2d_bias_backward(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 4>,
) -> FloatTensor<B, 1> {
conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
} }
/// Three dimensional convolution. /// Three dimensional convolution.
/// ///
@ -267,15 +261,32 @@ pub trait ModuleOps<B: Backend> {
bias: Option<FloatTensor<B, 1>>, bias: Option<FloatTensor<B, 1>>,
options: ConvOptions<3>, options: ConvOptions<3>,
) -> FloatTensor<B, 5>; ) -> FloatTensor<B, 5>;
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation. /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`.
fn conv3d_backward( fn conv3d_x_backward(
x: FloatTensor<B, 5>, x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>, weight: FloatTensor<B, 5>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 5>, output_grad: FloatTensor<B, 5>,
options: ConvOptions<3>, options: ConvOptions<3>,
) -> Conv3dBackward<B> { ) -> FloatTensor<B, 5> {
conv::conv3d_backward(x, weight, bias, output_grad, options) conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`.
fn conv3d_weight_backward(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
output_grad: FloatTensor<B, 5>,
options: ConvOptions<3>,
) -> FloatTensor<B, 5> {
conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`.
fn conv3d_bias_backward(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 5>,
) -> FloatTensor<B, 1> {
conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad)
} }
/// One dimensional transposed convolution. /// One dimensional transposed convolution.
/// ///
@ -292,15 +303,30 @@ pub trait ModuleOps<B: Backend> {
) -> FloatTensor<B, 3> { ) -> FloatTensor<B, 3> {
conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options) conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
} }
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation. /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`.
fn conv_transpose1d_backward( fn conv_transpose1d_x_backward(
x: FloatTensor<B, 3>,
weight: FloatTensor<B, 3>, weight: FloatTensor<B, 3>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 3>, output_grad: FloatTensor<B, 3>,
options: ConvTransposeOptions<1>, options: ConvTransposeOptions<1>,
) -> Conv1dBackward<B> { ) -> FloatTensor<B, 3> {
conv::conv_transpose1d_backward(x, weight, bias, output_grad, options) conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
}
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`.
fn conv_transpose1d_weight_backward(
x: FloatTensor<B, 3>,
weight: FloatTensor<B, 3>,
output_grad: FloatTensor<B, 3>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<B, 3> {
conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`.
fn conv_transpose1d_bias_backward(
x: FloatTensor<B, 3>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 3>,
) -> FloatTensor<B, 1> {
conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
} }
/// Two dimensional transposed convolution. /// Two dimensional transposed convolution.
@ -316,15 +342,30 @@ pub trait ModuleOps<B: Backend> {
bias: Option<FloatTensor<B, 1>>, bias: Option<FloatTensor<B, 1>>,
options: ConvTransposeOptions<2>, options: ConvTransposeOptions<2>,
) -> FloatTensor<B, 4>; ) -> FloatTensor<B, 4>;
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation. /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`.
fn conv_transpose2d_backward( fn conv_transpose2d_x_backward(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>, weight: FloatTensor<B, 4>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 4>, output_grad: FloatTensor<B, 4>,
options: ConvTransposeOptions<2>, options: ConvTransposeOptions<2>,
) -> Conv2dBackward<B> { ) -> FloatTensor<B, 4> {
conv::conv_transpose2d_backward(x, weight, bias, output_grad, options) conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
}
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`.
fn conv_transpose2d_weight_backward(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
output_grad: FloatTensor<B, 4>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<B, 4> {
conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`.
fn conv_transpose2d_bias_backward(
x: FloatTensor<B, 4>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 4>,
) -> FloatTensor<B, 1> {
conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
} }
/// Three dimensional transposed convolution. /// Three dimensional transposed convolution.
@ -340,15 +381,30 @@ pub trait ModuleOps<B: Backend> {
bias: Option<FloatTensor<B, 1>>, bias: Option<FloatTensor<B, 1>>,
options: ConvTransposeOptions<3>, options: ConvTransposeOptions<3>,
) -> FloatTensor<B, 5>; ) -> FloatTensor<B, 5>;
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation. /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`.
fn conv_transpose3d_backward( fn conv_transpose3d_x_backward(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>, weight: FloatTensor<B, 5>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 5>, output_grad: FloatTensor<B, 5>,
options: ConvTransposeOptions<3>, options: ConvTransposeOptions<3>,
) -> Conv3dBackward<B> { ) -> FloatTensor<B, 5> {
conv::conv_transpose3d_backward(x, weight, bias, output_grad, options) conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
}
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`.
fn conv_transpose3d_weight_backward(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
output_grad: FloatTensor<B, 5>,
options: ConvTransposeOptions<3>,
) -> FloatTensor<B, 5> {
conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
}
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`.
fn conv_transpose3d_bias_backward(
x: FloatTensor<B, 5>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 5>,
) -> FloatTensor<B, 1> {
conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
} }
/// Four-dimensional unfolding. /// Four-dimensional unfolding.

View File

@ -1,5 +1,5 @@
#![allow(clippy::single_range_in_vec_init)] #![allow(clippy::single_range_in_vec_init)]
use super::{Conv1dBackward, Conv2dBackward, Conv3dBackward, ConvOptions, ConvTransposeOptions}; use super::{ConvOptions, ConvTransposeOptions};
use crate::{backend::Backend, ops::FloatTensor, Shape}; use crate::{backend::Backend, ops::FloatTensor, Shape};
#[cfg(not(feature = "std"))] #[cfg(not(feature = "std"))]
@ -57,19 +57,17 @@ pub fn calculate_pool_output_size(
((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1 ((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1
} }
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass using convolutions. /// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `x`.
pub(crate) fn conv1d_backward<B: Backend>( pub(crate) fn conv1d_x_backward<B: Backend>(
x: FloatTensor<B, 3>, x: FloatTensor<B, 3>,
weight: FloatTensor<B, 3>, weight: FloatTensor<B, 3>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 3>, output_grad: FloatTensor<B, 3>,
options: ConvOptions<1>, options: ConvOptions<1>,
) -> Conv1dBackward<B> { ) -> FloatTensor<B, 3> {
let weight_shape = B::float_shape(&weight); let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
let [batch_size, _, length_in] = B::float_shape(&x).dims; let [_batch_size, _, length_in] = B::float_shape(&x).dims;
let [_batch_size, channels_out, length_out] = B::float_shape(&output_grad).dims; let [_batch_size, _channels_out, length_out] = B::float_shape(&output_grad).dims;
let [_, _, kernel_size] = weight_shape.dims; let [_, _, kernel_size] = weight_shape.dims;
let padding_out = calculate_padding_out( let padding_out = calculate_padding_out(
@ -81,8 +79,8 @@ pub(crate) fn conv1d_backward<B: Backend>(
length_out, length_out,
); );
let x_grad = B::conv_transpose1d( B::conv_transpose1d(
output_grad.clone(), output_grad,
weight, weight,
None, None,
ConvTransposeOptions::new( ConvTransposeOptions::new(
@ -92,45 +90,58 @@ pub(crate) fn conv1d_backward<B: Backend>(
options.dilation, options.dilation,
options.groups, options.groups,
), ),
);
let weight_grad = match options.groups == 1 {
true => conv1d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
false => conv1d_weight_grad_groups::<B>(
x,
B::float_zeros(weight_shape, &weight_device),
output_grad.clone(),
options,
),
};
Conv1dBackward::new(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&b))
}),
) )
} }
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass using convolutions. /// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `weight`.
pub(crate) fn conv2d_backward<B: Backend>( pub(crate) fn conv1d_weight_backward<B: Backend>(
x: FloatTensor<B, 4>, x: FloatTensor<B, 3>,
weight: FloatTensor<B, 4>, weight: FloatTensor<B, 3>,
bias: Option<FloatTensor<B, 1>>, output_grad: FloatTensor<B, 3>,
output_grad: FloatTensor<B, 4>, options: ConvOptions<1>,
options: ConvOptions<2>, ) -> FloatTensor<B, 3> {
) -> Conv2dBackward<B> {
let weight_shape = B::float_shape(&weight); let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight); let weight_device = B::float_device(&weight);
let [batch_size, _channels_in, height_in, width_in] = B::float_shape(&x).dims; match options.groups == 1 {
true => conv1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
false => conv1d_weight_grad_groups::<B>(
x,
B::float_zeros(weight_shape, &weight_device),
output_grad,
options,
),
}
}
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `bias`.
pub(crate) fn conv1d_bias_backward<B: Backend>(
x: FloatTensor<B, 3>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 3>,
) -> FloatTensor<B, 1> {
let [batch_size, _, _length_in] = B::float_shape(&x).dims;
let [_batch_size, channels_out, length_out] = B::float_shape(&output_grad).dims;
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&bias))
}
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `x`.
pub(crate) fn conv2d_x_backward<B: Backend>(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
output_grad: FloatTensor<B, 4>,
options: ConvOptions<2>,
) -> FloatTensor<B, 4> {
let weight_shape = B::float_shape(&weight);
let [_batch_size, _channels_in, height_in, width_in] = B::float_shape(&x).dims;
let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims; let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims;
let [channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims; let [_channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims;
let padding_1_out = calculate_padding_out( let padding_1_out = calculate_padding_out(
kernel_size_1, kernel_size_1,
@ -149,8 +160,8 @@ pub(crate) fn conv2d_backward<B: Backend>(
width_out, width_out,
); );
let x_grad = B::conv_transpose2d( B::conv_transpose2d(
output_grad.clone(), output_grad,
weight, weight,
None, None,
ConvTransposeOptions::new( ConvTransposeOptions::new(
@ -160,48 +171,65 @@ pub(crate) fn conv2d_backward<B: Backend>(
options.dilation, options.dilation,
options.groups, options.groups,
), ),
);
let weight_grad = match options.groups == 1 {
true => conv2d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
false => conv2d_weight_grad_groups::<B>(
x,
B::float_zeros(weight_shape, &weight_device),
output_grad.clone(),
options,
),
};
Conv2dBackward::new(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([channels_out, batch_size * height_out * width_out]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&b))
}),
) )
} }
/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass using convolutions. /// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `weight`.
pub(crate) fn conv3d_backward<B: Backend>( pub(crate) fn conv2d_weight_backward<B: Backend>(
x: FloatTensor<B, 5>, x: FloatTensor<B, 4>,
weight: FloatTensor<B, 5>, weight: FloatTensor<B, 4>,
bias: Option<FloatTensor<B, 1>>, output_grad: FloatTensor<B, 4>,
output_grad: FloatTensor<B, 5>, options: ConvOptions<2>,
options: ConvOptions<3>, ) -> FloatTensor<B, 4> {
) -> Conv3dBackward<B> {
let weight_shape = B::float_shape(&weight); let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight); let weight_device = B::float_device(&weight);
let [batch_size, _channels_in, depth_in, height_in, width_in] = B::float_shape(&x).dims; match options.groups == 1 {
true => conv2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
false => conv2d_weight_grad_groups::<B>(
x,
B::float_zeros(weight_shape, &weight_device),
output_grad,
options,
),
}
}
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `bias`.
pub(crate) fn conv2d_bias_backward<B: Backend>(
x: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 4>,
) -> FloatTensor<B, 1> {
let weight_shape = B::float_shape(&weight);
let [batch_size, _channels_in, _height_in, _width_in] = B::float_shape(&x).dims;
let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims;
let [channels_out, _, _kernel_size_1, _kernel_size_2] = weight_shape.dims;
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([channels_out, batch_size * height_out * width_out]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&bias))
}
/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `x`.
pub(crate) fn conv3d_x_backward<B: Backend>(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
output_grad: FloatTensor<B, 5>,
options: ConvOptions<3>,
) -> FloatTensor<B, 5> {
let weight_shape = B::float_shape(&weight);
let [_batch_size, _channels_in, depth_in, height_in, width_in] = B::float_shape(&x).dims;
let [_, _, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims; let [_, _, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
let [channels_out, _, kernel_size_1, kernel_size_2, kernel_size_3] = weight_shape.dims; let [_channels_out, _, kernel_size_1, kernel_size_2, kernel_size_3] = weight_shape.dims;
let padding_1_out = calculate_padding_out( let padding_1_out = calculate_padding_out(
kernel_size_1, kernel_size_1,
@ -228,8 +256,8 @@ pub(crate) fn conv3d_backward<B: Backend>(
width_out, width_out,
); );
let x_grad = B::conv_transpose3d( B::conv_transpose3d(
output_grad.clone(), output_grad,
weight, weight,
None, None,
ConvTransposeOptions::new( ConvTransposeOptions::new(
@ -239,53 +267,64 @@ pub(crate) fn conv3d_backward<B: Backend>(
options.dilation, options.dilation,
options.groups, options.groups,
), ),
); )
}
let weight_grad = match options.groups == 1 { /// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `weight`.
true => conv3d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options), pub(crate) fn conv3d_weight_backward<B: Backend>(
x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
output_grad: FloatTensor<B, 5>,
options: ConvOptions<3>,
) -> FloatTensor<B, 5> {
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
match options.groups == 1 {
true => conv3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
false => conv3d_weight_grad_groups::<B>( false => conv3d_weight_grad_groups::<B>(
x, x,
B::float_zeros(weight_shape, &weight_device), B::float_zeros(weight_shape, &weight_device),
output_grad.clone(), output_grad,
options, options,
), ),
}; }
Conv3dBackward::new(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([
channels_out,
batch_size * depth_out * height_out * width_out,
]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&b))
}),
)
} }
/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass using convolutions. /// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `bias`.
pub(crate) fn conv_transpose1d_backward<B: Backend>( pub(crate) fn conv3d_bias_backward<B: Backend>(
x: FloatTensor<B, 3>, x: FloatTensor<B, 5>,
weight: FloatTensor<B, 5>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 5>,
) -> FloatTensor<B, 1> {
let weight_shape = B::float_shape(&weight);
let [batch_size, _channels_in, _depth_in, _height_in, _width_in] = B::float_shape(&x).dims;
let [_, _, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
let [channels_out, _, _kernel_size_1, _kernel_size_2, _kernel_size_3] = weight_shape.dims;
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([
channels_out,
batch_size * depth_out * height_out * width_out,
]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&bias))
}
/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `x`.
pub(crate) fn conv_transpose1d_x_backward<B: Backend>(
weight: FloatTensor<B, 3>, weight: FloatTensor<B, 3>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 3>, output_grad: FloatTensor<B, 3>,
options: ConvTransposeOptions<1>, options: ConvTransposeOptions<1>,
) -> Conv1dBackward<B> { ) -> FloatTensor<B, 3> {
let weight_shape = B::float_shape(&weight); B::conv1d(
let weight_device = B::float_device(&weight); output_grad,
let [batch_size, _channels_in, _] = B::float_shape(&x).dims;
let [_, channels_out, length_out] = B::float_shape(&output_grad).dims;
let x_grad = B::conv1d(
output_grad.clone(),
weight, weight,
None, None,
ConvOptions::new( ConvOptions::new(
@ -294,52 +333,54 @@ pub(crate) fn conv_transpose1d_backward<B: Backend>(
options.dilation, options.dilation,
options.groups, options.groups,
), ),
); )
}
let weight_grad = match options.groups == 1 { /// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `weight`.
true => conv_transpose1d_weight_grad_no_groups::<B>( pub(crate) fn conv_transpose1d_weight_backward<B: Backend>(
x, x: FloatTensor<B, 3>,
output_grad.clone(), weight: FloatTensor<B, 3>,
weight_shape, output_grad: FloatTensor<B, 3>,
options, options: ConvTransposeOptions<1>,
), ) -> FloatTensor<B, 3> {
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
match options.groups == 1 {
true => conv_transpose1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
false => conv_transpose1d_weight_grad_groups::<B>( false => conv_transpose1d_weight_grad_groups::<B>(
x, x,
B::float_zeros(weight_shape, &weight_device), B::float_zeros(weight_shape, &weight_device),
output_grad.clone(), output_grad,
options, options,
), ),
}; }
Conv1dBackward::new(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&b))
}),
)
} }
/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass using convolutions. /// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `bias`.
pub(crate) fn conv_transpose2d_backward<B: Backend>( pub(crate) fn conv_transpose1d_bias_backward<B: Backend>(
x: FloatTensor<B, 4>, x: FloatTensor<B, 3>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 3>,
) -> FloatTensor<B, 1> {
let [batch_size, _channels_in, _] = B::float_shape(&x).dims;
let [_, channels_out, length_out] = B::float_shape(&output_grad).dims;
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&bias))
}
/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `x`.
pub(crate) fn conv_transpose2d_x_backward<B: Backend>(
weight: FloatTensor<B, 4>, weight: FloatTensor<B, 4>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 4>, output_grad: FloatTensor<B, 4>,
options: ConvTransposeOptions<2>, options: ConvTransposeOptions<2>,
) -> Conv2dBackward<B> { ) -> FloatTensor<B, 4> {
let weight_shape = B::float_shape(&weight); B::conv2d(
let weight_device = B::float_device(&weight); output_grad,
let [batch_size, _channels_in, _, _] = B::float_shape(&x).dims;
let [_, channels_out, height_out, width_out] = B::float_shape(&output_grad).dims;
let x_grad = B::conv2d(
output_grad.clone(),
weight, weight,
None, None,
ConvOptions::new( ConvOptions::new(
@ -348,55 +389,57 @@ pub(crate) fn conv_transpose2d_backward<B: Backend>(
options.dilation, options.dilation,
options.groups, options.groups,
), ),
); )
}
let weight_grad = match options.groups == 1 { /// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `weight`.
true => conv_transpose2d_weight_grad_no_groups::<B>( pub(crate) fn conv_transpose2d_weight_backward<B: Backend>(
x, x: FloatTensor<B, 4>,
output_grad.clone(), weight: FloatTensor<B, 4>,
weight_shape, output_grad: FloatTensor<B, 4>,
options, options: ConvTransposeOptions<2>,
), ) -> FloatTensor<B, 4> {
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
match options.groups == 1 {
true => conv_transpose2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
false => conv_transpose2d_weight_grad_groups::<B>( false => conv_transpose2d_weight_grad_groups::<B>(
x, x,
B::float_zeros(weight_shape, &weight_device), B::float_zeros(weight_shape, &weight_device),
output_grad.clone(), output_grad,
options, options,
), ),
}; }
Conv2dBackward::new(
x_grad,
weight_grad,
bias.map(|b| {
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([channels_out, batch_size * height_out * width_out]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&b))
}),
)
} }
/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass using convolutions. /// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `bias`.
pub(crate) fn conv_transpose3d_backward<B: Backend>( pub(crate) fn conv_transpose2d_bias_backward<B: Backend>(
x: FloatTensor<B, 5>, x: FloatTensor<B, 4>,
bias: FloatTensor<B, 1>,
output_grad: FloatTensor<B, 4>,
) -> FloatTensor<B, 1> {
let [batch_size, _channels_in, _, _] = B::float_shape(&x).dims;
let [_, channels_out, height_out, width_out] = B::float_shape(&output_grad).dims;
let grad = B::float_swap_dims(output_grad, 0, 1);
let grad = B::float_reshape(
grad,
Shape::new([channels_out, batch_size * height_out * width_out]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&bias))
}
/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `x`.
pub(crate) fn conv_transpose3d_x_backward<B: Backend>(
weight: FloatTensor<B, 5>, weight: FloatTensor<B, 5>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 5>, output_grad: FloatTensor<B, 5>,
options: ConvTransposeOptions<3>, options: ConvTransposeOptions<3>,
) -> Conv3dBackward<B> { ) -> FloatTensor<B, 5> {
let weight_shape = B::float_shape(&weight); B::conv3d(
let weight_device = B::float_device(&weight); output_grad,
let [batch_size, _channels_in, _, _, _] = B::float_shape(&x).dims;
let [_, channels_out, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
let x_grad = B::conv3d(
output_grad.clone(),
weight, weight,
None, None,
ConvOptions::new( ConvOptions::new(
@ -405,40 +448,50 @@ pub(crate) fn conv_transpose3d_backward<B: Backend>(
options.dilation, options.dilation,
options.groups, options.groups,
), ),
); )
}
let weight_grad = match options.groups == 1 { /// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `weight`.
true => conv_transpose3d_weight_grad_no_groups::<B>( pub(crate) fn conv_transpose3d_weight_backward<B: Backend>(
x, x: FloatTensor<B, 5>,
output_grad.clone(), weight: FloatTensor<B, 5>,
weight_shape, output_grad: FloatTensor<B, 5>,
options, options: ConvTransposeOptions<3>,
), ) -> FloatTensor<B, 5> {
let weight_shape = B::float_shape(&weight);
let weight_device = B::float_device(&weight);
match options.groups == 1 {
true => conv_transpose3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
false => conv_transpose3d_weight_grad_groups::<B>( false => conv_transpose3d_weight_grad_groups::<B>(
x, x,
B::float_zeros(weight_shape, &weight_device), B::float_zeros(weight_shape, &weight_device),
output_grad.clone(), output_grad,
options, options,
), ),
}; }
}
Conv3dBackward::new( /// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `bias`.
x_grad, pub(crate) fn conv_transpose3d_bias_backward<B: Backend>(
weight_grad, x: FloatTensor<B, 5>,
bias.map(|b| { bias: FloatTensor<B, 1>,
let grad = B::float_swap_dims(output_grad, 0, 1); output_grad: FloatTensor<B, 5>,
let grad = B::float_reshape( ) -> FloatTensor<B, 1> {
grad, let [batch_size, _channels_in, _, _, _] = B::float_shape(&x).dims;
Shape::new([ let [_, channels_out, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
channels_out,
batch_size * depth_out * height_out * width_out,
]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&b)) let grad = B::float_swap_dims(output_grad, 0, 1);
}), let grad = B::float_reshape(
) grad,
Shape::new([
channels_out,
batch_size * depth_out * height_out * width_out,
]),
);
let grad = B::float_sum_dim(grad, 1);
B::float_reshape(grad, B::float_shape(&bias))
} }
/// Execute a 1D convolution using a 2D convolution. /// Execute a 1D convolution using a 2D convolution.