Add flatten op to the tensor base (#260)

This commit is contained in:
Dilshod Tadjibaev 2023-03-31 15:44:48 -05:00 committed by GitHub
parent c14c7977ec
commit 7364d09d32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 116 additions and 0 deletions

View File

@ -48,6 +48,70 @@ where
Tensor::new(K::reshape::<D, D2>(self.primitive, shape.into()))
}
/// Flatten the tensor along a given range of dimensions.
///
/// This function collapses the specified range of dimensions into a single dimension,
/// effectively flattening the tensor in that range.
///
/// # Arguments
///
/// - `start_dim`: The starting dimension of the range to be flattened.
/// - `end_dim`: The ending dimension of the range to be flattened (inclusive).
///
/// # Type Parameters
///
/// - `D2`: The resulting number of dimensions in the flattened tensor.
///
/// # Returns
///
/// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
///
/// # Example
///
/// ```rust
///
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Shape};
///
/// fn example<B: Backend>() {
/// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]));
///
/// // Given a 3D tensor with dimensions (2, 3, 4), flatten the dimensions between indices 1 and 2:
/// let flattened_tensor: Tensor::<B, 2> = tensor.flatten(1, 2);
///
/// // The resulting tensor will have dimensions (2, 12).
/// println!("{:?}", flattened_tensor.shape());
/// }
///
/// ```
pub fn flatten<const D2: usize>(self, start_dim: usize, end_dim: usize) -> Tensor<B, D2, K> {
if start_dim > end_dim {
panic!("The start dim ({start_dim}) must be smaller than the end dim ({end_dim})")
}
if D2 > D {
panic!("Result dim ({D2}) must be smaller than ({D})")
}
if D < end_dim + 1 {
panic!("The end dim ({end_dim}) must be greater than the tensor dim ({D2})")
}
let current_dims = self.shape().dims;
let mut new_dims: [usize; D2] = [0; D2];
let mut flatten_dims = 1;
for i in current_dims[start_dim..=end_dim].iter() {
flatten_dims *= i;
}
new_dims[..start_dim].copy_from_slice(&current_dims[..start_dim]);
new_dims[start_dim] = flatten_dims;
new_dims[start_dim + 1..].copy_from_slice(&current_dims[end_dim + 1..]);
Tensor::new(K::reshape::<D, D2>(self.primitive, new_dims.into()))
}
/// Returns a tensor containing the elements selected from the given ranges.
///
/// # Panics

View File

@ -327,6 +327,7 @@ where
/// // Shape { dims: [1, 1, 3, 3] }
/// }
/// ```
/// TODO move this function to the base.
pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2> {
if D2 < D {
panic!("Can't unsqueeze smaller tensor, got dim {D2}, expected > {D}")

View File

@ -39,6 +39,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_powf!();
burn_tensor::testgen_repeat!();
burn_tensor::testgen_reshape!();
burn_tensor::testgen_flatten!();
burn_tensor::testgen_sin!();
burn_tensor::testgen_tanh!();
burn_tensor::testgen_sub!();

View File

@ -0,0 +1,49 @@
#[burn_tensor_testgen::testgen(flatten)]
mod tests {
use super::*;
use burn_tensor::{Data, Shape, Tensor};
/// Test if the function can successfully flatten a 4D tensor to a 1D tensor.
#[test]
fn should_flatten_to_1d() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
let flattened_tensor: Tensor<TestBackend, 1> = tensor.flatten(0, 3);
let expected_shape = Shape::new([120]);
assert_eq!(flattened_tensor.shape(), expected_shape);
}
/// Test if the function can successfully flatten the middle dimensions of a 4D tensor.
#[test]
fn should_flatten_middle() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
let flattened_tensor: Tensor<TestBackend, 3> = tensor.flatten(1, 2);
let expected_shape = Shape::new([2, 12, 5]);
assert_eq!(flattened_tensor.shape(), expected_shape);
}
/// Test if the function can successfully flatten the first dimensions of a 4D tensor.
#[test]
fn should_flatten_begin() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
let flattened_tensor: Tensor<TestBackend, 2> = tensor.flatten(0, 2);
let expected_shape = Shape::new([24, 5]);
assert_eq!(flattened_tensor.shape(), expected_shape);
}
/// Test if the function can successfully flatten the last dimensions of a 4D tensor.
#[test]
fn should_flatten_end() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
let flattened_tensor: Tensor<TestBackend, 2> = tensor.flatten(1, 3);
let expected_shape = Shape::new([2, 60]);
assert_eq!(flattened_tensor.shape(), expected_shape);
}
/// Test if the function panics when the start dimension is greater than the end dimension.
#[test]
#[should_panic]
fn should_flatten_panic() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
let flattened_tensor: Tensor<TestBackend, 2> = tensor.flatten(2, 0);
}
}

View File

@ -5,6 +5,7 @@ mod cos;
mod div;
mod erf;
mod exp;
mod flatten;
mod index;
mod index_select;
mod index_select_dim;