mirror of https://github.com/tracel-ai/burn.git
Implement chunk for different backends (#1032)
This commit is contained in:
parent
c1cb77ac2e
commit
7c6f017c98
|
@ -92,4 +92,21 @@ impl<B: Backend> BoolTensorOps<Self> for Autodiff<B> {
|
|||
) -> <Autodiff<B> as Backend>::BoolTensorPrimitive<D> {
|
||||
B::bool_swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn bool_narrow<const D: usize>(
|
||||
tensor: BoolTensor<B, D>,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
length: usize,
|
||||
) -> BoolTensor<B, D> {
|
||||
B::bool_narrow(tensor, dim, start, length)
|
||||
}
|
||||
|
||||
fn bool_chunk<const D: usize>(
|
||||
tensor: BoolTensor<B, D>,
|
||||
chunks: usize,
|
||||
dim: usize,
|
||||
) -> Vec<BoolTensor<B, D>> {
|
||||
B::bool_chunk(tensor, chunks, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -316,4 +316,21 @@ impl<B: Backend> IntTensorOps<Autodiff<B>> for Autodiff<B> {
|
|||
) -> <Autodiff<B> as Backend>::IntTensorPrimitive<D> {
|
||||
B::int_swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn int_narrow<const D: usize>(
|
||||
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
length: usize,
|
||||
) -> <Autodiff<B> as Backend>::IntTensorPrimitive<D> {
|
||||
B::int_narrow(tensor, dim, start, length)
|
||||
}
|
||||
|
||||
fn int_chunk<const D: usize>(
|
||||
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive<D>,
|
||||
chunks: usize,
|
||||
dim: usize,
|
||||
) -> Vec<<Autodiff<B> as Backend>::IntTensorPrimitive<D>> {
|
||||
B::int_chunk(tensor, chunks, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,3 +88,31 @@ pub fn slice_assign<E: CandleElement, const D1: usize, const D2: usize>(
|
|||
) -> CandleTensor<E, D1> {
|
||||
panic!("slice_assign not supported by Candle")
|
||||
}
|
||||
|
||||
pub fn narrow<E: CandleElement, const D: usize>(
|
||||
tensor: CandleTensor<E, D>,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
length: usize,
|
||||
) -> CandleTensor<E, D> {
|
||||
let tensor = tensor.tensor.narrow(dim, start, length);
|
||||
match tensor {
|
||||
Ok(tensor) => CandleTensor::new(tensor),
|
||||
Err(e) => panic!("error narrow from Candle"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chunk<E: CandleElement, const D: usize>(
|
||||
tensor: CandleTensor<E, D>,
|
||||
chunks: usize,
|
||||
dim: usize,
|
||||
) -> Vec<CandleTensor<E, D>> {
|
||||
let tensors = tensor.tensor.chunk(chunks, dim);
|
||||
match tensors {
|
||||
Ok(tensors) => tensors
|
||||
.into_iter()
|
||||
.map(|tensor| CandleTensor::new(tensor))
|
||||
.collect(),
|
||||
Err(e) => panic!("error chunk from Candle"),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -109,4 +109,21 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
|
|||
) -> <Candle<F, I> as burn_tensor::backend::Backend>::BoolTensorPrimitive<D> {
|
||||
super::base::swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn bool_narrow<const D: usize>(
|
||||
tensor: BoolTensor<Self, D>,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
length: usize,
|
||||
) -> BoolTensor<Self, D> {
|
||||
super::base::narrow(tensor, dim, start, length)
|
||||
}
|
||||
|
||||
fn bool_chunk<const D: usize>(
|
||||
tensor: BoolTensor<Self, D>,
|
||||
chunks: usize,
|
||||
dim: usize,
|
||||
) -> Vec<BoolTensor<Self, D>> {
|
||||
super::base::chunk(tensor, chunks, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -359,4 +359,21 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
|
|||
) -> <Candle<F, I> as burn_tensor::backend::Backend>::IntTensorPrimitive<D> {
|
||||
super::base::swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn int_narrow<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
length: usize,
|
||||
) -> IntTensor<Self, D> {
|
||||
super::base::narrow(tensor, dim, start, length)
|
||||
}
|
||||
|
||||
fn int_chunk<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
chunks: usize,
|
||||
dim: usize,
|
||||
) -> Vec<IntTensor<Self, D>> {
|
||||
super::base::chunk(tensor, chunks, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -448,4 +448,21 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
|
|||
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
CandleTensor::new(tensor.tensor.recip().unwrap())
|
||||
}
|
||||
|
||||
fn narrow<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
length: usize,
|
||||
) -> FloatTensor<Self, D> {
|
||||
super::base::narrow(tensor, dim, start, length)
|
||||
}
|
||||
|
||||
fn chunk<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
chunks: usize,
|
||||
dim: usize,
|
||||
) -> Vec<FloatTensor<Self, D>> {
|
||||
super::base::chunk(tensor, chunks, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -413,4 +413,30 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn narrow<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
length: usize,
|
||||
) -> TchTensor<E, D> {
|
||||
TchTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.narrow(dim as i64, start as i64, length as i64),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn chunk<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
chunks: usize,
|
||||
dim: usize,
|
||||
) -> Vec<TchTensor<E, D>> {
|
||||
tensor
|
||||
.tensor
|
||||
.chunk(chunks as i64, dim as i64)
|
||||
.into_iter()
|
||||
.map(|tensor| TchTensor::new(tensor))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -114,4 +114,21 @@ impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
|
|||
) -> <LibTorch<E> as Backend>::BoolTensorPrimitive<D> {
|
||||
TchOps::swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn bool_narrow<const D: usize>(
|
||||
tensor: TchTensor<bool, D>,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
length: usize,
|
||||
) -> TchTensor<bool, D> {
|
||||
TchOps::narrow(tensor, dim, start, length)
|
||||
}
|
||||
|
||||
fn bool_chunk<const D: usize>(
|
||||
tensor: TchTensor<bool, D>,
|
||||
chunks: usize,
|
||||
dim: usize,
|
||||
) -> Vec<TchTensor<bool, D>> {
|
||||
TchOps::chunk(tensor, chunks, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -401,4 +401,21 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
|
|||
) -> <LibTorch<E> as Backend>::IntTensorPrimitive<D> {
|
||||
TchOps::swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn int_narrow<const D: usize>(
|
||||
tensor: TchTensor<i64, D>,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
length: usize,
|
||||
) -> TchTensor<i64, D> {
|
||||
TchOps::narrow(tensor, dim, start, length)
|
||||
}
|
||||
|
||||
fn int_chunk<const D: usize>(
|
||||
tensor: TchTensor<i64, D>,
|
||||
chunks: usize,
|
||||
dim: usize,
|
||||
) -> Vec<TchTensor<i64, D>> {
|
||||
TchOps::chunk(tensor, chunks, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -440,4 +440,21 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
|
|||
let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn narrow<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
length: usize,
|
||||
) -> TchTensor<E, D> {
|
||||
TchOps::narrow(tensor, dim, start, length)
|
||||
}
|
||||
|
||||
fn chunk<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
chunks: usize,
|
||||
dim: usize,
|
||||
) -> Vec<TchTensor<E, D>> {
|
||||
TchOps::chunk(tensor, chunks, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,9 +12,10 @@ use alloc::vec;
|
|||
use burn_common::{reader::Reader, stub::Mutex};
|
||||
use core::{fmt::Debug, ops::Range};
|
||||
|
||||
use crate::{
|
||||
backend::Backend, check, check::TensorCheck, Bool, Data, Float, Int, Shape, TensorKind,
|
||||
};
|
||||
use crate::check::TensorCheck;
|
||||
use crate::tensor::api::chunk::chunk;
|
||||
use crate::tensor::api::narrow::narrow;
|
||||
use crate::{backend::Backend, check, Bool, Data, Float, Int, Shape, TensorKind};
|
||||
|
||||
/// A tensor with a given backend, shape and data type.
|
||||
#[derive(new, Clone, Debug)]
|
||||
|
@ -496,20 +497,7 @@ where
|
|||
pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
|
||||
check!(TensorCheck::dim_ops::<D>("narrow", dim));
|
||||
check!(TensorCheck::narrow(&self, dim, start, length));
|
||||
|
||||
let ranges: Vec<_> = (0..D)
|
||||
.map(|i| {
|
||||
if i == dim {
|
||||
start..(start + length)
|
||||
} else {
|
||||
0..self.shape().dims[i]
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let ranges_array: [_; D] = ranges.try_into().unwrap();
|
||||
|
||||
self.slice(ranges_array)
|
||||
Self::new(narrow::<B, D, K>(self.primitive, dim, start, length))
|
||||
}
|
||||
|
||||
/// Attempts to split the tensor along the given dimension into chunks.
|
||||
|
@ -526,31 +514,10 @@ where
|
|||
/// A vector of tensors.
|
||||
pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
|
||||
check!(TensorCheck::dim_ops::<D>("chunk", dim));
|
||||
|
||||
let size = self.shape().dims[dim];
|
||||
if size < chunks {
|
||||
return (0..size).map(|i| self.clone().narrow(dim, i, 1)).collect();
|
||||
}
|
||||
|
||||
let mut tensors = Vec::with_capacity(chunks);
|
||||
let mut sum_chunk_size = 0;
|
||||
if size % chunks == 0 {
|
||||
let chunk_size = size / chunks;
|
||||
for _ in 0..chunks {
|
||||
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
|
||||
sum_chunk_size += chunk_size;
|
||||
}
|
||||
} else {
|
||||
let chunk_size = (size / chunks) + 1; // assumes not divisible
|
||||
for _ in 0..chunks - 1 {
|
||||
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
|
||||
sum_chunk_size += chunk_size;
|
||||
}
|
||||
let remainder = size % chunk_size;
|
||||
tensors.push(self.clone().narrow(dim, sum_chunk_size, remainder));
|
||||
}
|
||||
|
||||
tensors
|
||||
chunk::<B, D, K>(self.primitive, chunks, dim)
|
||||
.into_iter()
|
||||
.map(|v| Self::new(v))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
use super::narrow::narrow;
|
||||
use crate::{backend::Backend, BasicOps, TensorKind};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// Split the tensor along the given dimension into chunks.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `chunks` - The number of chunks to be produced
|
||||
/// * `times` - The dimension along which the tensor will be split.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vectors of tensors
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
|
||||
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
|
||||
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
pub fn chunk<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
|
||||
tensor: K::Primitive<D>,
|
||||
chunks: usize,
|
||||
dim: usize,
|
||||
) -> Vec<K::Primitive<D>> {
|
||||
let size = K::shape(&tensor).dims[dim];
|
||||
if size < chunks {
|
||||
return (0..size)
|
||||
.map(|i| narrow::<B, D, K>(tensor.clone(), dim, i, 1))
|
||||
.collect();
|
||||
}
|
||||
|
||||
let mut tensors = Vec::with_capacity(chunks);
|
||||
let mut sum_chunk_size = 0;
|
||||
if size % chunks == 0 {
|
||||
let chunk_size = size / chunks;
|
||||
for _ in 0..chunks {
|
||||
tensors.push(narrow::<B, D, K>(
|
||||
tensor.clone(),
|
||||
dim,
|
||||
sum_chunk_size,
|
||||
chunk_size,
|
||||
));
|
||||
sum_chunk_size += chunk_size;
|
||||
}
|
||||
} else {
|
||||
let chunk_size = (size / chunks) + 1; // assumes not divisible
|
||||
for _ in 0..chunks - 1 {
|
||||
tensors.push(narrow::<B, D, K>(
|
||||
tensor.clone(),
|
||||
dim,
|
||||
sum_chunk_size,
|
||||
chunk_size,
|
||||
));
|
||||
sum_chunk_size += chunk_size;
|
||||
}
|
||||
let remainder = size % chunk_size;
|
||||
tensors.push(narrow::<B, D, K>(
|
||||
tensor.clone(),
|
||||
dim,
|
||||
sum_chunk_size,
|
||||
remainder,
|
||||
));
|
||||
}
|
||||
|
||||
tensors
|
||||
}
|
|
@ -3,12 +3,16 @@ pub(crate) mod check;
|
|||
mod autodiff;
|
||||
mod base;
|
||||
mod bool;
|
||||
mod chunk;
|
||||
mod float;
|
||||
mod int;
|
||||
mod kind;
|
||||
mod narrow;
|
||||
mod numeric;
|
||||
|
||||
pub use autodiff::*;
|
||||
pub use base::*;
|
||||
pub use chunk::chunk;
|
||||
pub use kind::*;
|
||||
pub use narrow::narrow;
|
||||
pub use numeric::*;
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
use crate::{backend::Backend, BasicOps, TensorKind};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// Returns a new tensor with the given dimension narrowed to the given range.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `dim` - The dimension along which the tensor will be narrowed.
|
||||
/// * `start` - The starting point of the given range.
|
||||
/// * `length` - The ending point of the given range.
|
||||
/// # Panics
|
||||
///
|
||||
/// - If the dimension is greater than the number of dimensions of the tensor.
|
||||
/// - If the given range exceeds the number of elements on the given dimension.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new tensor with the given dimension narrowed to the given range.
|
||||
pub fn narrow<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
|
||||
tensor: K::Primitive<D>,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
length: usize,
|
||||
) -> K::Primitive<D> {
|
||||
let shape = K::shape(&tensor);
|
||||
|
||||
let ranges: Vec<_> = (0..D)
|
||||
.map(|i| {
|
||||
if i == dim {
|
||||
start..(start + length)
|
||||
} else {
|
||||
0..shape.dims[i]
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let ranges_array: [_; D] = ranges.try_into().unwrap();
|
||||
|
||||
K::slice(tensor, ranges_array)
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
use super::{BoolTensor, Device, FloatTensor, IntTensor};
|
||||
use crate::{backend::Backend, tensor::Shape, Data};
|
||||
use crate::{backend::Backend, chunk, narrow, tensor::Shape, Bool, Data};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use core::ops::Range;
|
||||
|
@ -258,4 +258,48 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> BoolTensor<B, D>;
|
||||
|
||||
/// Returns a new tensor with the given dimension narrowed to the given range.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The dimension along which the tensor will be narrowed.
|
||||
/// * `start` - The starting point of the given range.
|
||||
/// * `length` - The ending point of the given range.
|
||||
/// # Panics
|
||||
///
|
||||
/// - If the dimension is greater than the number of dimensions of the tensor.
|
||||
/// - If the given range exceeds the number of elements on the given dimension.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new tensor with the given dimension narrowed to the given range.
|
||||
fn bool_narrow<const D: usize>(
|
||||
tensor: BoolTensor<B, D>,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
length: usize,
|
||||
) -> BoolTensor<B, D> {
|
||||
narrow::<B, D, Bool>(tensor, dim, start, length)
|
||||
}
|
||||
|
||||
/// Split the tensor along the given dimension into chunks.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `chunks` - The number of chunks to be produced
|
||||
/// * `times` - The dimension along which the tensor will be split.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vectors of tensors
|
||||
///
|
||||
fn bool_chunk<const D: usize>(
|
||||
tensor: BoolTensor<B, D>,
|
||||
chunks: usize,
|
||||
dim: usize,
|
||||
) -> Vec<BoolTensor<B, D>> {
|
||||
chunk::<B, D, Bool>(tensor, chunks, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
||||
use crate::{backend::Backend, tensor::Shape, Data, ElementConversion};
|
||||
use crate::{backend::Backend, tensor::Shape, Data, ElementConversion, Int};
|
||||
use crate::{tensor::api::chunk, tensor::api::narrow};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use core::ops::Range;
|
||||
|
@ -850,4 +851,48 @@ pub trait IntTensorOps<B: Backend> {
|
|||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> IntTensor<B, D>;
|
||||
|
||||
/// Returns a new tensor with the given dimension narrowed to the given range.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The dimension along which the tensor will be narrowed.
|
||||
/// * `start` - The starting point of the given range.
|
||||
/// * `length` - The ending point of the given range.
|
||||
/// # Panics
|
||||
///
|
||||
/// - If the dimension is greater than the number of dimensions of the tensor.
|
||||
/// - If the given range exceeds the number of elements on the given dimension.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new tensor with the given dimension narrowed to the given range.
|
||||
fn int_narrow<const D: usize>(
|
||||
tensor: IntTensor<B, D>,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
length: usize,
|
||||
) -> IntTensor<B, D> {
|
||||
narrow::<B, D, Int>(tensor, dim, start, length)
|
||||
}
|
||||
|
||||
/// Split the tensor along the given dimension into chunks.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `chunks` - The number of chunks to be produced
|
||||
/// * `times` - The dimension along which the tensor will be split.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vectors of tensors
|
||||
///
|
||||
fn int_chunk<const D: usize>(
|
||||
tensor: IntTensor<B, D>,
|
||||
chunks: usize,
|
||||
dim: usize,
|
||||
) -> Vec<IntTensor<B, D>> {
|
||||
chunk::<B, D, Int>(tensor, chunks, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
|
||||
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion};
|
||||
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Float};
|
||||
use crate::{tensor::api::chunk, tensor::api::narrow};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use core::ops::Range;
|
||||
|
@ -1075,4 +1076,48 @@ pub trait TensorOps<B: Backend> {
|
|||
|
||||
(values, index)
|
||||
}
|
||||
|
||||
/// Returns a new tensor with the given dimension narrowed to the given range.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The dimension along which the tensor will be narrowed.
|
||||
/// * `start` - The starting point of the given range.
|
||||
/// * `length` - The ending point of the given range.
|
||||
/// # Panics
|
||||
///
|
||||
/// - If the dimension is greater than the number of dimensions of the tensor.
|
||||
/// - If the given range exceeds the number of elements on the given dimension.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new tensor with the given dimension narrowed to the given range.
|
||||
fn narrow<const D: usize>(
|
||||
tensor: FloatTensor<B, D>,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
length: usize,
|
||||
) -> FloatTensor<B, D> {
|
||||
narrow::<B, D, Float>(tensor, dim, start, length)
|
||||
}
|
||||
|
||||
/// Split the tensor along the given dimension into chunks.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `chunks` - The number of chunks to be produced
|
||||
/// * `times` - The dimension along which the tensor will be split.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vectors of tensors
|
||||
///
|
||||
fn chunk<const D: usize>(
|
||||
tensor: FloatTensor<B, D>,
|
||||
chunks: usize,
|
||||
dim: usize,
|
||||
) -> Vec<FloatTensor<B, D>> {
|
||||
chunk::<B, D, Float>(tensor, chunks, dim)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue