From f939b5c775b70f8e7f01765f9a4d21f9356ea7a2 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 25 Jul 2022 20:55:13 -0400 Subject: [PATCH] feat: support reshape in AD --- .../src/tensor/backend/autodiff/ops/macros.rs | 2 +- .../src/tensor/backend/autodiff/ops/mod.rs | 1 + .../tensor/backend/autodiff/ops/reshape.rs | 86 +++++++++++++++++++ 3 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 burn-tensor/src/tensor/backend/autodiff/ops/reshape.rs diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs b/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs index fc2e165c3..8dd138ffe 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs @@ -97,7 +97,7 @@ macro_rules! register_ops { P: $crate::tensor::backend::autodiff::ADElement, T: $crate::tensor::backend::autodiff::ADCompatibleTensor, { - fn partial(&self, state: &$crate::graph::ops::UnaryRecordedState) -> T { + fn partial(&self, state: &$crate::graph::ops::UnaryOpsNodeState) -> T { $partial(state) } } diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs index fa67969aa..84c29d6b4 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs @@ -1,6 +1,7 @@ mod add; mod matmul; mod mul; +mod reshape; mod sub; mod macros; diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/reshape.rs b/burn-tensor/src/tensor/backend/autodiff/ops/reshape.rs new file mode 100644 index 000000000..c871a8e48 --- /dev/null +++ b/burn-tensor/src/tensor/backend/autodiff/ops/reshape.rs @@ -0,0 +1,86 @@ +use std::sync::Arc; + +use crate::{ + backend::autodiff::{ADCompatibleTensor, ADElement, ADKind, ADTensor}, + node::{ForwardNode, ForwardNodeState}, + ops::{ForwardUnaryRecordedOps, UnaryOps, UnaryOpsNodeState}, + Shape, TensorOpsReshape, +}; + +#[derive(Debug)] +struct ADTensorOpsReshape { + shape: Shape, + _kind: ADKind

, +} + +impl ADTensorOpsReshape { + pub fn new(shape: Shape) -> Self { + Self { + shape, + _kind: ADKind::new(), + } + } +} + +impl UnaryOps for ADTensorOpsReshape +where + P: ADElement, + T1: ADCompatibleTensor + TensorOpsReshape, + T2: ADCompatibleTensor + TensorOpsReshape, +{ + fn partial(&self, state: &UnaryOpsNodeState) -> T1 { + state.output.grad().reshape(self.shape.clone()) + } +} + +impl TensorOpsReshape> + for ADTensor +where + P: ADElement, + T1: ADCompatibleTensor + TensorOpsReshape, + T2: ADCompatibleTensor + TensorOpsReshape, +{ + fn reshape(&self, shape: Shape) -> ADTensor { + let input = self.tensor(); + let out = TensorOpsReshape::reshape(&input, shape.clone()); + + let state = ForwardNodeState::new(out); + + let ops = ADTensorOpsReshape::::new(self.shape.clone()); + let ops = Arc::new(ops); + let ops = ForwardUnaryRecordedOps::new(self.node.clone(), ops); + let ops = Arc::new(ops); + + let node = ForwardNode::from_unary(&self.node, state, ops); + let node = Arc::new(node); + + let shape = shape.clone(); + let kind = self.kind.clone(); + + ADTensor { node, shape, kind } + } +} +#[cfg(test)] +mod tests { + use super::*; + use crate::{backend::autodiff::helper::ADTchTensor, Data, TensorBase, TensorOpsMatmul}; + + #[test] + fn should_diff_mul() { + let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2: Data = Data::from([4.0, 7.0, 2.0, 3.0]); + + let tensor_1 = ADTchTensor::from_data(data_1.clone()); + let tensor_2 = ADTchTensor::from_data(data_2.clone()); + + let tensor_3 = tensor_2.reshape(Shape::new([2, 2])); + let tensor_4 = &tensor_1.matmul(&tensor_3); + let grads = tensor_4.backward(); + + let grad_1 = grads.wrt(&tensor_1).unwrap(); + let grad_2 = grads.wrt(&tensor_2).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); + assert_eq!(grad_2.to_data(), Data::from([3.0, 3.0, 10.0, 10.0])); + } +}