From ab39b8779b3346178896a724f21745ebfe7ef9a3 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Sat, 12 Nov 2022 12:27:31 -0500 Subject: [PATCH] refactor: erf ops (#99) --- .../src/tensor/backend/autodiff/ops/erf.rs | 60 ------------------- .../src/tensor/backend/autodiff/ops/mod.rs | 1 - .../src/tensor/backend/autodiff/ops/tensor.rs | 31 ++++++++++ burn-tensor/src/tensor/backend/base.rs | 1 - .../src/tensor/backend/ndarray/ops/erf.rs | 19 ------ .../src/tensor/backend/ndarray/ops/mod.rs | 1 - .../src/tensor/backend/ndarray/tensor_ops.rs | 10 ++++ burn-tensor/src/tensor/backend/tch/ops/erf.rs | 21 ------- burn-tensor/src/tensor/backend/tch/ops/mod.rs | 1 - .../src/tensor/backend/tch/tensor_ops.rs | 4 ++ burn-tensor/src/tensor/base.rs | 2 +- burn-tensor/src/tensor/ops/base.rs | 5 +- burn-tensor/tests/tensor/grad/erf.rs | 25 ++++++++ burn-tensor/tests/tensor/grad/mod.rs | 1 + 14 files changed, 73 insertions(+), 109 deletions(-) delete mode 100644 burn-tensor/src/tensor/backend/autodiff/ops/erf.rs delete mode 100644 burn-tensor/src/tensor/backend/ndarray/ops/erf.rs delete mode 100644 burn-tensor/src/tensor/backend/tch/ops/erf.rs create mode 100644 burn-tensor/tests/tensor/grad/erf.rs diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/erf.rs b/burn-tensor/src/tensor/backend/autodiff/ops/erf.rs deleted file mode 100644 index 767780c00..000000000 --- a/burn-tensor/src/tensor/backend/autodiff/ops/erf.rs +++ /dev/null @@ -1,60 +0,0 @@ -use crate::tensor::backend::Backend; -use crate::{ - execute_ops, - graph::ops::{UnaryOps, UnaryOpsNodeState}, - register_ops, - tensor::{backend::autodiff::ADTensor, ops::*}, - ElementConversion, -}; - -register_ops!( - ops UnaryOps, - name ADTensorErfOps, - partial |state: &UnaryOpsNodeState, B::TensorPrimitive>|{ - let value = state.input.value(); - let exponent = B::neg(&B::powf(&value, 2.0)); - let numerator = B::mul_scalar(&B::exp(&exponent), &2.0.to_elem()); - let denominator = std::f64::consts::PI.sqrt().to_elem(); - let value = B::div_scalar(&numerator, &denominator); - - B::mul(&state.output.grad(), &value) - }, -); - -impl TensorOpsErf for ADTensor { - fn erf(&self) -> Self { - execute_ops!( - input self.node.clone(), - out TensorOpsErf::erf(&self.tensor()), - ops ADTensorErfOps::::new(), - ) - } -} - -#[cfg(test)] -mod tests { - use crate::tensor::{backend::autodiff::helper::TestADTensor, Data}; - - #[test] - fn should_diff_erf() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - - let tensor_1 = TestADTensor::from_data(data_1); - let tensor_2 = TestADTensor::from_data(data_2); - - let tensor_3 = tensor_1.matmul(&tensor_2.erf()); - let tensor_4 = tensor_3.matmul(&tensor_2); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[8.0, 8.0], [8.0, 8.0]]), 3); - } -} diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs index 73f147821..3d5294683 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs @@ -1,7 +1,6 @@ mod base; mod cat; mod creation; -mod erf; mod module; mod tensor; diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs b/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs index 4104e0c88..b2837f0b9 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs @@ -1014,4 +1014,35 @@ impl TensorOps> for ADBackendDecorator { unary_ops_wrapper(tensor.node.clone(), output, ops) } + + fn erf( + tensor: & as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + #[derive(Default, Debug)] + struct Backward { + _b: B, + } + + impl UnaryOps, B::TensorPrimitive> + for Backward + { + fn partial( + &self, + state: &UnaryOpsNodeState, B::TensorPrimitive>, + ) -> B::TensorPrimitive { + let value = state.input.value(); + let exponent = B::neg(&B::powf(&value, 2.0)); + let numerator = B::mul_scalar(&B::exp(&exponent), &2.0.to_elem()); + let denominator = std::f64::consts::PI.sqrt().to_elem(); + let value = B::div_scalar(&numerator, &denominator); + + B::mul(&state.output.grad(), &value) + } + } + + let output = B::erf(tensor.tensor_ref()); + let ops = Backward::::default(); + + unary_ops_wrapper(tensor.node.clone(), output, ops) + } } diff --git a/burn-tensor/src/tensor/backend/base.rs b/burn-tensor/src/tensor/backend/base.rs index 4dd3b11ff..1be5d7d51 100644 --- a/burn-tensor/src/tensor/backend/base.rs +++ b/burn-tensor/src/tensor/backend/base.rs @@ -24,7 +24,6 @@ pub trait Backend: + Zeros> + Ones> + TensorOpsCat - + TensorOpsErf + ReLU + Clone + Send diff --git a/burn-tensor/src/tensor/backend/ndarray/ops/erf.rs b/burn-tensor/src/tensor/backend/ndarray/ops/erf.rs deleted file mode 100644 index 1b2b8433c..000000000 --- a/burn-tensor/src/tensor/backend/ndarray/ops/erf.rs +++ /dev/null @@ -1,19 +0,0 @@ -use crate::{ - tensor::{backend::ndarray::NdArrayTensor, ops::*}, - ElementConversion, NdArrayElement, -}; - -impl TensorOpsErf for NdArrayTensor -where - E: NdArrayElement, -{ - fn erf(&self) -> Self { - let array = self - .array - .mapv(|a| libm::erf(a.to_f64().unwrap()).to_elem()) - .into_shared(); - let shape = self.shape; - - Self { array, shape } - } -} diff --git a/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs b/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs index 1020a84f8..5c68c64de 100644 --- a/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs @@ -1,3 +1,2 @@ mod cat; mod creation; -mod erf; diff --git a/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs b/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs index 2e30f9c9b..8dfbfe664 100644 --- a/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs +++ b/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs @@ -474,6 +474,16 @@ impl TensorOps> for NdArrayBackend { NdArrayTensor { array, shape } } + + fn erf(tensor: &NdArrayTensor) -> NdArrayTensor { + let array = tensor + .array + .mapv(|a| libm::erf(a.to_f64().unwrap()).to_elem()) + .into_shared(); + let shape = tensor.shape; + + NdArrayTensor { array, shape } + } } fn to_slice_args( diff --git a/burn-tensor/src/tensor/backend/tch/ops/erf.rs b/burn-tensor/src/tensor/backend/tch/ops/erf.rs deleted file mode 100644 index 29d4b9e50..000000000 --- a/burn-tensor/src/tensor/backend/tch/ops/erf.rs +++ /dev/null @@ -1,21 +0,0 @@ -use crate::{ - tensor::{backend::tch::TchTensor, ops::*}, - TchElement, -}; - -impl TensorOpsErf for TchTensor -where - E: TchElement, -{ - fn erf(&self) -> Self { - let tensor = self.tensor.erf(); - let kind = self.kind; - let shape = self.shape; - - Self { - tensor, - shape, - kind, - } - } -} diff --git a/burn-tensor/src/tensor/backend/tch/ops/mod.rs b/burn-tensor/src/tensor/backend/tch/ops/mod.rs index 1020a84f8..5c68c64de 100644 --- a/burn-tensor/src/tensor/backend/tch/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/tch/ops/mod.rs @@ -1,3 +1,2 @@ mod cat; mod creation; -mod erf; diff --git a/burn-tensor/src/tensor/backend/tch/tensor_ops.rs b/burn-tensor/src/tensor/backend/tch/tensor_ops.rs index ac6becc9c..8ebc7f7f1 100644 --- a/burn-tensor/src/tensor/backend/tch/tensor_ops.rs +++ b/burn-tensor/src/tensor/backend/tch/tensor_ops.rs @@ -378,6 +378,10 @@ impl TensorOps> for TchBackend { fn powf(tensor: &TchTensor, value: f32) -> TchTensor { to_tensor(tensor.tensor.pow_tensor_scalar(value as f64)) } + + fn erf(tensor: &TchTensor) -> TchTensor { + to_tensor(tensor.tensor.erf()) + } } fn to_tensor(tensor: tch::Tensor) -> TchTensor { diff --git a/burn-tensor/src/tensor/base.rs b/burn-tensor/src/tensor/base.rs index f92ea458c..c6d4b81d4 100644 --- a/burn-tensor/src/tensor/base.rs +++ b/burn-tensor/src/tensor/base.rs @@ -75,7 +75,7 @@ where /// /// `y = erf(x)` pub fn erf(&self) -> Self { - Self::new(self.value.erf()) + Self::new(B::erf(&self.value)) } /// Applies element wise power operation. diff --git a/burn-tensor/src/tensor/ops/base.rs b/burn-tensor/src/tensor/ops/base.rs index 835fcfe59..3d1604ef3 100644 --- a/burn-tensor/src/tensor/ops/base.rs +++ b/burn-tensor/src/tensor/ops/base.rs @@ -195,16 +195,13 @@ pub trait TensorOps { fn exp(tensor: &B::TensorPrimitive) -> B::TensorPrimitive; fn log(tensor: &B::TensorPrimitive) -> B::TensorPrimitive; fn powf(tensor: &B::TensorPrimitive, value: f32) -> B::TensorPrimitive; + fn erf(tensor: &B::TensorPrimitive) -> B::TensorPrimitive; } pub trait TensorOpsCat { fn cat(tensors: Vec<&Self>, dim: usize) -> Self; } -pub trait TensorOpsErf { - fn erf(&self) -> Self; -} - pub trait Zeros { fn zeros(&self) -> T; } diff --git a/burn-tensor/tests/tensor/grad/erf.rs b/burn-tensor/tests/tensor/grad/erf.rs new file mode 100644 index 000000000..262e0c736 --- /dev/null +++ b/burn-tensor/tests/tensor/grad/erf.rs @@ -0,0 +1,25 @@ +use crate::tensor::TestADTensor; +use burn_tensor::Data; + +#[test] +fn should_diff_erf() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + + let tensor_1 = TestADTensor::from_data(data_1); + let tensor_2 = TestADTensor::from_data(data_2); + + let tensor_3 = tensor_1.matmul(&tensor_2.erf()); + let tensor_4 = tensor_3.matmul(&tensor_2); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[8.0, 8.0], [8.0, 8.0]]), 3); +} diff --git a/burn-tensor/tests/tensor/grad/mod.rs b/burn-tensor/tests/tensor/grad/mod.rs index 47c127623..c3ce57396 100644 --- a/burn-tensor/tests/tensor/grad/mod.rs +++ b/burn-tensor/tests/tensor/grad/mod.rs @@ -2,6 +2,7 @@ mod add; mod aggregation; mod cross_entropy; mod div; +mod erf; mod exp; mod index; mod log;