mirror of https://github.com/tracel-ai/burn.git
Add flatten op to the tensor base (#260)
This commit is contained in:
parent
c14c7977ec
commit
7364d09d32
|
@ -48,6 +48,70 @@ where
|
|||
Tensor::new(K::reshape::<D, D2>(self.primitive, shape.into()))
|
||||
}
|
||||
|
||||
/// Flatten the tensor along a given range of dimensions.
|
||||
///
|
||||
/// This function collapses the specified range of dimensions into a single dimension,
|
||||
/// effectively flattening the tensor in that range.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `start_dim`: The starting dimension of the range to be flattened.
|
||||
/// - `end_dim`: The ending dimension of the range to be flattened (inclusive).
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// - `D2`: The resulting number of dimensions in the flattened tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
///
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Shape};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]));
|
||||
///
|
||||
/// // Given a 3D tensor with dimensions (2, 3, 4), flatten the dimensions between indices 1 and 2:
|
||||
/// let flattened_tensor: Tensor::<B, 2> = tensor.flatten(1, 2);
|
||||
///
|
||||
/// // The resulting tensor will have dimensions (2, 12).
|
||||
/// println!("{:?}", flattened_tensor.shape());
|
||||
/// }
|
||||
///
|
||||
/// ```
|
||||
pub fn flatten<const D2: usize>(self, start_dim: usize, end_dim: usize) -> Tensor<B, D2, K> {
|
||||
if start_dim > end_dim {
|
||||
panic!("The start dim ({start_dim}) must be smaller than the end dim ({end_dim})")
|
||||
}
|
||||
|
||||
if D2 > D {
|
||||
panic!("Result dim ({D2}) must be smaller than ({D})")
|
||||
}
|
||||
|
||||
if D < end_dim + 1 {
|
||||
panic!("The end dim ({end_dim}) must be greater than the tensor dim ({D2})")
|
||||
}
|
||||
|
||||
let current_dims = self.shape().dims;
|
||||
let mut new_dims: [usize; D2] = [0; D2];
|
||||
let mut flatten_dims = 1;
|
||||
|
||||
for i in current_dims[start_dim..=end_dim].iter() {
|
||||
flatten_dims *= i;
|
||||
}
|
||||
|
||||
new_dims[..start_dim].copy_from_slice(¤t_dims[..start_dim]);
|
||||
new_dims[start_dim] = flatten_dims;
|
||||
new_dims[start_dim + 1..].copy_from_slice(¤t_dims[end_dim + 1..]);
|
||||
|
||||
Tensor::new(K::reshape::<D, D2>(self.primitive, new_dims.into()))
|
||||
}
|
||||
|
||||
/// Returns a tensor containing the elements selected from the given ranges.
|
||||
///
|
||||
/// # Panics
|
||||
|
|
|
@ -327,6 +327,7 @@ where
|
|||
/// // Shape { dims: [1, 1, 3, 3] }
|
||||
/// }
|
||||
/// ```
|
||||
/// TODO move this function to the base.
|
||||
pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2> {
|
||||
if D2 < D {
|
||||
panic!("Can't unsqueeze smaller tensor, got dim {D2}, expected > {D}")
|
||||
|
|
|
@ -39,6 +39,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_powf!();
|
||||
burn_tensor::testgen_repeat!();
|
||||
burn_tensor::testgen_reshape!();
|
||||
burn_tensor::testgen_flatten!();
|
||||
burn_tensor::testgen_sin!();
|
||||
burn_tensor::testgen_tanh!();
|
||||
burn_tensor::testgen_sub!();
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
#[burn_tensor_testgen::testgen(flatten)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Shape, Tensor};
|
||||
|
||||
/// Test if the function can successfully flatten a 4D tensor to a 1D tensor.
|
||||
#[test]
|
||||
fn should_flatten_to_1d() {
|
||||
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
|
||||
let flattened_tensor: Tensor<TestBackend, 1> = tensor.flatten(0, 3);
|
||||
let expected_shape = Shape::new([120]);
|
||||
assert_eq!(flattened_tensor.shape(), expected_shape);
|
||||
}
|
||||
|
||||
/// Test if the function can successfully flatten the middle dimensions of a 4D tensor.
|
||||
#[test]
|
||||
fn should_flatten_middle() {
|
||||
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
|
||||
let flattened_tensor: Tensor<TestBackend, 3> = tensor.flatten(1, 2);
|
||||
let expected_shape = Shape::new([2, 12, 5]);
|
||||
assert_eq!(flattened_tensor.shape(), expected_shape);
|
||||
}
|
||||
|
||||
/// Test if the function can successfully flatten the first dimensions of a 4D tensor.
|
||||
#[test]
|
||||
fn should_flatten_begin() {
|
||||
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
|
||||
let flattened_tensor: Tensor<TestBackend, 2> = tensor.flatten(0, 2);
|
||||
let expected_shape = Shape::new([24, 5]);
|
||||
assert_eq!(flattened_tensor.shape(), expected_shape);
|
||||
}
|
||||
|
||||
/// Test if the function can successfully flatten the last dimensions of a 4D tensor.
|
||||
#[test]
|
||||
fn should_flatten_end() {
|
||||
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
|
||||
let flattened_tensor: Tensor<TestBackend, 2> = tensor.flatten(1, 3);
|
||||
let expected_shape = Shape::new([2, 60]);
|
||||
assert_eq!(flattened_tensor.shape(), expected_shape);
|
||||
}
|
||||
|
||||
/// Test if the function panics when the start dimension is greater than the end dimension.
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_flatten_panic() {
|
||||
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
|
||||
let flattened_tensor: Tensor<TestBackend, 2> = tensor.flatten(2, 0);
|
||||
}
|
||||
}
|
|
@ -5,6 +5,7 @@ mod cos;
|
|||
mod div;
|
||||
mod erf;
|
||||
mod exp;
|
||||
mod flatten;
|
||||
mod index;
|
||||
mod index_select;
|
||||
mod index_select_dim;
|
||||
|
|
Loading…
Reference in New Issue