diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs b/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs index fc2e165c3..8dd138ffe 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs @@ -97,7 +97,7 @@ macro_rules! register_ops { P: $crate::tensor::backend::autodiff::ADElement, T: $crate::tensor::backend::autodiff::ADCompatibleTensor
,
{
- fn partial(&self, state: &$crate::graph::ops::UnaryRecordedState {
+ shape: Shape ,
+}
+
+impl {
+ pub fn new(shape: Shape
+where
+ P: ADElement,
+ T1: ADCompatibleTensor + TensorOpsReshape ,
+ T2: ADCompatibleTensor + TensorOpsReshape ,
+{
+ fn partial(&self, state: &UnaryOpsNodeState TensorOpsReshape >
+ for ADTensor
+where
+ P: ADElement,
+ T1: ADCompatibleTensor + TensorOpsReshape ,
+ T2: ADCompatibleTensor + TensorOpsReshape ,
+{
+ fn reshape(&self, shape: Shape {
+ let input = self.tensor();
+ let out = TensorOpsReshape::reshape(&input, shape.clone());
+
+ let state = ForwardNodeState::new(out);
+
+ let ops = ADTensorOpsReshape:: ::new(self.shape.clone());
+ let ops = Arc::new(ops);
+ let ops = ForwardUnaryRecordedOps::new(self.node.clone(), ops);
+ let ops = Arc::new(ops);
+
+ let node = ForwardNode::from_unary(&self.node, state, ops);
+ let node = Arc::new(node);
+
+ let shape = shape.clone();
+ let kind = self.kind.clone();
+
+ ADTensor { node, shape, kind }
+ }
+}
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{backend::autodiff::helper::ADTchTensor, Data, TensorBase, TensorOpsMatmul};
+
+ #[test]
+ fn should_diff_mul() {
+ let data_1: Data