refactor: args ops (#96)

This commit is contained in:
Nathaniel Simard 2022-11-12 11:29:42 -05:00 committed by GitHub
parent 0b77ef5dbc
commit 7684857282
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 101 additions and 142 deletions

View File

@ -1,21 +0,0 @@
use crate::backend::autodiff::ADBackendDecorator;
use crate::backend::Backend;
use crate::tensor::ops::*;
impl<B: Backend, const D: usize> TensorOpsArg<ADBackendDecorator<B>, D>
for <ADBackendDecorator<B> as Backend>::TensorPrimitive<D>
{
fn argmax(
&self,
dim: usize,
) -> <<ADBackendDecorator<B> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
TensorOpsArg::argmax(&self.tensor(), dim)
}
fn argmin(
&self,
dim: usize,
) -> <<ADBackendDecorator<B> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
TensorOpsArg::argmin(&self.tensor(), dim)
}
}

View File

@ -1,4 +1,3 @@
mod arg;
mod base;
mod cat;
mod creation;

View File

@ -917,4 +917,18 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
ops,
)
}
fn argmax<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
dim: usize,
) -> <<ADBackendDecorator<B> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
B::argmax(tensor.tensor_ref(), dim)
}
fn argmin<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
dim: usize,
) -> <<ADBackendDecorator<B> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
B::argmin(tensor.tensor_ref(), dim)
}
}

View File

@ -24,7 +24,6 @@ pub trait Backend:
+ Zeros<Self::TensorPrimitive<D>>
+ Ones<Self::TensorPrimitive<D>>
+ TensorOpsExp<Self::Elem, D>
+ TensorOpsArg<Self, D>
+ TensorOpsCat<Self::Elem, D>
+ TensorOpsLog<Self::Elem, D>
+ TensorOpsErf<Self::Elem, D>

View File

@ -1,74 +0,0 @@
use crate::backend::ndarray::NdArrayBackend;
use crate::tensor::{backend::ndarray::NdArrayTensor, ops::*};
use crate::{Data, NdArrayElement};
use std::cmp::Ordering;
impl<E, const D: usize> TensorOpsArg<NdArrayBackend<E>, D> for NdArrayTensor<E, D>
where
E: NdArrayElement,
{
fn argmax(&self, dim: usize) -> NdArrayTensor<i64, D> {
arg(self, dim, cmp_min)
}
fn argmin(&self, dim: usize) -> NdArrayTensor<i64, D> {
arg(self, dim, cmp_max)
}
}
fn arg<E: NdArrayElement, F, const D: usize>(
tensor: &NdArrayTensor<E, D>,
dim: usize,
cmp: F,
) -> NdArrayTensor<i64, D>
where
F: Fn(&f64, &f64) -> Ordering,
{
let mut data = <NdArrayBackend<E> as TensorOps<NdArrayBackend<E>>>::to_data::<D>(tensor);
let batch_size = tensor.shape.dims[dim];
let mut start = 0;
let mut end = tensor.shape.dims[dim];
let mut output = Vec::new();
while end <= data.value.len() {
let data_dim = &mut data.value[start..end];
let mut sorted: Vec<f64> = data_dim.iter().map(|a| a.to_elem()).collect();
sorted.sort_by(&cmp);
let max = sorted[0];
let data_dim = &mut data.value[start..end];
let mut index: i64 = 0;
for elem in data_dim {
let as_float: f64 = elem.to_elem();
if as_float == max {
break;
}
index += 1;
}
output.push(index);
start += batch_size;
end += batch_size;
}
let mut shape = tensor.shape;
shape.dims[dim] = 1;
NdArrayTensor::from_data(Data::new(output, shape))
}
fn cmp_max(a: &f64, b: &f64) -> Ordering {
if a < b {
return Ordering::Less;
} else if a > b {
return Ordering::Greater;
}
Ordering::Equal
}
fn cmp_min(a: &f64, b: &f64) -> Ordering {
if a > b {
return Ordering::Less;
} else if a < b {
return Ordering::Greater;
}
Ordering::Equal
}

View File

@ -1,4 +1,3 @@
mod arg;
mod cat;
mod creation;
mod erf;

View File

@ -5,7 +5,7 @@ use crate::{
to_nd_array_tensor, Data, ElementConversion, NdArrayElement, Shape,
};
use ndarray::{Axis, Dim, SliceInfoElem};
use std::ops::Range;
use std::{cmp::Ordering, ops::Range};
macro_rules! keepdim {
(
@ -445,6 +445,14 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
array,
}
}
fn argmax<const D: usize>(tensor: &NdArrayTensor<E, D>, dim: usize) -> NdArrayTensor<i64, D> {
arg(tensor, dim, cmp_min)
}
fn argmin<const D: usize>(tensor: &NdArrayTensor<E, D>, dim: usize) -> NdArrayTensor<i64, D> {
arg(tensor, dim, cmp_max)
}
}
fn to_slice_args<const D1: usize, const D2: usize>(
@ -488,3 +496,61 @@ fn sum_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
NdArrayTensor { array, shape }
}
fn arg<E: NdArrayElement, F, const D: usize>(
tensor: &NdArrayTensor<E, D>,
dim: usize,
cmp: F,
) -> NdArrayTensor<i64, D>
where
F: Fn(&f64, &f64) -> Ordering,
{
let batch_size = tensor.shape.dims[dim];
let mut data = NdArrayBackend::to_data::<D>(tensor);
let mut start = 0;
let mut end = tensor.shape.dims[dim];
let mut output = Vec::new();
while end <= data.value.len() {
let data_dim = &mut data.value[start..end];
let mut sorted: Vec<f64> = data_dim.iter().map(|a| a.to_elem()).collect();
sorted.sort_by(&cmp);
let max = sorted[0];
let data_dim = &mut data.value[start..end];
let mut index: i64 = 0;
for elem in data_dim {
let as_float: f64 = elem.to_elem();
if as_float == max {
break;
}
index += 1;
}
output.push(index);
start += batch_size;
end += batch_size;
}
let mut shape = tensor.shape;
shape.dims[dim] = 1;
NdArrayTensor::from_data(Data::new(output, shape))
}
fn cmp_max(a: &f64, b: &f64) -> Ordering {
if a < b {
return Ordering::Less;
} else if a > b {
return Ordering::Greater;
}
Ordering::Equal
}
fn cmp_min(a: &f64, b: &f64) -> Ordering {
if a > b {
return Ordering::Less;
} else if a < b {
return Ordering::Greater;
}
Ordering::Equal
}

View File

@ -1,35 +0,0 @@
use crate::backend::tch::TchBackend;
use crate::tensor::TchElement;
use crate::tensor::{
backend::tch::{TchKind, TchTensor},
ops::*,
};
impl<E, const D: usize> TensorOpsArg<TchBackend<E>, D> for TchTensor<E, D>
where
E: TchElement,
{
fn argmax(&self, dim: usize) -> TchTensor<i64, D> {
let tensor = self.tensor.argmax(dim as i64, true);
let mut shape = self.shape;
shape.dims[dim] = 1;
TchTensor {
tensor,
kind: TchKind::<i64>::new(),
shape,
}
}
fn argmin(&self, dim: usize) -> TchTensor<i64, D> {
let tensor = self.tensor.argmin(dim as i64, true);
let mut shape = self.shape;
shape.dims[dim] = 1;
TchTensor {
tensor,
kind: TchKind::<i64>::new(),
shape,
}
}
}

View File

@ -1,4 +1,3 @@
mod arg;
mod cat;
mod creation;
mod erf;

View File

@ -356,6 +356,16 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
let tensor = tensor.tensor.to_kind(TchKind::<E>::new().kind());
to_tensor(tensor)
}
fn argmax<const D: usize>(tensor: &TchTensor<E, D>, dim: usize) -> TchTensor<i64, D> {
let tensor = tensor.tensor.argmax(dim as i64, true);
to_tensor(tensor)
}
fn argmin<const D: usize>(tensor: &TchTensor<E, D>, dim: usize) -> TchTensor<i64, D> {
let tensor = tensor.tensor.argmin(dim as i64, true);
to_tensor(tensor)
}
}
fn to_tensor<const D: usize, E: TchElement>(tensor: tch::Tensor) -> TchTensor<E, D> {

View File

@ -478,7 +478,7 @@ where
/// }
/// ```
pub fn argmax(&self, dim: usize) -> Tensor<B::IntegerBackend, D> {
Tensor::new(self.value.argmax(dim))
Tensor::new(B::argmax(&self.value, dim))
}
/// Applies the argmin function along the given dimension and returns an integer tensor.
@ -497,7 +497,7 @@ where
/// }
/// ```
pub fn argmin(&self, dim: usize) -> Tensor<B::IntegerBackend, D> {
Tensor::new(self.value.argmin(dim))
Tensor::new(B::argmin(&self.value, dim))
}
/// Concatenates all tensors into a new one along the given dimension.

View File

@ -184,11 +184,14 @@ pub trait TensorOps<B: Backend> {
fn from_full_precision<const D: usize>(
tensor: &<B::FullPrecisionBackend as Backend>::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
}
pub trait TensorOpsArg<B: Backend, const D: usize> {
fn argmax(&self, dim: usize) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
fn argmin(&self, dim: usize) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
fn argmax<const D: usize>(
tensor: &B::TensorPrimitive<D>,
dim: usize,
) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
fn argmin<const D: usize>(
tensor: &B::TensorPrimitive<D>,
dim: usize,
) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
}
pub trait TensorOpsExp<E, const D: usize> {