mirror of https://github.com/tracel-ai/burn.git
feat: support reshape in AD
This commit is contained in:
parent
ae994f367a
commit
f939b5c775
|
@ -97,7 +97,7 @@ macro_rules! register_ops {
|
|||
P: $crate::tensor::backend::autodiff::ADElement,
|
||||
T: $crate::tensor::backend::autodiff::ADCompatibleTensor<P, D>,
|
||||
{
|
||||
fn partial(&self, state: &$crate::graph::ops::UnaryRecordedState<T, T>) -> T {
|
||||
fn partial(&self, state: &$crate::graph::ops::UnaryOpsNodeState<T, T>) -> T {
|
||||
$partial(state)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
mod add;
|
||||
mod matmul;
|
||||
mod mul;
|
||||
mod reshape;
|
||||
mod sub;
|
||||
|
||||
mod macros;
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
backend::autodiff::{ADCompatibleTensor, ADElement, ADKind, ADTensor},
|
||||
node::{ForwardNode, ForwardNodeState},
|
||||
ops::{ForwardUnaryRecordedOps, UnaryOps, UnaryOpsNodeState},
|
||||
Shape, TensorOpsReshape,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ADTensorOpsReshape<P, const D1: usize, const D2: usize> {
|
||||
shape: Shape<D1>,
|
||||
_kind: ADKind<P>,
|
||||
}
|
||||
|
||||
impl<P: Default, const D1: usize, const D2: usize> ADTensorOpsReshape<P, D1, D2> {
|
||||
pub fn new(shape: Shape<D1>) -> Self {
|
||||
Self {
|
||||
shape,
|
||||
_kind: ADKind::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1, T2, P, const D1: usize, const D2: usize> UnaryOps<T1, T2> for ADTensorOpsReshape<P, D1, D2>
|
||||
where
|
||||
P: ADElement,
|
||||
T1: ADCompatibleTensor<P, D1> + TensorOpsReshape<P, D1, D2, T2>,
|
||||
T2: ADCompatibleTensor<P, D2> + TensorOpsReshape<P, D2, D1, T1>,
|
||||
{
|
||||
fn partial(&self, state: &UnaryOpsNodeState<T1, T2>) -> T1 {
|
||||
state.output.grad().reshape(self.shape.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<P, const D1: usize, const D2: usize, T1, T2> TensorOpsReshape<P, D1, D2, ADTensor<P, D2, T2>>
|
||||
for ADTensor<P, D1, T1>
|
||||
where
|
||||
P: ADElement,
|
||||
T1: ADCompatibleTensor<P, D1> + TensorOpsReshape<P, D1, D2, T2>,
|
||||
T2: ADCompatibleTensor<P, D2> + TensorOpsReshape<P, D2, D1, T1>,
|
||||
{
|
||||
fn reshape(&self, shape: Shape<D2>) -> ADTensor<P, D2, T2> {
|
||||
let input = self.tensor();
|
||||
let out = TensorOpsReshape::reshape(&input, shape.clone());
|
||||
|
||||
let state = ForwardNodeState::new(out);
|
||||
|
||||
let ops = ADTensorOpsReshape::<P, D1, D2>::new(self.shape.clone());
|
||||
let ops = Arc::new(ops);
|
||||
let ops = ForwardUnaryRecordedOps::new(self.node.clone(), ops);
|
||||
let ops = Arc::new(ops);
|
||||
|
||||
let node = ForwardNode::from_unary(&self.node, state, ops);
|
||||
let node = Arc::new(node);
|
||||
|
||||
let shape = shape.clone();
|
||||
let kind = self.kind.clone();
|
||||
|
||||
ADTensor { node, shape, kind }
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{backend::autodiff::helper::ADTchTensor, Data, TensorBase, TensorOpsMatmul};
|
||||
|
||||
#[test]
|
||||
fn should_diff_mul() {
|
||||
let data_1: Data<f64, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2: Data<f64, 1> = Data::from([4.0, 7.0, 2.0, 3.0]);
|
||||
|
||||
let tensor_1 = ADTchTensor::from_data(data_1.clone());
|
||||
let tensor_2 = ADTchTensor::from_data(data_2.clone());
|
||||
|
||||
let tensor_3 = tensor_2.reshape(Shape::new([2, 2]));
|
||||
let tensor_4 = &tensor_1.matmul(&tensor_3);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = grads.wrt(&tensor_1).unwrap();
|
||||
let grad_2 = grads.wrt(&tensor_2).unwrap();
|
||||
|
||||
assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([3.0, 3.0, 10.0, 10.0]));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue