refactor: erf ops (#99)

This commit is contained in:
Nathaniel Simard 2022-11-12 12:27:31 -05:00 committed by GitHub
parent ef01a4ed3f
commit ab39b8779b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 73 additions and 109 deletions

View File

@ -1,60 +0,0 @@
use crate::tensor::backend::Backend;
use crate::{
execute_ops,
graph::ops::{UnaryOps, UnaryOpsNodeState},
register_ops,
tensor::{backend::autodiff::ADTensor, ops::*},
ElementConversion,
};
register_ops!(
ops UnaryOps,
name ADTensorErfOps,
partial |state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>|{
let value = state.input.value();
let exponent = B::neg(&B::powf(&value, 2.0));
let numerator = B::mul_scalar(&B::exp(&exponent), &2.0.to_elem());
let denominator = std::f64::consts::PI.sqrt().to_elem();
let value = B::div_scalar(&numerator, &denominator);
B::mul(&state.output.grad(), &value)
},
);
impl<B: Backend, const D: usize> TensorOpsErf<B::Elem, D> for ADTensor<D, B> {
fn erf(&self) -> Self {
execute_ops!(
input self.node.clone(),
out TensorOpsErf::erf(&self.tensor()),
ops ADTensorErfOps::<B, D>::new(),
)
}
}
#[cfg(test)]
mod tests {
use crate::tensor::{backend::autodiff::helper::TestADTensor, Data};
#[test]
fn should_diff_erf() {
let data_1 = Data::<f64, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::<f64, 2>::from([[6.0, 7.0], [9.0, 10.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2.erf());
let tensor_4 = tensor_3.matmul(&tensor_2);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[8.0, 8.0], [8.0, 8.0]]), 3);
}
}

View File

@ -1,7 +1,6 @@
mod base;
mod cat;
mod creation;
mod erf;
mod module;
mod tensor;

View File

@ -1014,4 +1014,35 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
unary_ops_wrapper(tensor.node.clone(), output, ops)
}
fn erf<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
#[derive(Default, Debug)]
struct Backward<B: Backend, const D: usize> {
_b: B,
}
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for Backward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
let value = state.input.value();
let exponent = B::neg(&B::powf(&value, 2.0));
let numerator = B::mul_scalar(&B::exp(&exponent), &2.0.to_elem());
let denominator = std::f64::consts::PI.sqrt().to_elem();
let value = B::div_scalar(&numerator, &denominator);
B::mul(&state.output.grad(), &value)
}
}
let output = B::erf(tensor.tensor_ref());
let ops = Backward::<B, D>::default();
unary_ops_wrapper(tensor.node.clone(), output, ops)
}
}

View File

@ -24,7 +24,6 @@ pub trait Backend:
+ Zeros<Self::TensorPrimitive<D>>
+ Ones<Self::TensorPrimitive<D>>
+ TensorOpsCat<Self::Elem, D>
+ TensorOpsErf<Self::Elem, D>
+ ReLU<Self::Elem, D>
+ Clone
+ Send

View File

@ -1,19 +0,0 @@
use crate::{
tensor::{backend::ndarray::NdArrayTensor, ops::*},
ElementConversion, NdArrayElement,
};
impl<E, const D: usize> TensorOpsErf<E, D> for NdArrayTensor<E, D>
where
E: NdArrayElement,
{
fn erf(&self) -> Self {
let array = self
.array
.mapv(|a| libm::erf(a.to_f64().unwrap()).to_elem())
.into_shared();
let shape = self.shape;
Self { array, shape }
}
}

View File

@ -1,3 +1,2 @@
mod cat;
mod creation;
mod erf;

View File

@ -474,6 +474,16 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
NdArrayTensor { array, shape }
}
fn erf<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv(|a| libm::erf(a.to_f64().unwrap()).to_elem())
.into_shared();
let shape = tensor.shape;
NdArrayTensor { array, shape }
}
}
fn to_slice_args<const D1: usize, const D2: usize>(

View File

@ -1,21 +0,0 @@
use crate::{
tensor::{backend::tch::TchTensor, ops::*},
TchElement,
};
impl<E, const D: usize> TensorOpsErf<E, D> for TchTensor<E, D>
where
E: TchElement,
{
fn erf(&self) -> Self {
let tensor = self.tensor.erf();
let kind = self.kind;
let shape = self.shape;
Self {
tensor,
shape,
kind,
}
}
}

View File

@ -1,3 +1,2 @@
mod cat;
mod creation;
mod erf;

View File

@ -378,6 +378,10 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
fn powf<const D: usize>(tensor: &TchTensor<E, D>, value: f32) -> TchTensor<E, D> {
to_tensor(tensor.tensor.pow_tensor_scalar(value as f64))
}
fn erf<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
to_tensor(tensor.tensor.erf())
}
}
fn to_tensor<const D: usize, E: TchElement>(tensor: tch::Tensor) -> TchTensor<E, D> {

View File

@ -75,7 +75,7 @@ where
///
/// `y = erf(x)`
pub fn erf(&self) -> Self {
Self::new(self.value.erf())
Self::new(B::erf(&self.value))
}
/// Applies element wise power operation.

View File

@ -195,16 +195,13 @@ pub trait TensorOps<B: Backend> {
fn exp<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn log<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn powf<const D: usize>(tensor: &B::TensorPrimitive<D>, value: f32) -> B::TensorPrimitive<D>;
fn erf<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
}
pub trait TensorOpsCat<E, const D: usize> {
fn cat(tensors: Vec<&Self>, dim: usize) -> Self;
}
pub trait TensorOpsErf<E, const D: usize> {
fn erf(&self) -> Self;
}
pub trait Zeros<T> {
fn zeros(&self) -> T;
}

View File

@ -0,0 +1,25 @@
use crate::tensor::TestADTensor;
use burn_tensor::Data;
#[test]
fn should_diff_erf() {
let data_1 = Data::<f32, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::<f32, 2>::from([[6.0, 7.0], [9.0, 10.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2.erf());
let tensor_4 = tensor_3.matmul(&tensor_2);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[8.0, 8.0], [8.0, 8.0]]), 3);
}

View File

@ -2,6 +2,7 @@ mod add;
mod aggregation;
mod cross_entropy;
mod div;
mod erf;
mod exp;
mod index;
mod log;