mirror of https://github.com/tracel-ai/burn.git
feat: support neg in auto diff
This commit is contained in:
parent
445a9fbcbe
commit
9af7bde608
|
@ -0,0 +1,67 @@
|
|||
use crate::{
|
||||
backend::autodiff::{ADCompatibleTensor, ADElement, ADTensor},
|
||||
define_ops, execute_ops,
|
||||
ops::{UnaryOps, UnaryOpsNodeState},
|
||||
register_ops, TensorOpsNeg,
|
||||
};
|
||||
|
||||
register_ops!(
|
||||
ops UnaryOps<T, T>,
|
||||
name ADTensorNegOps,
|
||||
partial |state: &UnaryOpsNodeState<T, T>|{
|
||||
state.output.grad().neg()
|
||||
},
|
||||
);
|
||||
|
||||
impl<T, P, const D: usize> TensorOpsNeg<P, D> for ADTensor<P, D, T>
|
||||
where
|
||||
T: ADCompatibleTensor<P, D>,
|
||||
P: ADElement,
|
||||
{
|
||||
fn neg(&self) -> Self {
|
||||
let node = execute_ops!(
|
||||
input self.node.clone(),
|
||||
out TensorOpsNeg::neg(&self.tensor()),
|
||||
ops ADTensorNegOps::new(),
|
||||
);
|
||||
self.from_existing(node)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, P, const D: usize> std::ops::Neg for ADTensor<P, D, T>
|
||||
where
|
||||
T: ADCompatibleTensor<P, D> + 'static,
|
||||
P: ADElement + 'static,
|
||||
{
|
||||
type Output = ADTensor<P, D, T>;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
TensorOpsNeg::neg(&self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
backend::autodiff::helper::ADTchTensor, Data, TensorBase, TensorOpsMatmul, TensorOpsNeg,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn should_diff_neg() {
|
||||
let data_1 = Data::<f64, 2>::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2 = Data::<f64, 2>::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let tensor_1 = ADTchTensor::from_data(data_1.clone());
|
||||
let tensor_2 = ADTchTensor::from_data(data_2.clone());
|
||||
|
||||
let tensor_3 = tensor_1.matmul(&tensor_2.neg());
|
||||
let tensor_4 = tensor_3.neg();
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = grads.wrt(&tensor_1).unwrap();
|
||||
let grad_2 = grads.wrt(&tensor_2).unwrap();
|
||||
|
||||
assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]]));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue