mirror of https://github.com/tracel-ai/burn.git
refactor: args ops (#96)
This commit is contained in:
parent
0b77ef5dbc
commit
7684857282
|
@ -1,21 +0,0 @@
|
|||
use crate::backend::autodiff::ADBackendDecorator;
|
||||
use crate::backend::Backend;
|
||||
use crate::tensor::ops::*;
|
||||
|
||||
impl<B: Backend, const D: usize> TensorOpsArg<ADBackendDecorator<B>, D>
|
||||
for <ADBackendDecorator<B> as Backend>::TensorPrimitive<D>
|
||||
{
|
||||
fn argmax(
|
||||
&self,
|
||||
dim: usize,
|
||||
) -> <<ADBackendDecorator<B> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
|
||||
TensorOpsArg::argmax(&self.tensor(), dim)
|
||||
}
|
||||
|
||||
fn argmin(
|
||||
&self,
|
||||
dim: usize,
|
||||
) -> <<ADBackendDecorator<B> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
|
||||
TensorOpsArg::argmin(&self.tensor(), dim)
|
||||
}
|
||||
}
|
|
@ -1,4 +1,3 @@
|
|||
mod arg;
|
||||
mod base;
|
||||
mod cat;
|
||||
mod creation;
|
||||
|
|
|
@ -917,4 +917,18 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
ops,
|
||||
)
|
||||
}
|
||||
|
||||
fn argmax<const D: usize>(
|
||||
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> <<ADBackendDecorator<B> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
|
||||
B::argmax(tensor.tensor_ref(), dim)
|
||||
}
|
||||
|
||||
fn argmin<const D: usize>(
|
||||
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> <<ADBackendDecorator<B> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
|
||||
B::argmin(tensor.tensor_ref(), dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,7 +24,6 @@ pub trait Backend:
|
|||
+ Zeros<Self::TensorPrimitive<D>>
|
||||
+ Ones<Self::TensorPrimitive<D>>
|
||||
+ TensorOpsExp<Self::Elem, D>
|
||||
+ TensorOpsArg<Self, D>
|
||||
+ TensorOpsCat<Self::Elem, D>
|
||||
+ TensorOpsLog<Self::Elem, D>
|
||||
+ TensorOpsErf<Self::Elem, D>
|
||||
|
|
|
@ -1,74 +0,0 @@
|
|||
use crate::backend::ndarray::NdArrayBackend;
|
||||
use crate::tensor::{backend::ndarray::NdArrayTensor, ops::*};
|
||||
use crate::{Data, NdArrayElement};
|
||||
use std::cmp::Ordering;
|
||||
|
||||
impl<E, const D: usize> TensorOpsArg<NdArrayBackend<E>, D> for NdArrayTensor<E, D>
|
||||
where
|
||||
E: NdArrayElement,
|
||||
{
|
||||
fn argmax(&self, dim: usize) -> NdArrayTensor<i64, D> {
|
||||
arg(self, dim, cmp_min)
|
||||
}
|
||||
|
||||
fn argmin(&self, dim: usize) -> NdArrayTensor<i64, D> {
|
||||
arg(self, dim, cmp_max)
|
||||
}
|
||||
}
|
||||
|
||||
fn arg<E: NdArrayElement, F, const D: usize>(
|
||||
tensor: &NdArrayTensor<E, D>,
|
||||
dim: usize,
|
||||
cmp: F,
|
||||
) -> NdArrayTensor<i64, D>
|
||||
where
|
||||
F: Fn(&f64, &f64) -> Ordering,
|
||||
{
|
||||
let mut data = <NdArrayBackend<E> as TensorOps<NdArrayBackend<E>>>::to_data::<D>(tensor);
|
||||
let batch_size = tensor.shape.dims[dim];
|
||||
let mut start = 0;
|
||||
let mut end = tensor.shape.dims[dim];
|
||||
let mut output = Vec::new();
|
||||
|
||||
while end <= data.value.len() {
|
||||
let data_dim = &mut data.value[start..end];
|
||||
let mut sorted: Vec<f64> = data_dim.iter().map(|a| a.to_elem()).collect();
|
||||
sorted.sort_by(&cmp);
|
||||
|
||||
let max = sorted[0];
|
||||
|
||||
let data_dim = &mut data.value[start..end];
|
||||
let mut index: i64 = 0;
|
||||
for elem in data_dim {
|
||||
let as_float: f64 = elem.to_elem();
|
||||
if as_float == max {
|
||||
break;
|
||||
}
|
||||
index += 1;
|
||||
}
|
||||
output.push(index);
|
||||
start += batch_size;
|
||||
end += batch_size;
|
||||
}
|
||||
let mut shape = tensor.shape;
|
||||
shape.dims[dim] = 1;
|
||||
NdArrayTensor::from_data(Data::new(output, shape))
|
||||
}
|
||||
|
||||
fn cmp_max(a: &f64, b: &f64) -> Ordering {
|
||||
if a < b {
|
||||
return Ordering::Less;
|
||||
} else if a > b {
|
||||
return Ordering::Greater;
|
||||
}
|
||||
Ordering::Equal
|
||||
}
|
||||
|
||||
fn cmp_min(a: &f64, b: &f64) -> Ordering {
|
||||
if a > b {
|
||||
return Ordering::Less;
|
||||
} else if a < b {
|
||||
return Ordering::Greater;
|
||||
}
|
||||
Ordering::Equal
|
||||
}
|
|
@ -1,4 +1,3 @@
|
|||
mod arg;
|
||||
mod cat;
|
||||
mod creation;
|
||||
mod erf;
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::{
|
|||
to_nd_array_tensor, Data, ElementConversion, NdArrayElement, Shape,
|
||||
};
|
||||
use ndarray::{Axis, Dim, SliceInfoElem};
|
||||
use std::ops::Range;
|
||||
use std::{cmp::Ordering, ops::Range};
|
||||
|
||||
macro_rules! keepdim {
|
||||
(
|
||||
|
@ -445,6 +445,14 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
array,
|
||||
}
|
||||
}
|
||||
|
||||
fn argmax<const D: usize>(tensor: &NdArrayTensor<E, D>, dim: usize) -> NdArrayTensor<i64, D> {
|
||||
arg(tensor, dim, cmp_min)
|
||||
}
|
||||
|
||||
fn argmin<const D: usize>(tensor: &NdArrayTensor<E, D>, dim: usize) -> NdArrayTensor<i64, D> {
|
||||
arg(tensor, dim, cmp_max)
|
||||
}
|
||||
}
|
||||
|
||||
fn to_slice_args<const D1: usize, const D2: usize>(
|
||||
|
@ -488,3 +496,61 @@ fn sum_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
|
|||
|
||||
NdArrayTensor { array, shape }
|
||||
}
|
||||
|
||||
fn arg<E: NdArrayElement, F, const D: usize>(
|
||||
tensor: &NdArrayTensor<E, D>,
|
||||
dim: usize,
|
||||
cmp: F,
|
||||
) -> NdArrayTensor<i64, D>
|
||||
where
|
||||
F: Fn(&f64, &f64) -> Ordering,
|
||||
{
|
||||
let batch_size = tensor.shape.dims[dim];
|
||||
|
||||
let mut data = NdArrayBackend::to_data::<D>(tensor);
|
||||
let mut start = 0;
|
||||
let mut end = tensor.shape.dims[dim];
|
||||
let mut output = Vec::new();
|
||||
|
||||
while end <= data.value.len() {
|
||||
let data_dim = &mut data.value[start..end];
|
||||
let mut sorted: Vec<f64> = data_dim.iter().map(|a| a.to_elem()).collect();
|
||||
sorted.sort_by(&cmp);
|
||||
|
||||
let max = sorted[0];
|
||||
|
||||
let data_dim = &mut data.value[start..end];
|
||||
let mut index: i64 = 0;
|
||||
for elem in data_dim {
|
||||
let as_float: f64 = elem.to_elem();
|
||||
if as_float == max {
|
||||
break;
|
||||
}
|
||||
index += 1;
|
||||
}
|
||||
output.push(index);
|
||||
start += batch_size;
|
||||
end += batch_size;
|
||||
}
|
||||
let mut shape = tensor.shape;
|
||||
shape.dims[dim] = 1;
|
||||
NdArrayTensor::from_data(Data::new(output, shape))
|
||||
}
|
||||
|
||||
fn cmp_max(a: &f64, b: &f64) -> Ordering {
|
||||
if a < b {
|
||||
return Ordering::Less;
|
||||
} else if a > b {
|
||||
return Ordering::Greater;
|
||||
}
|
||||
Ordering::Equal
|
||||
}
|
||||
|
||||
fn cmp_min(a: &f64, b: &f64) -> Ordering {
|
||||
if a > b {
|
||||
return Ordering::Less;
|
||||
} else if a < b {
|
||||
return Ordering::Greater;
|
||||
}
|
||||
Ordering::Equal
|
||||
}
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
use crate::backend::tch::TchBackend;
|
||||
use crate::tensor::TchElement;
|
||||
use crate::tensor::{
|
||||
backend::tch::{TchKind, TchTensor},
|
||||
ops::*,
|
||||
};
|
||||
|
||||
impl<E, const D: usize> TensorOpsArg<TchBackend<E>, D> for TchTensor<E, D>
|
||||
where
|
||||
E: TchElement,
|
||||
{
|
||||
fn argmax(&self, dim: usize) -> TchTensor<i64, D> {
|
||||
let tensor = self.tensor.argmax(dim as i64, true);
|
||||
let mut shape = self.shape;
|
||||
shape.dims[dim] = 1;
|
||||
|
||||
TchTensor {
|
||||
tensor,
|
||||
kind: TchKind::<i64>::new(),
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
fn argmin(&self, dim: usize) -> TchTensor<i64, D> {
|
||||
let tensor = self.tensor.argmin(dim as i64, true);
|
||||
let mut shape = self.shape;
|
||||
shape.dims[dim] = 1;
|
||||
|
||||
TchTensor {
|
||||
tensor,
|
||||
kind: TchKind::<i64>::new(),
|
||||
shape,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +1,3 @@
|
|||
mod arg;
|
||||
mod cat;
|
||||
mod creation;
|
||||
mod erf;
|
||||
|
|
|
@ -356,6 +356,16 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
let tensor = tensor.tensor.to_kind(TchKind::<E>::new().kind());
|
||||
to_tensor(tensor)
|
||||
}
|
||||
|
||||
fn argmax<const D: usize>(tensor: &TchTensor<E, D>, dim: usize) -> TchTensor<i64, D> {
|
||||
let tensor = tensor.tensor.argmax(dim as i64, true);
|
||||
to_tensor(tensor)
|
||||
}
|
||||
|
||||
fn argmin<const D: usize>(tensor: &TchTensor<E, D>, dim: usize) -> TchTensor<i64, D> {
|
||||
let tensor = tensor.tensor.argmin(dim as i64, true);
|
||||
to_tensor(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
fn to_tensor<const D: usize, E: TchElement>(tensor: tch::Tensor) -> TchTensor<E, D> {
|
||||
|
|
|
@ -478,7 +478,7 @@ where
|
|||
/// }
|
||||
/// ```
|
||||
pub fn argmax(&self, dim: usize) -> Tensor<B::IntegerBackend, D> {
|
||||
Tensor::new(self.value.argmax(dim))
|
||||
Tensor::new(B::argmax(&self.value, dim))
|
||||
}
|
||||
|
||||
/// Applies the argmin function along the given dimension and returns an integer tensor.
|
||||
|
@ -497,7 +497,7 @@ where
|
|||
/// }
|
||||
/// ```
|
||||
pub fn argmin(&self, dim: usize) -> Tensor<B::IntegerBackend, D> {
|
||||
Tensor::new(self.value.argmin(dim))
|
||||
Tensor::new(B::argmin(&self.value, dim))
|
||||
}
|
||||
|
||||
/// Concatenates all tensors into a new one along the given dimension.
|
||||
|
|
|
@ -184,11 +184,14 @@ pub trait TensorOps<B: Backend> {
|
|||
fn from_full_precision<const D: usize>(
|
||||
tensor: &<B::FullPrecisionBackend as Backend>::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
}
|
||||
|
||||
pub trait TensorOpsArg<B: Backend, const D: usize> {
|
||||
fn argmax(&self, dim: usize) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
|
||||
fn argmin(&self, dim: usize) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
|
||||
fn argmax<const D: usize>(
|
||||
tensor: &B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
|
||||
fn argmin<const D: usize>(
|
||||
tensor: &B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
|
||||
}
|
||||
|
||||
pub trait TensorOpsExp<E, const D: usize> {
|
||||
|
|
Loading…
Reference in New Issue