Addition of abs tensor opperator #506 (#553)

This commit is contained in:
mmalczak 2023-08-02 00:25:14 +02:00 committed by GitHub
parent 87125da6c9
commit 73fb0eaa7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 205 additions and 3 deletions

View File

@ -305,4 +305,7 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
) -> (B::IntTensorPrimitive<D>, B::IntTensorPrimitive<D>) {
B::int_min_dim_with_indices(tensor, dim)
}
fn int_abs<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<D> {
B::int_abs(tensor)
}
}

View File

@ -1147,6 +1147,28 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}
}
fn abs<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
#[derive(Debug)]
struct Abs;
impl<B: Backend, const D: usize> Backward<B, D, 1> for Abs {
type State = B::TensorPrimitive<D>;
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| B::mul(grad, ops.state));
}
}
match Abs.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => {
let output = B::abs(tensor.primitive.clone());
let state = B::div(tensor.primitive, output.clone());
prep.finish(state, output)
}
OpsKind::UnTracked(prep) => prep.finish(B::abs(tensor.primitive)),
}
}
fn cos<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
#[derive(Debug)]
struct Cos;

View File

@ -0,0 +1,28 @@
#[burn_tensor_testgen::testgen(ad_abs)]
mod tests {
use super::*;
use burn_tensor::Data;
#[test]
fn should_diff_abs() {
let data_1 = Data::<f32, 2>::from([[0.0, -1.0], [3.0, 4.0]]);
let data_2 = Data::<f32, 2>::from([[6.0, 7.0], [9.0, -10.0]]);
let tensor_1 = TestADTensor::from_data(data_1).require_grad();
let tensor_2 = TestADTensor::from_data(data_2).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[71.0, 107.0], [71.0, 107.0]]), 3);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[84.0, 42.0], [90.0, 54.0]]), 3);
}
}

View File

@ -1,5 +1,6 @@
#![allow(missing_docs)]
mod abs;
mod add;
mod aggregation;
mod avgpool1d;
@ -85,6 +86,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_sin!();
burn_autodiff::testgen_ad_softmax!();
burn_autodiff::testgen_ad_sqrt!();
burn_autodiff::testgen_ad_abs!();
burn_autodiff::testgen_ad_sub!();
burn_autodiff::testgen_ad_tanh!();
burn_autodiff::testgen_ad_transpose!();

View File

@ -1,6 +1,6 @@
use burn_tensor::Element;
use libm::{exp, log, log1p, pow, sqrt};
use libm::{expf, log1pf, logf, powf, sqrtf};
use libm::{exp, fabs, log, log1p, pow, sqrt};
use libm::{expf, fabsf, log1pf, logf, powf, sqrtf};
use ndarray::LinalgScalar;
pub(crate) trait FloatNdArrayElement: NdArrayElement + LinalgScalar
@ -28,6 +28,8 @@ pub(crate) trait ExpElement {
fn powf_elem(self, value: f32) -> Self;
fn powi_elem(self, value: i32) -> Self;
fn sqrt_elem(self) -> Self;
fn abs_elem(self) -> Self;
fn int_abs_elem(self) -> Self;
}
impl FloatNdArrayElement for f64 {}
@ -76,6 +78,16 @@ macro_rules! make_elem {
fn sqrt_elem(self) -> Self {
sqrt(self as f64) as $ty
}
#[inline(always)]
fn abs_elem(self) -> Self {
fabs(self as f64) as $ty
}
#[inline(always)]
fn int_abs_elem(self) -> Self {
(self as i64).abs() as $ty
}
}
};
(
@ -120,6 +132,16 @@ macro_rules! make_elem {
fn sqrt_elem(self) -> Self {
sqrtf(self as f32) as $ty
}
#[inline(always)]
fn abs_elem(self) -> Self {
fabsf(self as f32) as $ty
}
#[inline(always)]
fn int_abs_elem(self) -> Self {
(self as i32).abs() as $ty
}
}
};
}

View File

@ -5,6 +5,7 @@ use burn_tensor::ops::IntTensorOps;
use core::ops::Range;
// Current crate
use crate::element::ExpElement;
use crate::element::FloatNdArrayElement;
use crate::NdArrayDevice;
use crate::{tensor::NdArrayTensor, NdArrayBackend};
@ -356,4 +357,10 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::clamp(tensor, min, max)
}
fn int_abs<const D: usize>(tensor: NdArrayTensor<i64, D>) -> NdArrayTensor<i64, D> {
let array = tensor.array.mapv_into(|a| a.int_abs_elem()).into_shared();
NdArrayTensor::new(array)
}
}

View File

@ -380,6 +380,12 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
NdArrayTensor::new(array)
}
fn abs<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor.array.mapv_into(|a| a.abs_elem()).into_shared();
NdArrayTensor::new(array)
}
fn cos<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array

View File

@ -357,4 +357,8 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
) -> TchTensor<i64, D> {
TchOps::clamp(tensor, min, max)
}
fn int_abs<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<i64, D> {
tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
}
}

View File

@ -390,6 +390,10 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt())
}
fn abs<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
}
fn cos<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos())
}

View File

@ -443,6 +443,11 @@ where
pub fn clamp_max(self, max: K::Elem) -> Self {
Self::new(K::clamp_max(self.primitive, max))
}
/// Apply element wise absolute value operation
pub fn abs(self) -> Self {
Self::new(K::abs(self.primitive))
}
}
/// Trait that list all operations that can be applied on all numerical tensors.
@ -1445,6 +1450,26 @@ where
/// [Tensor::clamp_max](Tensor::clamp_max) function, which is more high-level and designed for public use.
fn clamp_max<const D: usize>(tensor: Self::Primitive<D>, max: Self::Elem)
-> Self::Primitive<D>;
/// Calculate absolute value on all elements of a tensor
///
/// # Arguments
///
/// * `tensor` - The tensor to apply abs to.
///
/// # Returns
///
/// A tensor with absolute values.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For calculating abs of the elements of a tensor, users should prefer the [Tensor::abs](Tensor::abs) function,
/// which is more high-level and designed for public use.
fn abs<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D>;
}
impl<B: Backend> Numeric<B> for Int {
@ -1695,6 +1720,10 @@ impl<B: Backend> Numeric<B> for Int {
) -> Self::Primitive<D> {
B::int_clamp_max(tensor, max)
}
fn abs<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D> {
B::int_abs(tensor)
}
}
impl<B: Backend> Numeric<B> for Float {
@ -1946,6 +1975,10 @@ impl<B: Backend> Numeric<B> for Float {
) -> Self::Primitive<D> {
B::clamp_max(tensor, max)
}
fn abs<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D> {
B::abs(tensor)
}
}
impl<B, const D: usize, K> core::ops::Add<Self> for Tensor<B, D, K>

View File

@ -866,4 +866,15 @@ pub trait IntTensorOps<B: Backend> {
(values, indices)
}
/// Returns a new tensor with absolute values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take absolute value of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with absolute values.
fn int_abs<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<D>;
}

View File

@ -916,6 +916,17 @@ pub trait TensorOps<B: Backend> {
/// A tensor with the same shape as `tensor` with square root values.
fn sqrt<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
/// Returns a new tensor with absolute values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take absolute value of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with absolute values.
fn abs<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
/// Returns a new tensor with cosine values.
///
/// # Arguments

View File

@ -56,6 +56,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_sin!();
burn_tensor::testgen_slice!();
burn_tensor::testgen_sqrt!();
burn_tensor::testgen_abs!();
burn_tensor::testgen_squeeze!();
burn_tensor::testgen_sub!();
burn_tensor::testgen_tanh!();

View File

@ -0,0 +1,24 @@
#[burn_tensor_testgen::testgen(abs)]
mod tests {
use super::*;
use burn_tensor::{Data, Int, Tensor};
#[test]
fn should_support_abs_ops() {
let data = Data::from([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = tensor.abs().into_data();
let data_expected = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
assert_eq!(data_expected, data_actual);
let data = Data::from([[0, -1, 2], [3, 4, -5]]);
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data);
let data_actual = tensor.abs().into_data();
let data_expected = Data::from([[0, 1, 2], [3, 4, 5]]);
assert_eq!(data_expected, data_actual);
}
}

View File

@ -1,3 +1,4 @@
mod abs;
mod add;
mod aggregation;
mod arange;

View File

@ -371,6 +371,17 @@ where
unary_default::<Sqrt, F, D>(tensor)
}
fn abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Abs, func "abs");
unary_inplace!(AbsInplace, func "abs");
if tensor.can_mut() {
return unary_inplace_default::<AbsInplace, F, D>(tensor);
}
unary_default::<Abs, F, D>(tensor)
}
fn cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Cos, func "cos");
unary_inplace!(CosInplace, func "cos");

View File

@ -1,7 +1,8 @@
use super::{numeric, BoolTensor, Device, IntElem, IntTensor};
use crate::kernel::{unary_default, unary_inplace_default};
use crate::{
element::{FloatElement, IntElement},
kernel, GraphicsApi, WgpuBackend,
kernel, unary, unary_inplace, GraphicsApi, WgpuBackend,
};
use burn_tensor::{ops::IntTensorOps, Data, Shape};
use std::ops::Range;
@ -294,4 +295,15 @@ where
// ) -> IntTensor<Self, D> {
// kernel::clamp(tensor, min, max)
// }
fn int_abs<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
unary!(IntAbs, func "abs");
unary_inplace!(IntAbsInplace, func "abs");
if tensor.can_mut() {
return unary_inplace_default::<IntAbsInplace, I, D>(tensor);
}
unary_default::<IntAbs, I, D>(tensor)
}
}