Feat/index_select_dim ops (#225)

This commit is contained in:
Nathaniel Simard 2023-03-11 16:14:57 -05:00 committed by GitHub
parent 860051ca5c
commit 9655b74b22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 448 additions and 0 deletions

View File

@ -194,4 +194,21 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
) -> BoolTensor<B, D> {
B::int_lower_equal_elem(lhs, rhs)
}
fn int_index_select_dim<const D: usize>(
tensor: IntTensor<B, D>,
dim: usize,
indexes: IntTensor<B, 1>,
) -> IntTensor<B, D> {
B::int_index_select_dim(tensor, dim, indexes)
}
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: IntTensor<B, D1>,
dim: usize,
indexes: IntTensor<B, 1>,
value: IntTensor<B, D2>,
) -> IntTensor<B, D1> {
B::int_index_select_dim_assign(tensor, dim, indexes, value)
}
}

View File

@ -420,6 +420,101 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}
}
fn index_select_dim<const D: usize>(
tensor: ADTensor<B, D>,
dim: usize,
indexes: IntTensor<B, 1>,
) -> ADTensor<B, D> {
#[derive(Debug)]
struct IndexSelectDim;
impl<B: Backend, const D: usize> Backward<B, D, 1> for IndexSelectDim {
type State = (usize, IntTensor<B, 1>, Shape<D>, B::Device);
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let (dim, indexes, shape, device) = ops.state;
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let zeros = B::zeros(shape, &device);
B::index_select_dim_assign(zeros, dim, indexes, grad)
});
}
}
match IndexSelectDim
.prepare([tensor.node], [tensor.graph])
.statefull()
{
OpsKind::Tracked(prep) => prep.finish(
(
dim,
indexes.clone(),
B::shape(&tensor.primitive),
B::device(&tensor.primitive),
),
B::index_select_dim(tensor.primitive, dim, indexes),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::index_select_dim(tensor.primitive, dim, indexes))
}
}
}
fn index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: ADTensor<B, D1>,
dim: usize,
indexes: IntTensor<B, 1>,
value: ADTensor<B, D2>,
) -> ADTensor<B, D1> {
#[derive(Debug)]
struct IndexSelectDimAssign<const D2: usize>;
impl<B: Backend, const D1: usize, const D2: usize> Backward<B, D1, 2> for IndexSelectDimAssign<D2> {
type State = (usize, IntTensor<B, 1>, Shape<D1>, Shape<D2>, B::Device);
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
let (dim, indexes, shape_lhs, shape_rhs, device) = ops.state;
let [indexes_4lhs, indexes_4rhs] = duplicate(&ops.parents, Some(indexes));
binary::<B, D1, D1, D2, _, _>(
ops.parents,
ops.node,
grads,
|grad| {
let zeros = B::zeros(shape_lhs, &device);
B::index_select_dim_assign(grad, dim, indexes_4lhs.unwrap(), zeros)
},
|grad| {
let zeros = B::zeros(shape_rhs, &device);
B::index_select_dim_assign(zeros, dim, indexes_4rhs.unwrap(), grad)
},
);
}
}
match IndexSelectDimAssign::<D2>
.prepare([tensor.node, value.node], [tensor.graph, value.graph])
.statefull()
{
OpsKind::Tracked(prep) => prep.finish(
(
dim,
indexes.clone(),
B::shape(&tensor.primitive),
B::shape(&value.primitive),
B::device(&value.primitive),
),
B::index_select_dim_assign(tensor.primitive, dim, indexes, value.primitive),
),
OpsKind::UnTracked(prep) => prep.finish(B::index_select_dim_assign(
tensor.primitive,
dim,
indexes,
value.primitive,
)),
}
}
fn index<const D1: usize, const D2: usize>(
tensor: ADTensor<B, D1>,
indexes: [std::ops::Range<usize>; D2],

View File

@ -0,0 +1,34 @@
#[burn_tensor_testgen::testgen(ad_index_select_dim)]
mod tests {
use super::*;
use burn_tensor::Data;
#[test]
fn test_select_grad() {
let tensor_1 =
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
let values =
TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad();
let indexes = TestADTensor::from_data(Data::from([1, 0]));
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1
.clone()
.index_select_dim_assign(0, indexes, values.clone());
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = values.grad(&grads).unwrap();
assert_eq!(
grad_1.into_data(),
Data::from([[127., 199., 271.], [172., 244., 316.]])
);
assert_eq!(
grad_2.into_data(),
Data::from([[64., 64., 64.], [19., 19., 19.]])
);
}
}

View File

@ -11,6 +11,7 @@ mod div;
mod erf;
mod exp;
mod index;
mod index_select_dim;
mod log;
mod log1p;
mod mask;
@ -52,6 +53,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_div!();
burn_autodiff::testgen_ad_erf!();
burn_autodiff::testgen_ad_exp!();
burn_autodiff::testgen_ad_index_select_dim!();
burn_autodiff::testgen_ad_index!();
burn_autodiff::testgen_ad_log!();
burn_autodiff::testgen_ad_log1p!();

View File

@ -27,6 +27,7 @@ extern crate alloc;
mod tests {
type TestBackend = crate::NdArrayBackend<f32>;
type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
burn_tensor::testgen_all!();

View File

@ -203,4 +203,39 @@ where
_ => panic!("Dim not supported {D}"),
}
}
pub fn index_select_dim<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
) -> NdArrayTensor<E, D> {
let array = tensor.array.select(
Axis(dim),
&indexes
.array
.into_iter()
.map(|i| i as usize)
.collect::<Vec<_>>(),
);
NdArrayTensor::new(array.into_shared())
}
pub fn index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
value: NdArrayTensor<E, D2>,
) -> NdArrayTensor<E, D1> {
let mut output_array = tensor.array.into_owned();
for (index_value, index) in indexes.array.into_iter().enumerate() {
let mut view = output_array.index_axis_mut(Axis(dim), index as usize);
let value = value.array.index_axis(Axis(0), index_value);
view.zip_mut_with(&value, |a, b| *a = *a + *b);
}
NdArrayTensor::new(output_array.into_shared())
}
}

View File

@ -263,4 +263,21 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::mean_dim(tensor, dim)
}
fn int_index_select_dim<const D: usize>(
tensor: NdArrayTensor<i64, D>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::index_select_dim(tensor, dim, indexes)
}
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<i64, D1>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
value: NdArrayTensor<i64, D2>,
) -> NdArrayTensor<i64, D1> {
NdArrayMathOps::index_select_dim_assign(tensor, dim, indexes, value)
}
}

View File

@ -151,6 +151,23 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
NdArrayOps::reshape(tensor, shape)
}
fn index_select_dim<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::index_select_dim(tensor, dim, indexes)
}
fn index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
value: NdArrayTensor<E, D2>,
) -> NdArrayTensor<E, D1> {
NdArrayMathOps::index_select_dim_assign(tensor, dim, indexes, value)
}
fn index<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
indexes: [Range<usize>; D2],

View File

@ -10,6 +10,7 @@ pub use tensor::*;
mod tests {
type TestBackend = crate::TchBackend<f32>;
type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
burn_tensor::testgen_all!();
burn_autodiff::testgen_all!();

View File

@ -45,6 +45,27 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
TchTensor::new(tensor_original)
}
pub fn index_select_dim<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
indexes: TchTensor<i64, 1>,
) -> TchTensor<E, D> {
let tensor = tensor.tensor.index_select(dim as i64, &indexes.tensor);
TchTensor::new(tensor)
}
pub fn index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
dim: usize,
indexes: TchTensor<i64, 1>,
value: TchTensor<E, D2>,
) -> TchTensor<E, D1> {
let mut indices = vec![None; D1];
indices[dim] = Some(indexes.tensor);
let tensor = tensor.tensor.index_put(&indices, &value.tensor, true);
TchTensor::new(tensor)
}
pub fn cat<const D: usize>(tensors: Vec<TchTensor<E, D>>, dim: usize) -> TchTensor<E, D> {
let tensors: Vec<tch::Tensor> = tensors
.into_iter()

View File

@ -243,4 +243,21 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
fn int_mean_dim<const D: usize>(tensor: TchTensor<i64, D>, dim: usize) -> TchTensor<i64, D> {
TchOps::mean_dim(tensor, dim)
}
fn int_index_select_dim<const D: usize>(
tensor: TchTensor<i64, D>,
dim: usize,
indexes: TchTensor<i64, 1>,
) -> TchTensor<i64, D> {
TchOps::index_select_dim(tensor, dim, indexes)
}
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: TchTensor<i64, D1>,
dim: usize,
indexes: TchTensor<i64, 1>,
value: TchTensor<i64, D2>,
) -> TchTensor<i64, D1> {
TchOps::index_select_dim_assign(tensor, dim, indexes, value)
}
}

View File

@ -190,6 +190,23 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
TchTensor::new(tensor)
}
fn index_select_dim<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
indexes: TchTensor<i64, 1>,
) -> TchTensor<E, D> {
TchOps::index_select_dim(tensor, dim, indexes)
}
fn index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
dim: usize,
indexes: TchTensor<i64, 1>,
value: TchTensor<E, D2>,
) -> TchTensor<E, D1> {
TchOps::index_select_dim_assign(tensor, dim, indexes, value)
}
fn index<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
indexes: [Range<usize>; D2],

View File

@ -342,6 +342,27 @@ where
self.reshape(shape)
}
/// Index the tensor along the given dimension using the given indexes.
pub fn index_select_dim(self, dim: usize, indexes: Tensor<B, 1, Int>) -> Self {
Self::new(B::index_select_dim(self.primitive, dim, indexes.primitive))
}
/// Return a new tensor with the same dimension, but with the values added to
/// the original tensor using the corresponding indexes provided along the given dimension.
pub fn index_select_dim_assign<const D2: usize>(
self,
dim: usize,
indexes: Tensor<B, 1, Int>,
values: Tensor<B, D2>,
) -> Self {
Self::new(B::index_select_dim_assign(
self.primitive,
dim,
indexes.primitive,
values.primitive,
))
}
pub(crate) fn relu(self) -> Self {
Self::new(B::relu(self.primitive))
}

View File

@ -234,6 +234,17 @@ pub trait Numeric<B: Backend>: TensorKind<B> {
lhs: Self::Primitive<D>,
rhs: Self::Elem,
) -> Tensor<B, D, Bool>;
fn index_select_dim<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
indexes: Tensor<B, 1, Int>,
) -> Self::Primitive<D>;
fn index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
dim: usize,
indexes: Tensor<B, 1, Int>,
values: Self::Primitive<D2>,
) -> Self::Primitive<D1>;
}
impl<B: Backend> Numeric<B> for Int {
@ -361,6 +372,23 @@ impl<B: Backend> Numeric<B> for Int {
) -> Tensor<B, D, Bool> {
Tensor::new(B::int_lower_equal_elem(lhs, rhs))
}
fn index_select_dim<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
indexes: Tensor<B, 1, Int>,
) -> Self::Primitive<D> {
B::int_index_select_dim(tensor, dim, indexes.primitive)
}
fn index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
dim: usize,
indexes: Tensor<B, 1, Int>,
values: Self::Primitive<D2>,
) -> Self::Primitive<D1> {
B::int_index_select_dim_assign(tensor, dim, indexes.primitive, values)
}
}
impl<B: Backend> Numeric<B> for Float {
@ -488,6 +516,23 @@ impl<B: Backend> Numeric<B> for Float {
) -> Tensor<B, D, Bool> {
Tensor::new(B::lower_equal_elem(lhs, rhs))
}
fn index_select_dim<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
indexes: Tensor<B, 1, Int>,
) -> Self::Primitive<D> {
B::index_select_dim(tensor, dim, indexes.primitive)
}
fn index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
dim: usize,
indexes: Tensor<B, 1, Int>,
values: Self::Primitive<D2>,
) -> Self::Primitive<D1> {
B::index_select_dim_assign(tensor, dim, indexes.primitive, values)
}
}
impl<B, const D: usize, K> core::ops::Add<Self> for Tensor<B, D, K>

View File

@ -34,6 +34,17 @@ pub trait IntTensorOps<B: Backend> {
indexes: [Range<usize>; D2],
value: B::IntTensorPrimitive<D1>,
) -> B::IntTensorPrimitive<D1>;
fn int_index_select_dim<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
indexes: B::IntTensorPrimitive<1>,
) -> B::IntTensorPrimitive<D>;
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: B::IntTensorPrimitive<D1>,
dim: usize,
indexes: B::IntTensorPrimitive<1>,
value: B::IntTensorPrimitive<D2>,
) -> B::IntTensorPrimitive<D1>;
fn int_repeat<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,

View File

@ -115,6 +115,17 @@ pub trait TensorOps<B: Backend> {
tensor: B::TensorPrimitive<D1>,
shape: Shape<D2>,
) -> B::TensorPrimitive<D2>;
fn index_select_dim<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
indexes: B::IntTensorPrimitive<1>,
) -> B::TensorPrimitive<D>;
fn index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: B::TensorPrimitive<D1>,
dim: usize,
indexes: B::IntTensorPrimitive<1>,
value: B::TensorPrimitive<D2>,
) -> B::TensorPrimitive<D1>;
fn index<const D1: usize, const D2: usize>(
tensor: B::TensorPrimitive<D1>,
indexes: [Range<usize>; D2],

View File

@ -28,6 +28,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_exp!();
burn_tensor::testgen_log!();
burn_tensor::testgen_log1p!();
burn_tensor::testgen_index_select_dim!();
burn_tensor::testgen_index!();
burn_tensor::testgen_map_comparison!();
burn_tensor::testgen_mask!();

View File

@ -0,0 +1,84 @@
#[burn_tensor_testgen::testgen(index_select_dim)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
#[test]
fn should_select_1d() {
let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0]));
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
let output = tensor.index_select_dim(0, indexes);
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
}
#[test]
fn should_select_2d_dim0_same_num_dim() {
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
let indexes = TestTensorInt::from_data(Data::from([1, 0]));
let output = tensor.index_select_dim(0, indexes);
assert_eq!(
output.into_data(),
Data::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]])
);
}
#[test]
fn should_select_2d_dim0_more_num_dim() {
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
let indexes = TestTensorInt::from_data(Data::from([1, 0, 1, 1]));
let output = tensor.index_select_dim(0, indexes);
assert_eq!(
output.into_data(),
Data::from([
[3.0, 4.0, 5.0],
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[3.0, 4.0, 5.0]
])
);
}
#[test]
fn should_select_2d_dim1() {
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
let output = tensor.index_select_dim(1, indexes);
assert_eq!(
output.into_data(),
Data::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]])
);
}
#[test]
fn should_select_assign_1d() {
let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0]));
let values = TestTensor::from_data(Data::from([5.0, 4.0, 3.0, 2.0, 1.0]));
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
let output = tensor.index_select_dim_assign(0, indexes, values);
assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0]));
}
#[test]
fn should_select_assign_2d_dim0() {
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
let values = TestTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]));
let indexes = TestTensorInt::from_data(Data::from([1, 0]));
let output = tensor.index_select_dim_assign(0, indexes, values);
assert_eq!(
output.into_data(),
Data::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]])
);
}
}

View File

@ -6,6 +6,7 @@ mod div;
mod erf;
mod exp;
mod index;
mod index_select_dim;
mod log;
mod log1p;
mod map_comparison;