refactor: aggregation ops (#94)

This commit is contained in:
Nathaniel Simard 2022-11-12 10:23:00 -05:00 committed by GitHub
parent e7094b92ac
commit 9d832a802a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 354 additions and 411 deletions

View File

@ -1,232 +0,0 @@
use crate::backend::autodiff::ADBackendDecorator;
use crate::tensor::ElementConversion;
use crate::Tensor;
use crate::{backend::Backend, tensor::ops::*};
use crate::{
define_ops, execute_ops,
graph::ops::{UnaryOps, UnaryOpsNodeState},
Shape,
};
define_ops! {
name ADTensorOpsMean,
state Shape<D>,
}
define_ops! {
name ADTensorOpsSum,
state Shape<D>,
}
define_ops! {
name ADTensorOpsMeanDim,
state (Shape<D>, usize),
}
define_ops! {
name ADTensorOpsSumDim,
state (Shape<D>, usize),
}
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<1>>
for ADTensorOpsMean<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<1>>,
) -> B::TensorPrimitive<D> {
let grad = state.output.grad();
let ones = B::ones(self.state, B::device(&grad));
let grad: Tensor<B, 1> = Tensor::new(grad);
let val = 1_f64 / self.state.num_elements() as f64;
let ones: Tensor<B, D> = Tensor::new(ones).mul_scalar(val);
ones.mul(&grad.unsqueeze()).value
}
}
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<1>>
for ADTensorOpsSum<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<1>>,
) -> B::TensorPrimitive<D> {
let grad = state.output.grad();
let ones = B::ones(self.state, B::device(&grad));
let grad: Tensor<B, 1> = Tensor::new(grad);
let ones: Tensor<B, D> = Tensor::new(ones);
ones.mul(&grad.unsqueeze()).value
}
}
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for ADTensorOpsMeanDim<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
let (shape, dim) = self.state;
let grad = state.output.grad().sum_dim(dim);
let ones = B::ones(shape, B::device(&grad));
let val = 1_f64 / shape.dims[dim] as f64;
let ones = B::mul_scalar(&ones, &B::Elem::from_elem(val));
B::mul(&ones, &grad)
}
}
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for ADTensorOpsSumDim<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
let (shape, dim) = self.state;
let grad = state.output.grad().sum_dim(dim);
let ones = B::ones(shape, B::device(&grad));
B::mul(&ones, &grad)
}
}
impl<B: Backend, const D: usize> TensorOpsAggregation<ADBackendDecorator<B>, D>
for <ADBackendDecorator<B> as Backend>::TensorPrimitive<D>
{
fn mean(&self) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<1> {
execute_ops!(
input self.node.clone(),
out TensorOpsAggregation::mean(&self.tensor()),
ops ADTensorOpsMean::<B, D>::new(self.shape),
)
}
fn sum(&self) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<1> {
execute_ops!(
input self.node.clone(),
out TensorOpsAggregation::sum(&self.tensor()),
ops ADTensorOpsSum::<B, D>::new(self.shape),
)
}
fn mean_dim(&self, dim: usize) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
execute_ops!(
input self.node.clone(),
out TensorOpsAggregation::mean_dim(&self.tensor(), dim),
ops ADTensorOpsMeanDim::<B, D>::new((self.shape, dim)),
)
}
fn sum_dim(&self, dim: usize) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
execute_ops!(
input self.node.clone(),
out TensorOpsAggregation::sum_dim(&self.tensor(), dim),
ops ADTensorOpsSumDim::<B, D>::new((self.shape, dim)),
)
}
}
#[cfg(test)]
mod tests {
use crate::tensor::{backend::autodiff::helper::TestADTensor, Data};
#[test]
fn should_diff_mean() {
let data_1 = Data::<f64, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f64, 2>::from([[4.0, -7.0], [2.0, 3.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_1.mul(&tensor_3.mean().unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[3.5, 9.5], [3.5, 9.5]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[-0.75, -0.75], [3.0, 3.0]]), 5);
}
#[test]
fn should_diff_sum() {
let data_1 = Data::<f64, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f64, 2>::from([[4.0, -7.0], [2.0, 3.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_1.mul(&tensor_3.sum().unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[14.0, 38.0], [14.0, 38.0]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[-3.0, -3.0], [12.0, 12.0]]), 5);
}
#[test]
fn should_diff_mean_dim() {
let data_1 = Data::<f64, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f64, 2>::from([[4.0, -7.0], [2.0, 3.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_1.mul(&tensor_3.mean_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[4.0, 36.0], [3.0, -17.0]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[9.0, 9.0], [35.5, 35.5]]), 5);
}
#[test]
fn should_diff_sum_dim() {
let data_1 = Data::<f64, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f64, 2>::from([[4.0, -7.0], [2.0, 3.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_1.mul(&tensor_3.sum_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[8.0, 72.0], [6.0, -34.0]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[18.0, 18.0], [71.0, 71.0]]), 5);
}
}

View File

@ -1,4 +1,3 @@
mod aggregation;
mod arg;
mod base;
mod cat;

View File

@ -1,12 +1,13 @@
use super::{binary_ops_wrapper, unary_ops_wrapper};
use crate::tensor::ElementConversion;
use crate::{
backend::{
autodiff::{ADBackendDecorator, ADTensor},
Backend,
},
graph::ops::{BinaryOps, BinaryOpsNodeState, UnaryOps, UnaryOpsNodeState},
ops::{Ones, TensorOps, TensorOpsAggregation, Zeros},
Data, Shape,
ops::{Ones, TensorOps, Zeros},
Data, Shape, Tensor,
};
use std::ops::Range;
@ -508,7 +509,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
for i in 0..D2 {
if shape_value.dims[i] == 1 && shape_grad.dims[i] != 1 {
grad = grad.sum_dim(i);
grad = B::sum_dim(&grad, i);
}
}
@ -708,4 +709,138 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
ADTensor::from_tensor(B::detach(tensor.tensor_ref()))
}
fn mean<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<1> {
#[derive(new, Debug)]
struct Backward<B: Backend, const D: usize> {
shape: Shape<D>,
_b: B,
}
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<1>>
for Backward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<1>>,
) -> B::TensorPrimitive<D> {
let grad = state.output.grad();
let ones = B::ones(self.shape, B::device(&grad));
let grad: Tensor<B, 1> = Tensor::new(grad);
let val = 1_f64 / self.shape.num_elements() as f64;
let ones: Tensor<B, D> = Tensor::new(ones).mul_scalar(val);
ones.mul(&grad.unsqueeze()).value
}
}
let shape = B::shape(tensor.tensor_ref());
let output = B::mean(tensor.tensor_ref());
let ops = Backward::<B, D>::new(*shape, B::default());
unary_ops_wrapper(tensor.node.clone(), output, ops)
}
fn sum<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<1> {
#[derive(new, Debug)]
struct Backward<B: Backend, const D: usize> {
shape: Shape<D>,
_b: B,
}
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<1>>
for Backward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<1>>,
) -> B::TensorPrimitive<D> {
let grad = state.output.grad();
let ones = B::ones(self.shape, B::device(&grad));
let grad: Tensor<B, 1> = Tensor::new(grad);
let ones: Tensor<B, D> = Tensor::new(ones);
ones.mul(&grad.unsqueeze()).value
}
}
let shape = B::shape(tensor.tensor_ref());
let output = B::sum(tensor.tensor_ref());
let ops = Backward::<B, D>::new(*shape, B::default());
unary_ops_wrapper(tensor.node.clone(), output, ops)
}
fn mean_dim<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
dim: usize,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
#[derive(new, Debug)]
struct Backward<B: Backend, const D: usize> {
shape: Shape<D>,
dim: usize,
_b: B,
}
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for Backward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
let grad = B::sum_dim(&state.output.grad(), self.dim);
let ones = B::ones(self.shape, B::device(&grad));
let val = 1_f64 / self.shape.dims[self.dim] as f64;
let ones = B::mul_scalar(&ones, &B::Elem::from_elem(val));
B::mul(&ones, &grad)
}
}
let shape = B::shape(tensor.tensor_ref());
let output = B::mean_dim(tensor.tensor_ref(), dim);
let ops = Backward::<B, D>::new(*shape, dim, B::default());
unary_ops_wrapper(tensor.node.clone(), output, ops)
}
fn sum_dim<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
dim: usize,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
#[derive(new, Debug)]
struct Backward<B: Backend, const D: usize> {
shape: Shape<D>,
dim: usize,
_b: B,
}
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for Backward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
let grad = B::sum_dim(&state.output.grad(), self.dim);
let ones = B::ones(self.shape, B::device(&grad));
B::mul(&ones, &grad)
}
}
let shape = B::shape(tensor.tensor_ref());
let output = B::sum_dim(tensor.tensor_ref(), dim);
let ops = Backward::<B, D>::new(*shape, dim, B::default());
unary_ops_wrapper(tensor.node.clone(), output, ops)
}
}

View File

@ -24,7 +24,6 @@ pub trait Backend:
+ Zeros<Self::TensorPrimitive<D>>
+ Ones<Self::TensorPrimitive<D>>
+ TensorOpsPrecision<Self, D>
+ TensorOpsAggregation<Self, D>
+ TensorOpsExp<Self::Elem, D>
+ TensorOpsArg<Self, D>
+ TensorOpsCat<Self::Elem, D>

View File

@ -1,92 +0,0 @@
use crate::{
backend::Backend,
tensor::{
backend::ndarray::{NdArrayBackend, NdArrayTensor},
ops::*,
},
Data, NdArrayElement,
};
use ndarray::Axis;
macro_rules! keepdim {
(
$D:expr,
$dim:expr,
$self:expr,
mean
) => {{
let tensor: NdArrayTensor<E, $D> = mean_dim(&$self, $dim);
let mut shape = $self.shape.clone();
shape.dims[$dim] = 1;
NdArrayBackend::reshape(&tensor, shape)
}};
(
$D:expr,
$dim:expr,
$self:expr,
sum
) => {{
let tensor: NdArrayTensor<E, $D> = sum_dim(&$self, $dim);
let mut shape = $self.shape.clone();
shape.dims[$dim] = 1;
NdArrayBackend::reshape(&tensor, shape)
}};
}
impl<E: NdArrayElement, const D: usize> TensorOpsAggregation<NdArrayBackend<E>, D>
for NdArrayTensor<E, D>
{
fn mean(&self) -> NdArrayTensor<E, 1> {
let data = Data::from([self.array.mean().unwrap()]);
NdArrayTensor::from_data(data)
}
fn sum(&self) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<1> {
let data = Data::from([self.array.sum()]);
NdArrayTensor::from_data(data)
}
fn mean_dim(&self, dim: usize) -> Self {
match D {
1 => keepdim!(0, dim, self, mean),
2 => keepdim!(1, dim, self, mean),
3 => keepdim!(2, dim, self, mean),
4 => keepdim!(3, dim, self, mean),
5 => keepdim!(4, dim, self, mean),
6 => keepdim!(5, dim, self, mean),
_ => panic!("Dim not supported {}", D),
}
}
fn sum_dim(&self, dim: usize) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
match D {
1 => keepdim!(0, dim, self, sum),
2 => keepdim!(1, dim, self, sum),
3 => keepdim!(2, dim, self, sum),
4 => keepdim!(3, dim, self, sum),
5 => keepdim!(4, dim, self, sum),
6 => keepdim!(5, dim, self, sum),
_ => panic!("Dim not supported {}", D),
}
}
}
fn mean_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
tensor: &NdArrayTensor<E, D1>,
dim: usize,
) -> NdArrayTensor<E, D2> {
let array = tensor.array.mean_axis(Axis(dim)).unwrap().into_shared();
let shape = tensor.shape.remove_dim(dim);
NdArrayTensor { array, shape }
}
fn sum_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
tensor: &NdArrayTensor<E, D1>,
dim: usize,
) -> NdArrayTensor<E, D2> {
let array = tensor.array.sum_axis(Axis(dim)).into_shared();
let shape = tensor.shape.remove_dim(dim);
NdArrayTensor { array, shape }
}

View File

@ -1,4 +1,3 @@
mod aggregation;
mod arg;
mod cat;
mod creation;

View File

@ -1,12 +1,36 @@
use std::ops::Range;
use super::{BatchMatrix, NdArrayBackend, NdArrayTensor};
use crate::{
backend::{Backend, NdArrayDevice},
ops::TensorOps,
to_nd_array_tensor, Data, ElementConversion, NdArrayElement, Shape,
};
use ndarray::{Dim, SliceInfoElem};
use ndarray::{Axis, Dim, SliceInfoElem};
use std::ops::Range;
macro_rules! keepdim {
(
$D:expr,
$dim:expr,
$self:expr,
mean
) => {{
let tensor: NdArrayTensor<E, $D> = mean_dim(&$self, $dim);
let mut shape = $self.shape.clone();
shape.dims[$dim] = 1;
NdArrayBackend::reshape(&tensor, shape)
}};
(
$D:expr,
$dim:expr,
$self:expr,
sum
) => {{
let tensor: NdArrayTensor<E, $D> = sum_dim(&$self, $dim);
let mut shape = $self.shape.clone();
shape.dims[$dim] = 1;
NdArrayBackend::reshape(&tensor, shape)
}};
}
impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn shape<const D: usize>(
@ -365,9 +389,44 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
array,
}
}
fn detach<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
tensor.clone()
}
fn mean<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, 1> {
let data = Data::from([tensor.array.mean().unwrap()]);
NdArrayTensor::from_data(data)
}
fn sum<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, 1> {
let data = Data::from([tensor.array.sum()]);
NdArrayTensor::from_data(data)
}
fn mean_dim<const D: usize>(tensor: &NdArrayTensor<E, D>, dim: usize) -> NdArrayTensor<E, D> {
match D {
1 => keepdim!(0, dim, tensor, mean),
2 => keepdim!(1, dim, tensor, mean),
3 => keepdim!(2, dim, tensor, mean),
4 => keepdim!(3, dim, tensor, mean),
5 => keepdim!(4, dim, tensor, mean),
6 => keepdim!(5, dim, tensor, mean),
_ => panic!("Dim not supported {}", D),
}
}
fn sum_dim<const D: usize>(tensor: &NdArrayTensor<E, D>, dim: usize) -> NdArrayTensor<E, D> {
match D {
1 => keepdim!(0, dim, tensor, sum),
2 => keepdim!(1, dim, tensor, sum),
3 => keepdim!(2, dim, tensor, sum),
4 => keepdim!(3, dim, tensor, sum),
5 => keepdim!(4, dim, tensor, sum),
6 => keepdim!(5, dim, tensor, sum),
_ => panic!("Dim not supported {}", D),
}
}
}
fn to_slice_args<const D1: usize, const D2: usize>(
@ -391,3 +450,23 @@ fn to_slice_args<const D1: usize, const D2: usize>(
}
slices
}
fn mean_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
tensor: &NdArrayTensor<E, D1>,
dim: usize,
) -> NdArrayTensor<E, D2> {
let array = tensor.array.mean_axis(Axis(dim)).unwrap().into_shared();
let shape = tensor.shape.remove_dim(dim);
NdArrayTensor { array, shape }
}
fn sum_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
tensor: &NdArrayTensor<E, D1>,
dim: usize,
) -> NdArrayTensor<E, D2> {
let array = tensor.array.sum_axis(Axis(dim)).into_shared();
let shape = tensor.shape.remove_dim(dim);
NdArrayTensor { array, shape }
}

View File

@ -1,61 +0,0 @@
use crate::{
backend::Backend,
tensor::{
backend::tch::{TchBackend, TchTensor},
ops::*,
Shape,
},
TchElement,
};
impl<E: TchElement, const D: usize> TensorOpsAggregation<TchBackend<E>, D> for TchTensor<E, D> {
fn mean(&self) -> <TchBackend<E> as Backend>::TensorPrimitive<1> {
let kind = self.kind;
let tensor = self.tensor.mean(kind.kind());
let shape = Shape::new([1]);
TchTensor {
tensor,
kind,
shape,
}
}
fn sum(&self) -> <TchBackend<E> as Backend>::TensorPrimitive<1> {
let kind = self.kind;
let tensor = self.tensor.sum(kind.kind());
let shape = Shape::new([1]);
TchTensor {
tensor,
kind,
shape,
}
}
fn mean_dim(&self, dim: usize) -> <TchBackend<E> as Backend>::TensorPrimitive<D> {
let kind = self.kind;
let tensor = self.tensor.mean_dim(&[dim as i64], true, kind.kind());
let shape = Shape::from(tensor.size());
TchTensor {
tensor,
kind,
shape,
}
}
fn sum_dim(&self, dim: usize) -> <TchBackend<E> as Backend>::TensorPrimitive<D> {
let kind = self.kind;
let tensor = self
.tensor
.sum_dim_intlist(&[dim as i64], true, kind.kind());
let shape = Shape::from(tensor.size());
TchTensor {
tensor,
kind,
shape,
}
}
}

View File

@ -1,4 +1,3 @@
mod aggregation;
mod arg;
mod cat;
mod creation;

View File

@ -318,9 +318,34 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
kind: TchKind::<bool>::new(),
}
}
fn detach<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
tensor.clone()
}
fn mean<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, 1> {
let tensor = tensor.tensor.mean(tensor.kind.kind());
to_tensor(tensor)
}
fn sum<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, 1> {
let tensor = tensor.tensor.sum(tensor.kind.kind());
to_tensor(tensor)
}
fn mean_dim<const D: usize>(tensor: &TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
let tensor = tensor
.tensor
.mean_dim(&[dim as i64], true, tensor.kind.kind());
to_tensor(tensor)
}
fn sum_dim<const D: usize>(tensor: &TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
let tensor = tensor
.tensor
.sum_dim_intlist(&[dim as i64], true, tensor.kind.kind());
to_tensor(tensor)
}
}
fn to_tensor<const D: usize, E: TchElement>(tensor: tch::Tensor) -> TchTensor<E, D> {

View File

@ -257,22 +257,22 @@ where
/// Aggregate all elements in the tensor with the mean operation.
pub fn mean(&self) -> Tensor<B, 1> {
Tensor::new(self.value.mean())
Tensor::new(B::mean(&self.value))
}
/// Aggregate all elements in the tensor with the sum operation.
pub fn sum(&self) -> Tensor<B, 1> {
Tensor::new(self.value.sum())
Tensor::new(B::sum(&self.value))
}
/// Aggregate all elements along the given *dimension* or *axis* in the tensor with the mean operation.
pub fn mean_dim(&self, dim: usize) -> Self {
Self::new(self.value.mean_dim(dim))
Self::new(B::mean_dim(&self.value, dim))
}
/// Aggregate all elements along the given *dimension* or *axis* in the tensor with the sum operation.
pub fn sum_dim(&self, dim: usize) -> Self {
Self::new(self.value.sum_dim(dim))
Self::new(B::sum_dim(&self.value, dim))
}
/// Calculate the variance along the given dimension.

View File

@ -170,13 +170,14 @@ pub trait TensorOps<B: Backend> {
rhs: &B::Elem,
) -> B::BoolTensorPrimitive<D>;
fn detach<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
}
pub trait TensorOpsAggregation<B: Backend, const D: usize> {
fn mean(&self) -> B::TensorPrimitive<1>;
fn sum(&self) -> B::TensorPrimitive<1>;
fn mean_dim(&self, dim: usize) -> B::TensorPrimitive<D>;
fn sum_dim(&self, dim: usize) -> B::TensorPrimitive<D>;
fn mean<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<1>;
fn sum<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<1>;
fn mean_dim<const D: usize>(
tensor: &B::TensorPrimitive<D>,
dim: usize,
) -> B::TensorPrimitive<D>;
fn sum_dim<const D: usize>(tensor: &B::TensorPrimitive<D>, dim: usize)
-> B::TensorPrimitive<D>;
}
pub trait TensorOpsPrecision<B: Backend, const D: usize> {

View File

@ -1,13 +1,59 @@
use super::super::TestADBackend;
use burn_tensor::{Data, Tensor};
use crate::tensor::TestADTensor;
use burn_tensor::Data;
#[test]
fn test_sum_dim_grad() {
fn should_diff_mean() {
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f32, 2>::from([[4.0, -7.0], [2.0, 3.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_1.mul(&tensor_3.mean().unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[3.5, 9.5], [3.5, 9.5]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[-0.75, -0.75], [3.0, 3.0]]), 5);
}
#[test]
fn should_diff_sum_1() {
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f32, 2>::from([[4.0, -7.0], [2.0, 3.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_1.mul(&tensor_3.sum().unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[14.0, 38.0], [14.0, 38.0]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[-3.0, -3.0], [12.0, 12.0]]), 5);
}
#[test]
fn should_diff_sum_2() {
let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]);
let tensor_1 = Tensor::<TestADBackend, 2>::from_data(data_1);
let tensor_2 = Tensor::<TestADBackend, 2>::from_data(data_2);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_3.sum_dim(1);
@ -24,3 +70,49 @@ fn test_sum_dim_grad() {
.to_data()
.assert_approx_eq(&Data::from([[690.0, 690.0], [958.0, 958.0]]), 3);
}
#[test]
fn should_diff_mean_dim() {
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f32, 2>::from([[4.0, -7.0], [2.0, 3.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_1.mul(&tensor_3.mean_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[4.0, 36.0], [3.0, -17.0]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[9.0, 9.0], [35.5, 35.5]]), 5);
}
#[test]
fn should_diff_sum_dim() {
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f32, 2>::from([[4.0, -7.0], [2.0, 3.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_1.mul(&tensor_3.sum_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[8.0, 72.0], [6.0, -34.0]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[18.0, 18.0], [71.0, 71.0]]), 5);
}