Implement chunk for different backends (#1032)

This commit is contained in:
Kelvin Wu 2023-12-21 02:35:59 +08:00 committed by GitHub
parent c1cb77ac2e
commit 7c6f017c98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 450 additions and 45 deletions

View File

@ -92,4 +92,21 @@ impl<B: Backend> BoolTensorOps<Self> for Autodiff<B> {
) -> <Autodiff<B> as Backend>::BoolTensorPrimitive<D> {
B::bool_swap_dims(tensor, dim1, dim2)
}
fn bool_narrow<const D: usize>(
tensor: BoolTensor<B, D>,
dim: usize,
start: usize,
length: usize,
) -> BoolTensor<B, D> {
B::bool_narrow(tensor, dim, start, length)
}
fn bool_chunk<const D: usize>(
tensor: BoolTensor<B, D>,
chunks: usize,
dim: usize,
) -> Vec<BoolTensor<B, D>> {
B::bool_chunk(tensor, chunks, dim)
}
}

View File

@ -316,4 +316,21 @@ impl<B: Backend> IntTensorOps<Autodiff<B>> for Autodiff<B> {
) -> <Autodiff<B> as Backend>::IntTensorPrimitive<D> {
B::int_swap_dims(tensor, dim1, dim2)
}
fn int_narrow<const D: usize>(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive<D>,
dim: usize,
start: usize,
length: usize,
) -> <Autodiff<B> as Backend>::IntTensorPrimitive<D> {
B::int_narrow(tensor, dim, start, length)
}
fn int_chunk<const D: usize>(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive<D>,
chunks: usize,
dim: usize,
) -> Vec<<Autodiff<B> as Backend>::IntTensorPrimitive<D>> {
B::int_chunk(tensor, chunks, dim)
}
}

View File

@ -88,3 +88,31 @@ pub fn slice_assign<E: CandleElement, const D1: usize, const D2: usize>(
) -> CandleTensor<E, D1> {
panic!("slice_assign not supported by Candle")
}
pub fn narrow<E: CandleElement, const D: usize>(
tensor: CandleTensor<E, D>,
dim: usize,
start: usize,
length: usize,
) -> CandleTensor<E, D> {
let tensor = tensor.tensor.narrow(dim, start, length);
match tensor {
Ok(tensor) => CandleTensor::new(tensor),
Err(e) => panic!("error narrow from Candle"),
}
}
pub fn chunk<E: CandleElement, const D: usize>(
tensor: CandleTensor<E, D>,
chunks: usize,
dim: usize,
) -> Vec<CandleTensor<E, D>> {
let tensors = tensor.tensor.chunk(chunks, dim);
match tensors {
Ok(tensors) => tensors
.into_iter()
.map(|tensor| CandleTensor::new(tensor))
.collect(),
Err(e) => panic!("error chunk from Candle"),
}
}

View File

@ -109,4 +109,21 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
) -> <Candle<F, I> as burn_tensor::backend::Backend>::BoolTensorPrimitive<D> {
super::base::swap_dims(tensor, dim1, dim2)
}
fn bool_narrow<const D: usize>(
tensor: BoolTensor<Self, D>,
dim: usize,
start: usize,
length: usize,
) -> BoolTensor<Self, D> {
super::base::narrow(tensor, dim, start, length)
}
fn bool_chunk<const D: usize>(
tensor: BoolTensor<Self, D>,
chunks: usize,
dim: usize,
) -> Vec<BoolTensor<Self, D>> {
super::base::chunk(tensor, chunks, dim)
}
}

View File

@ -359,4 +359,21 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
) -> <Candle<F, I> as burn_tensor::backend::Backend>::IntTensorPrimitive<D> {
super::base::swap_dims(tensor, dim1, dim2)
}
fn int_narrow<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
start: usize,
length: usize,
) -> IntTensor<Self, D> {
super::base::narrow(tensor, dim, start, length)
}
fn int_chunk<const D: usize>(
tensor: IntTensor<Self, D>,
chunks: usize,
dim: usize,
) -> Vec<IntTensor<Self, D>> {
super::base::chunk(tensor, chunks, dim)
}
}

View File

@ -448,4 +448,21 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.recip().unwrap())
}
fn narrow<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
start: usize,
length: usize,
) -> FloatTensor<Self, D> {
super::base::narrow(tensor, dim, start, length)
}
fn chunk<const D: usize>(
tensor: FloatTensor<Self, D>,
chunks: usize,
dim: usize,
) -> Vec<FloatTensor<Self, D>> {
super::base::chunk(tensor, chunks, dim)
}
}

View File

@ -413,4 +413,30 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64);
TchTensor::new(tensor)
}
pub fn narrow<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
start: usize,
length: usize,
) -> TchTensor<E, D> {
TchTensor::new(
tensor
.tensor
.narrow(dim as i64, start as i64, length as i64),
)
}
pub fn chunk<const D: usize>(
tensor: TchTensor<E, D>,
chunks: usize,
dim: usize,
) -> Vec<TchTensor<E, D>> {
tensor
.tensor
.chunk(chunks as i64, dim as i64)
.into_iter()
.map(|tensor| TchTensor::new(tensor))
.collect()
}
}

View File

@ -114,4 +114,21 @@ impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
) -> <LibTorch<E> as Backend>::BoolTensorPrimitive<D> {
TchOps::swap_dims(tensor, dim1, dim2)
}
fn bool_narrow<const D: usize>(
tensor: TchTensor<bool, D>,
dim: usize,
start: usize,
length: usize,
) -> TchTensor<bool, D> {
TchOps::narrow(tensor, dim, start, length)
}
fn bool_chunk<const D: usize>(
tensor: TchTensor<bool, D>,
chunks: usize,
dim: usize,
) -> Vec<TchTensor<bool, D>> {
TchOps::chunk(tensor, chunks, dim)
}
}

View File

@ -401,4 +401,21 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
) -> <LibTorch<E> as Backend>::IntTensorPrimitive<D> {
TchOps::swap_dims(tensor, dim1, dim2)
}
fn int_narrow<const D: usize>(
tensor: TchTensor<i64, D>,
dim: usize,
start: usize,
length: usize,
) -> TchTensor<i64, D> {
TchOps::narrow(tensor, dim, start, length)
}
fn int_chunk<const D: usize>(
tensor: TchTensor<i64, D>,
chunks: usize,
dim: usize,
) -> Vec<TchTensor<i64, D>> {
TchOps::chunk(tensor, chunks, dim)
}
}

View File

@ -440,4 +440,21 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
TchTensor::new(tensor)
}
fn narrow<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
start: usize,
length: usize,
) -> TchTensor<E, D> {
TchOps::narrow(tensor, dim, start, length)
}
fn chunk<const D: usize>(
tensor: TchTensor<E, D>,
chunks: usize,
dim: usize,
) -> Vec<TchTensor<E, D>> {
TchOps::chunk(tensor, chunks, dim)
}
}

View File

@ -12,9 +12,10 @@ use alloc::vec;
use burn_common::{reader::Reader, stub::Mutex};
use core::{fmt::Debug, ops::Range};
use crate::{
backend::Backend, check, check::TensorCheck, Bool, Data, Float, Int, Shape, TensorKind,
};
use crate::check::TensorCheck;
use crate::tensor::api::chunk::chunk;
use crate::tensor::api::narrow::narrow;
use crate::{backend::Backend, check, Bool, Data, Float, Int, Shape, TensorKind};
/// A tensor with a given backend, shape and data type.
#[derive(new, Clone, Debug)]
@ -496,20 +497,7 @@ where
pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
check!(TensorCheck::dim_ops::<D>("narrow", dim));
check!(TensorCheck::narrow(&self, dim, start, length));
let ranges: Vec<_> = (0..D)
.map(|i| {
if i == dim {
start..(start + length)
} else {
0..self.shape().dims[i]
}
})
.collect();
let ranges_array: [_; D] = ranges.try_into().unwrap();
self.slice(ranges_array)
Self::new(narrow::<B, D, K>(self.primitive, dim, start, length))
}
/// Attempts to split the tensor along the given dimension into chunks.
@ -526,31 +514,10 @@ where
/// A vector of tensors.
pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
check!(TensorCheck::dim_ops::<D>("chunk", dim));
let size = self.shape().dims[dim];
if size < chunks {
return (0..size).map(|i| self.clone().narrow(dim, i, 1)).collect();
}
let mut tensors = Vec::with_capacity(chunks);
let mut sum_chunk_size = 0;
if size % chunks == 0 {
let chunk_size = size / chunks;
for _ in 0..chunks {
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
sum_chunk_size += chunk_size;
}
} else {
let chunk_size = (size / chunks) + 1; // assumes not divisible
for _ in 0..chunks - 1 {
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
sum_chunk_size += chunk_size;
}
let remainder = size % chunk_size;
tensors.push(self.clone().narrow(dim, sum_chunk_size, remainder));
}
tensors
chunk::<B, D, K>(self.primitive, chunks, dim)
.into_iter()
.map(|v| Self::new(v))
.collect()
}
}

View File

@ -0,0 +1,69 @@
use super::narrow::narrow;
use crate::{backend::Backend, BasicOps, TensorKind};
use alloc::vec::Vec;
/// Split the tensor along the given dimension into chunks.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `chunks` - The number of chunks to be produced
/// * `times` - The dimension along which the tensor will be split.
///
/// # Returns
///
/// A vectors of tensors
///
/// # Remarks
///
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
pub fn chunk<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive<D>,
chunks: usize,
dim: usize,
) -> Vec<K::Primitive<D>> {
let size = K::shape(&tensor).dims[dim];
if size < chunks {
return (0..size)
.map(|i| narrow::<B, D, K>(tensor.clone(), dim, i, 1))
.collect();
}
let mut tensors = Vec::with_capacity(chunks);
let mut sum_chunk_size = 0;
if size % chunks == 0 {
let chunk_size = size / chunks;
for _ in 0..chunks {
tensors.push(narrow::<B, D, K>(
tensor.clone(),
dim,
sum_chunk_size,
chunk_size,
));
sum_chunk_size += chunk_size;
}
} else {
let chunk_size = (size / chunks) + 1; // assumes not divisible
for _ in 0..chunks - 1 {
tensors.push(narrow::<B, D, K>(
tensor.clone(),
dim,
sum_chunk_size,
chunk_size,
));
sum_chunk_size += chunk_size;
}
let remainder = size % chunk_size;
tensors.push(narrow::<B, D, K>(
tensor.clone(),
dim,
sum_chunk_size,
remainder,
));
}
tensors
}

View File

@ -3,12 +3,16 @@ pub(crate) mod check;
mod autodiff;
mod base;
mod bool;
mod chunk;
mod float;
mod int;
mod kind;
mod narrow;
mod numeric;
pub use autodiff::*;
pub use base::*;
pub use chunk::chunk;
pub use kind::*;
pub use narrow::narrow;
pub use numeric::*;

View File

@ -0,0 +1,41 @@
use crate::{backend::Backend, BasicOps, TensorKind};
use alloc::vec::Vec;
/// Returns a new tensor with the given dimension narrowed to the given range.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `dim` - The dimension along which the tensor will be narrowed.
/// * `start` - The starting point of the given range.
/// * `length` - The ending point of the given range.
/// # Panics
///
/// - If the dimension is greater than the number of dimensions of the tensor.
/// - If the given range exceeds the number of elements on the given dimension.
///
/// # Returns
///
/// A new tensor with the given dimension narrowed to the given range.
pub fn narrow<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive<D>,
dim: usize,
start: usize,
length: usize,
) -> K::Primitive<D> {
let shape = K::shape(&tensor);
let ranges: Vec<_> = (0..D)
.map(|i| {
if i == dim {
start..(start + length)
} else {
0..shape.dims[i]
}
})
.collect();
let ranges_array: [_; D] = ranges.try_into().unwrap();
K::slice(tensor, ranges_array)
}

View File

@ -1,5 +1,5 @@
use super::{BoolTensor, Device, FloatTensor, IntTensor};
use crate::{backend::Backend, tensor::Shape, Data};
use crate::{backend::Backend, chunk, narrow, tensor::Shape, Bool, Data};
use alloc::vec::Vec;
use burn_common::reader::Reader;
use core::ops::Range;
@ -258,4 +258,48 @@ pub trait BoolTensorOps<B: Backend> {
dim1: usize,
dim2: usize,
) -> BoolTensor<B, D>;
/// Returns a new tensor with the given dimension narrowed to the given range.
///
/// # Arguments
///
/// * `dim` - The dimension along which the tensor will be narrowed.
/// * `start` - The starting point of the given range.
/// * `length` - The ending point of the given range.
/// # Panics
///
/// - If the dimension is greater than the number of dimensions of the tensor.
/// - If the given range exceeds the number of elements on the given dimension.
///
/// # Returns
///
/// A new tensor with the given dimension narrowed to the given range.
fn bool_narrow<const D: usize>(
tensor: BoolTensor<B, D>,
dim: usize,
start: usize,
length: usize,
) -> BoolTensor<B, D> {
narrow::<B, D, Bool>(tensor, dim, start, length)
}
/// Split the tensor along the given dimension into chunks.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `chunks` - The number of chunks to be produced
/// * `times` - The dimension along which the tensor will be split.
///
/// # Returns
///
/// A vectors of tensors
///
fn bool_chunk<const D: usize>(
tensor: BoolTensor<B, D>,
chunks: usize,
dim: usize,
) -> Vec<BoolTensor<B, D>> {
chunk::<B, D, Bool>(tensor, chunks, dim)
}
}

View File

@ -1,5 +1,6 @@
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
use crate::{backend::Backend, tensor::Shape, Data, ElementConversion};
use crate::{backend::Backend, tensor::Shape, Data, ElementConversion, Int};
use crate::{tensor::api::chunk, tensor::api::narrow};
use alloc::vec::Vec;
use burn_common::reader::Reader;
use core::ops::Range;
@ -850,4 +851,48 @@ pub trait IntTensorOps<B: Backend> {
dim1: usize,
dim2: usize,
) -> IntTensor<B, D>;
/// Returns a new tensor with the given dimension narrowed to the given range.
///
/// # Arguments
///
/// * `dim` - The dimension along which the tensor will be narrowed.
/// * `start` - The starting point of the given range.
/// * `length` - The ending point of the given range.
/// # Panics
///
/// - If the dimension is greater than the number of dimensions of the tensor.
/// - If the given range exceeds the number of elements on the given dimension.
///
/// # Returns
///
/// A new tensor with the given dimension narrowed to the given range.
fn int_narrow<const D: usize>(
tensor: IntTensor<B, D>,
dim: usize,
start: usize,
length: usize,
) -> IntTensor<B, D> {
narrow::<B, D, Int>(tensor, dim, start, length)
}
/// Split the tensor along the given dimension into chunks.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `chunks` - The number of chunks to be produced
/// * `times` - The dimension along which the tensor will be split.
///
/// # Returns
///
/// A vectors of tensors
///
fn int_chunk<const D: usize>(
tensor: IntTensor<B, D>,
chunks: usize,
dim: usize,
) -> Vec<IntTensor<B, D>> {
chunk::<B, D, Int>(tensor, chunks, dim)
}
}

View File

@ -1,5 +1,6 @@
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion};
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Float};
use crate::{tensor::api::chunk, tensor::api::narrow};
use alloc::vec::Vec;
use burn_common::reader::Reader;
use core::ops::Range;
@ -1075,4 +1076,48 @@ pub trait TensorOps<B: Backend> {
(values, index)
}
/// Returns a new tensor with the given dimension narrowed to the given range.
///
/// # Arguments
///
/// * `dim` - The dimension along which the tensor will be narrowed.
/// * `start` - The starting point of the given range.
/// * `length` - The ending point of the given range.
/// # Panics
///
/// - If the dimension is greater than the number of dimensions of the tensor.
/// - If the given range exceeds the number of elements on the given dimension.
///
/// # Returns
///
/// A new tensor with the given dimension narrowed to the given range.
fn narrow<const D: usize>(
tensor: FloatTensor<B, D>,
dim: usize,
start: usize,
length: usize,
) -> FloatTensor<B, D> {
narrow::<B, D, Float>(tensor, dim, start, length)
}
/// Split the tensor along the given dimension into chunks.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `chunks` - The number of chunks to be produced
/// * `times` - The dimension along which the tensor will be split.
///
/// # Returns
///
/// A vectors of tensors
///
fn chunk<const D: usize>(
tensor: FloatTensor<B, D>,
chunks: usize,
dim: usize,
) -> Vec<FloatTensor<B, D>> {
chunk::<B, D, Float>(tensor, chunks, dim)
}
}