mirror of https://github.com/tracel-ai/burn.git
Feat/index_select_dim ops (#225)
This commit is contained in:
parent
860051ca5c
commit
9655b74b22
|
@ -194,4 +194,21 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
) -> BoolTensor<B, D> {
|
||||
B::int_lower_equal_elem(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_index_select_dim<const D: usize>(
|
||||
tensor: IntTensor<B, D>,
|
||||
dim: usize,
|
||||
indexes: IntTensor<B, 1>,
|
||||
) -> IntTensor<B, D> {
|
||||
B::int_index_select_dim(tensor, dim, indexes)
|
||||
}
|
||||
|
||||
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: IntTensor<B, D1>,
|
||||
dim: usize,
|
||||
indexes: IntTensor<B, 1>,
|
||||
value: IntTensor<B, D2>,
|
||||
) -> IntTensor<B, D1> {
|
||||
B::int_index_select_dim_assign(tensor, dim, indexes, value)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -420,6 +420,101 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
}
|
||||
}
|
||||
|
||||
fn index_select_dim<const D: usize>(
|
||||
tensor: ADTensor<B, D>,
|
||||
dim: usize,
|
||||
indexes: IntTensor<B, 1>,
|
||||
) -> ADTensor<B, D> {
|
||||
#[derive(Debug)]
|
||||
struct IndexSelectDim;
|
||||
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 1> for IndexSelectDim {
|
||||
type State = (usize, IntTensor<B, 1>, Shape<D>, B::Device);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
let (dim, indexes, shape, device) = ops.state;
|
||||
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let zeros = B::zeros(shape, &device);
|
||||
B::index_select_dim_assign(zeros, dim, indexes, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match IndexSelectDim
|
||||
.prepare([tensor.node], [tensor.graph])
|
||||
.statefull()
|
||||
{
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
(
|
||||
dim,
|
||||
indexes.clone(),
|
||||
B::shape(&tensor.primitive),
|
||||
B::device(&tensor.primitive),
|
||||
),
|
||||
B::index_select_dim(tensor.primitive, dim, indexes),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => {
|
||||
prep.finish(B::index_select_dim(tensor.primitive, dim, indexes))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: ADTensor<B, D1>,
|
||||
dim: usize,
|
||||
indexes: IntTensor<B, 1>,
|
||||
value: ADTensor<B, D2>,
|
||||
) -> ADTensor<B, D1> {
|
||||
#[derive(Debug)]
|
||||
struct IndexSelectDimAssign<const D2: usize>;
|
||||
|
||||
impl<B: Backend, const D1: usize, const D2: usize> Backward<B, D1, 2> for IndexSelectDimAssign<D2> {
|
||||
type State = (usize, IntTensor<B, 1>, Shape<D1>, Shape<D2>, B::Device);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
|
||||
let (dim, indexes, shape_lhs, shape_rhs, device) = ops.state;
|
||||
let [indexes_4lhs, indexes_4rhs] = duplicate(&ops.parents, Some(indexes));
|
||||
|
||||
binary::<B, D1, D1, D2, _, _>(
|
||||
ops.parents,
|
||||
ops.node,
|
||||
grads,
|
||||
|grad| {
|
||||
let zeros = B::zeros(shape_lhs, &device);
|
||||
B::index_select_dim_assign(grad, dim, indexes_4lhs.unwrap(), zeros)
|
||||
},
|
||||
|grad| {
|
||||
let zeros = B::zeros(shape_rhs, &device);
|
||||
B::index_select_dim_assign(zeros, dim, indexes_4rhs.unwrap(), grad)
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match IndexSelectDimAssign::<D2>
|
||||
.prepare([tensor.node, value.node], [tensor.graph, value.graph])
|
||||
.statefull()
|
||||
{
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
(
|
||||
dim,
|
||||
indexes.clone(),
|
||||
B::shape(&tensor.primitive),
|
||||
B::shape(&value.primitive),
|
||||
B::device(&value.primitive),
|
||||
),
|
||||
B::index_select_dim_assign(tensor.primitive, dim, indexes, value.primitive),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::index_select_dim_assign(
|
||||
tensor.primitive,
|
||||
dim,
|
||||
indexes,
|
||||
value.primitive,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
tensor: ADTensor<B, D1>,
|
||||
indexes: [std::ops::Range<usize>; D2],
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
#[burn_tensor_testgen::testgen(ad_index_select_dim)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn test_select_grad() {
|
||||
let tensor_1 =
|
||||
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
|
||||
let values =
|
||||
TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad();
|
||||
let indexes = TestADTensor::from_data(Data::from([1, 0]));
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 = tensor_1
|
||||
.clone()
|
||||
.index_select_dim_assign(0, indexes, values.clone());
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = values.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
grad_1.into_data(),
|
||||
Data::from([[127., 199., 271.], [172., 244., 316.]])
|
||||
);
|
||||
assert_eq!(
|
||||
grad_2.into_data(),
|
||||
Data::from([[64., 64., 64.], [19., 19., 19.]])
|
||||
);
|
||||
}
|
||||
}
|
|
@ -11,6 +11,7 @@ mod div;
|
|||
mod erf;
|
||||
mod exp;
|
||||
mod index;
|
||||
mod index_select_dim;
|
||||
mod log;
|
||||
mod log1p;
|
||||
mod mask;
|
||||
|
@ -52,6 +53,7 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_div!();
|
||||
burn_autodiff::testgen_ad_erf!();
|
||||
burn_autodiff::testgen_ad_exp!();
|
||||
burn_autodiff::testgen_ad_index_select_dim!();
|
||||
burn_autodiff::testgen_ad_index!();
|
||||
burn_autodiff::testgen_ad_log!();
|
||||
burn_autodiff::testgen_ad_log1p!();
|
||||
|
|
|
@ -27,6 +27,7 @@ extern crate alloc;
|
|||
mod tests {
|
||||
type TestBackend = crate::NdArrayBackend<f32>;
|
||||
type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
|
||||
type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
|
||||
|
||||
burn_tensor::testgen_all!();
|
||||
|
||||
|
|
|
@ -203,4 +203,39 @@ where
|
|||
_ => panic!("Dim not supported {D}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn index_select_dim<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
let array = tensor.array.select(
|
||||
Axis(dim),
|
||||
&indexes
|
||||
.array
|
||||
.into_iter()
|
||||
.map(|i| i as usize)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
NdArrayTensor::new(array.into_shared())
|
||||
}
|
||||
|
||||
pub fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
value: NdArrayTensor<E, D2>,
|
||||
) -> NdArrayTensor<E, D1> {
|
||||
let mut output_array = tensor.array.into_owned();
|
||||
|
||||
for (index_value, index) in indexes.array.into_iter().enumerate() {
|
||||
let mut view = output_array.index_axis_mut(Axis(dim), index as usize);
|
||||
let value = value.array.index_axis(Axis(0), index_value);
|
||||
|
||||
view.zip_mut_with(&value, |a, b| *a = *a + *b);
|
||||
}
|
||||
|
||||
NdArrayTensor::new(output_array.into_shared())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -263,4 +263,21 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
|
|||
) -> NdArrayTensor<i64, D> {
|
||||
NdArrayMathOps::mean_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_index_select_dim<const D: usize>(
|
||||
tensor: NdArrayTensor<i64, D>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
) -> NdArrayTensor<i64, D> {
|
||||
NdArrayMathOps::index_select_dim(tensor, dim, indexes)
|
||||
}
|
||||
|
||||
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<i64, D1>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
value: NdArrayTensor<i64, D2>,
|
||||
) -> NdArrayTensor<i64, D1> {
|
||||
NdArrayMathOps::index_select_dim_assign(tensor, dim, indexes, value)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -151,6 +151,23 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
NdArrayOps::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn index_select_dim<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
NdArrayMathOps::index_select_dim(tensor, dim, indexes)
|
||||
}
|
||||
|
||||
fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
value: NdArrayTensor<E, D2>,
|
||||
) -> NdArrayTensor<E, D1> {
|
||||
NdArrayMathOps::index_select_dim_assign(tensor, dim, indexes, value)
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
|
|
|
@ -10,6 +10,7 @@ pub use tensor::*;
|
|||
mod tests {
|
||||
type TestBackend = crate::TchBackend<f32>;
|
||||
type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
|
||||
type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
|
||||
|
||||
burn_tensor::testgen_all!();
|
||||
burn_autodiff::testgen_all!();
|
||||
|
|
|
@ -45,6 +45,27 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
TchTensor::new(tensor_original)
|
||||
}
|
||||
|
||||
pub fn index_select_dim<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: TchTensor<i64, 1>,
|
||||
) -> TchTensor<E, D> {
|
||||
let tensor = tensor.tensor.index_select(dim as i64, &indexes.tensor);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<E, D1>,
|
||||
dim: usize,
|
||||
indexes: TchTensor<i64, 1>,
|
||||
value: TchTensor<E, D2>,
|
||||
) -> TchTensor<E, D1> {
|
||||
let mut indices = vec![None; D1];
|
||||
indices[dim] = Some(indexes.tensor);
|
||||
let tensor = tensor.tensor.index_put(&indices, &value.tensor, true);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn cat<const D: usize>(tensors: Vec<TchTensor<E, D>>, dim: usize) -> TchTensor<E, D> {
|
||||
let tensors: Vec<tch::Tensor> = tensors
|
||||
.into_iter()
|
||||
|
|
|
@ -243,4 +243,21 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
fn int_mean_dim<const D: usize>(tensor: TchTensor<i64, D>, dim: usize) -> TchTensor<i64, D> {
|
||||
TchOps::mean_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_index_select_dim<const D: usize>(
|
||||
tensor: TchTensor<i64, D>,
|
||||
dim: usize,
|
||||
indexes: TchTensor<i64, 1>,
|
||||
) -> TchTensor<i64, D> {
|
||||
TchOps::index_select_dim(tensor, dim, indexes)
|
||||
}
|
||||
|
||||
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<i64, D1>,
|
||||
dim: usize,
|
||||
indexes: TchTensor<i64, 1>,
|
||||
value: TchTensor<i64, D2>,
|
||||
) -> TchTensor<i64, D1> {
|
||||
TchOps::index_select_dim_assign(tensor, dim, indexes, value)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -190,6 +190,23 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn index_select_dim<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: TchTensor<i64, 1>,
|
||||
) -> TchTensor<E, D> {
|
||||
TchOps::index_select_dim(tensor, dim, indexes)
|
||||
}
|
||||
|
||||
fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<E, D1>,
|
||||
dim: usize,
|
||||
indexes: TchTensor<i64, 1>,
|
||||
value: TchTensor<E, D2>,
|
||||
) -> TchTensor<E, D1> {
|
||||
TchOps::index_select_dim_assign(tensor, dim, indexes, value)
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
|
|
|
@ -342,6 +342,27 @@ where
|
|||
self.reshape(shape)
|
||||
}
|
||||
|
||||
/// Index the tensor along the given dimension using the given indexes.
|
||||
pub fn index_select_dim(self, dim: usize, indexes: Tensor<B, 1, Int>) -> Self {
|
||||
Self::new(B::index_select_dim(self.primitive, dim, indexes.primitive))
|
||||
}
|
||||
|
||||
/// Return a new tensor with the same dimension, but with the values added to
|
||||
/// the original tensor using the corresponding indexes provided along the given dimension.
|
||||
pub fn index_select_dim_assign<const D2: usize>(
|
||||
self,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
values: Tensor<B, D2>,
|
||||
) -> Self {
|
||||
Self::new(B::index_select_dim_assign(
|
||||
self.primitive,
|
||||
dim,
|
||||
indexes.primitive,
|
||||
values.primitive,
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn relu(self) -> Self {
|
||||
Self::new(B::relu(self.primitive))
|
||||
}
|
||||
|
|
|
@ -234,6 +234,17 @@ pub trait Numeric<B: Backend>: TensorKind<B> {
|
|||
lhs: Self::Primitive<D>,
|
||||
rhs: Self::Elem,
|
||||
) -> Tensor<B, D, Bool>;
|
||||
fn index_select_dim<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
) -> Self::Primitive<D>;
|
||||
fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
values: Self::Primitive<D2>,
|
||||
) -> Self::Primitive<D1>;
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric<B> for Int {
|
||||
|
@ -361,6 +372,23 @@ impl<B: Backend> Numeric<B> for Int {
|
|||
) -> Tensor<B, D, Bool> {
|
||||
Tensor::new(B::int_lower_equal_elem(lhs, rhs))
|
||||
}
|
||||
|
||||
fn index_select_dim<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::int_index_select_dim(tensor, dim, indexes.primitive)
|
||||
}
|
||||
|
||||
fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
values: Self::Primitive<D2>,
|
||||
) -> Self::Primitive<D1> {
|
||||
B::int_index_select_dim_assign(tensor, dim, indexes.primitive, values)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric<B> for Float {
|
||||
|
@ -488,6 +516,23 @@ impl<B: Backend> Numeric<B> for Float {
|
|||
) -> Tensor<B, D, Bool> {
|
||||
Tensor::new(B::lower_equal_elem(lhs, rhs))
|
||||
}
|
||||
|
||||
fn index_select_dim<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::index_select_dim(tensor, dim, indexes.primitive)
|
||||
}
|
||||
|
||||
fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
values: Self::Primitive<D2>,
|
||||
) -> Self::Primitive<D1> {
|
||||
B::index_select_dim_assign(tensor, dim, indexes.primitive, values)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, const D: usize, K> core::ops::Add<Self> for Tensor<B, D, K>
|
||||
|
|
|
@ -34,6 +34,17 @@ pub trait IntTensorOps<B: Backend> {
|
|||
indexes: [Range<usize>; D2],
|
||||
value: B::IntTensorPrimitive<D1>,
|
||||
) -> B::IntTensorPrimitive<D1>;
|
||||
fn int_index_select_dim<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
indexes: B::IntTensorPrimitive<1>,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: B::IntTensorPrimitive<D1>,
|
||||
dim: usize,
|
||||
indexes: B::IntTensorPrimitive<1>,
|
||||
value: B::IntTensorPrimitive<D2>,
|
||||
) -> B::IntTensorPrimitive<D1>;
|
||||
fn int_repeat<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -115,6 +115,17 @@ pub trait TensorOps<B: Backend> {
|
|||
tensor: B::TensorPrimitive<D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> B::TensorPrimitive<D2>;
|
||||
fn index_select_dim<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
indexes: B::IntTensorPrimitive<1>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: B::TensorPrimitive<D1>,
|
||||
dim: usize,
|
||||
indexes: B::IntTensorPrimitive<1>,
|
||||
value: B::TensorPrimitive<D2>,
|
||||
) -> B::TensorPrimitive<D1>;
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
tensor: B::TensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
|
|
|
@ -28,6 +28,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_exp!();
|
||||
burn_tensor::testgen_log!();
|
||||
burn_tensor::testgen_log1p!();
|
||||
burn_tensor::testgen_index_select_dim!();
|
||||
burn_tensor::testgen_index!();
|
||||
burn_tensor::testgen_map_comparison!();
|
||||
burn_tensor::testgen_mask!();
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
#[burn_tensor_testgen::testgen(index_select_dim)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_select_1d() {
|
||||
let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
|
||||
|
||||
let output = tensor.index_select_dim(0, indexes);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_2d_dim0_same_num_dim() {
|
||||
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 0]));
|
||||
|
||||
let output = tensor.index_select_dim(0, indexes);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_2d_dim0_more_num_dim() {
|
||||
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 0, 1, 1]));
|
||||
|
||||
let output = tensor.index_select_dim(0, indexes);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([
|
||||
[3.0, 4.0, 5.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[3.0, 4.0, 5.0]
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_2d_dim1() {
|
||||
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
|
||||
|
||||
let output = tensor.index_select_dim(1, indexes);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_assign_1d() {
|
||||
let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0]));
|
||||
let values = TestTensor::from_data(Data::from([5.0, 4.0, 3.0, 2.0, 1.0]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
|
||||
|
||||
let output = tensor.index_select_dim_assign(0, indexes, values);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_assign_2d_dim0() {
|
||||
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
|
||||
let values = TestTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 0]));
|
||||
|
||||
let output = tensor.index_select_dim_assign(0, indexes, values);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]])
|
||||
);
|
||||
}
|
||||
}
|
|
@ -6,6 +6,7 @@ mod div;
|
|||
mod erf;
|
||||
mod exp;
|
||||
mod index;
|
||||
mod index_select_dim;
|
||||
mod log;
|
||||
mod log1p;
|
||||
mod map_comparison;
|
||||
|
|
Loading…
Reference in New Issue