mirror of https://github.com/tracel-ai/burn.git
feat: implement more ops for ndarray backend
This commit is contained in:
parent
b2f3c42376
commit
ec32fa730c
|
@ -0,0 +1,65 @@
|
|||
use crate::{backend::ndarray::NdArrayTensor, TensorOpsAdd};
|
||||
use ndarray::{Dim, Dimension, LinalgScalar, ScalarOperand};
|
||||
|
||||
impl<P, const D: usize> TensorOpsAdd<P, D> for NdArrayTensor<P, D>
|
||||
where
|
||||
P: Clone + LinalgScalar + Default + std::fmt::Debug + ScalarOperand,
|
||||
Dim<[usize; D]>: Dimension,
|
||||
{
|
||||
fn add(&self, other: &Self) -> Self {
|
||||
let array = self.array.clone() + other.array.clone();
|
||||
let array = array.into_shared();
|
||||
let shape = self.shape.clone();
|
||||
|
||||
Self { array, shape }
|
||||
}
|
||||
fn add_scalar(&self, other: &P) -> Self {
|
||||
let array = self.array.clone() + other.clone();
|
||||
let shape = self.shape.clone();
|
||||
|
||||
Self { array, shape }
|
||||
}
|
||||
}
|
||||
|
||||
impl<P, const D: usize> std::ops::Add<Self> for NdArrayTensor<P, D>
|
||||
where
|
||||
P: Clone + LinalgScalar + Default + std::fmt::Debug + ScalarOperand,
|
||||
Dim<[usize; D]>: Dimension,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
TensorOpsAdd::add(&self, &rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<P, const D: usize> std::ops::Add<P> for NdArrayTensor<P, D>
|
||||
where
|
||||
P: Clone + LinalgScalar + Default + std::fmt::Debug + ScalarOperand,
|
||||
Dim<[usize; D]>: Dimension,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: P) -> Self::Output {
|
||||
TensorOpsAdd::add_scalar(&self, &rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{Data, TensorBase};
|
||||
|
||||
#[test]
|
||||
fn should_support_add_ops() {
|
||||
let data_1 = Data::<f64, 2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let data_2 = Data::<f64, 2>::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]);
|
||||
let data_expected = Data::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]);
|
||||
let tensor_1 = NdArrayTensor::from(data_1);
|
||||
let tensor_2 = NdArrayTensor::from(data_2);
|
||||
|
||||
let data_actual = (tensor_1 + tensor_2).into_data();
|
||||
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,135 @@
|
|||
use crate::{backend::ndarray::NdArrayTensor, TensorOpsIndex};
|
||||
use ndarray::SliceInfoElem;
|
||||
use std::ops::Range;
|
||||
|
||||
impl<
|
||||
P: tch::kind::Element + std::fmt::Debug + Copy + Default,
|
||||
const D1: usize,
|
||||
const D2: usize,
|
||||
> TensorOpsIndex<P, D1, D2> for NdArrayTensor<P, D1>
|
||||
{
|
||||
fn index(&self, indexes: [Range<usize>; D2]) -> Self {
|
||||
let slices = to_slice_args::<D1, D2>(indexes.clone());
|
||||
let array = self
|
||||
.array
|
||||
.clone()
|
||||
.slice_move(slices.as_slice())
|
||||
.into_shared();
|
||||
let shape = self.shape.index(indexes.clone());
|
||||
|
||||
Self { array, shape }
|
||||
}
|
||||
|
||||
fn index_assign(&self, indexes: [Range<usize>; D2], values: &Self) -> Self {
|
||||
let slices = to_slice_args::<D1, D2>(indexes.clone());
|
||||
let mut array = self.array.to_owned();
|
||||
array.slice_mut(slices.as_slice()).assign(&values.array);
|
||||
let array = array.into_owned().into_shared();
|
||||
|
||||
let shape = self.shape.clone();
|
||||
|
||||
Self { array, shape }
|
||||
}
|
||||
}
|
||||
|
||||
fn to_slice_args<const D1: usize, const D2: usize>(
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> [SliceInfoElem; D1] {
|
||||
let mut slices = [SliceInfoElem::NewAxis; D1];
|
||||
for i in 0..D1 {
|
||||
if i >= D2 {
|
||||
slices[i] = SliceInfoElem::Slice {
|
||||
start: 0,
|
||||
end: None,
|
||||
step: 1,
|
||||
}
|
||||
} else {
|
||||
slices[i] = SliceInfoElem::Slice {
|
||||
start: indexes[i].start as isize,
|
||||
end: Some(indexes[i].end as isize),
|
||||
step: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
slices
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{Data, TensorBase};
|
||||
|
||||
#[test]
|
||||
fn should_support_full_indexing_1d() {
|
||||
let data = Data::<f64, 1>::from([0.0, 1.0, 2.0]);
|
||||
let tensor = NdArrayTensor::from(data.clone());
|
||||
|
||||
let data_actual = tensor.index([0..3]).into_data();
|
||||
|
||||
assert_eq!(data, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_partial_indexing_1d() {
|
||||
let data = Data::<f64, 1>::from([0.0, 1.0, 2.0]);
|
||||
let tensor = NdArrayTensor::from(data.clone());
|
||||
|
||||
let data_actual = tensor.index([1..3]).into_data();
|
||||
|
||||
let data_expected = Data::from([1.0, 2.0]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_full_indexing_2d() {
|
||||
let data = Data::<f64, 2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = NdArrayTensor::from(data.clone());
|
||||
|
||||
let data_actual_1 = tensor.index([0..2]).into_data();
|
||||
let data_actual_2 = tensor.index([0..2, 0..3]).into_data();
|
||||
|
||||
assert_eq!(data, data_actual_1);
|
||||
assert_eq!(data, data_actual_2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_partial_indexing_2d() {
|
||||
let data = Data::<f64, 2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = NdArrayTensor::from(data.clone());
|
||||
|
||||
let data_actual = tensor.index([0..2, 0..2]).into_data();
|
||||
|
||||
let data_expected = Data::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_indexe_assign_1d() {
|
||||
let data = Data::<f64, 1>::from([0.0, 1.0, 2.0]);
|
||||
let data_assigned = Data::<f64, 1>::from([10.0, 5.0]);
|
||||
|
||||
let tensor = NdArrayTensor::from(data.clone());
|
||||
let tensor_assigned = NdArrayTensor::from(data_assigned.clone());
|
||||
|
||||
let data_actual = tensor.index_assign([0..2], &tensor_assigned).into_data();
|
||||
|
||||
let data_expected = Data::<f64, 1>::from([10.0, 5.0, 2.0]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_indexe_assign_2d() {
|
||||
let data = Data::<f64, 2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let data_assigned = Data::<f64, 2>::from([[10.0, 5.0]]);
|
||||
|
||||
let tensor = NdArrayTensor::from(data.clone());
|
||||
let tensor_assigned = NdArrayTensor::from(data_assigned.clone());
|
||||
|
||||
let data_actual = tensor
|
||||
.index_assign([1..2, 0..2], &tensor_assigned)
|
||||
.into_data();
|
||||
|
||||
let data_expected = Data::<f64, 2>::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
}
|
|
@ -1,41 +1,62 @@
|
|||
use crate::{backend::ndarray::NdArrayTensor, TensorOpsMatmul};
|
||||
use ndarray::LinalgScalar;
|
||||
use crate::{
|
||||
backend::ndarray::{BatchMatrix, NdArrayTensor},
|
||||
TensorOpsMatmul,
|
||||
};
|
||||
use ndarray::{Dim, Dimension, LinalgScalar};
|
||||
|
||||
impl<P, const D: usize> TensorOpsMatmul<f32, D> for NdArrayTensor<P, D>
|
||||
where
|
||||
P: Clone + LinalgScalar + Default,
|
||||
{
|
||||
fn matmul(&self, other: &Self) -> Self {
|
||||
let self_iter = self.arrays.iter();
|
||||
let other_iter = other.arrays.iter();
|
||||
let arrays = self_iter
|
||||
.zip(other_iter)
|
||||
.map(|(lhs, rhs)| lhs.dot(rhs))
|
||||
.map(|output| output.into_shared())
|
||||
.collect();
|
||||
macro_rules! define_from {
|
||||
(
|
||||
$n:expr
|
||||
) => {
|
||||
impl<P> TensorOpsMatmul<f32, $n> for NdArrayTensor<P, $n>
|
||||
where
|
||||
P: Clone + LinalgScalar + Default + std::fmt::Debug,
|
||||
Dim<[usize; $n]>: Dimension,
|
||||
{
|
||||
fn matmul(&self, other: &Self) -> Self {
|
||||
let batch_self = BatchMatrix::from_ndarray(self.array.clone(), self.shape.clone());
|
||||
let batch_other =
|
||||
BatchMatrix::from_ndarray(other.array.clone(), other.shape.clone());
|
||||
|
||||
let mut shape = self.shape.clone();
|
||||
shape.dims[D - 1] = other.shape.dims[D - 1];
|
||||
let self_iter = batch_self.arrays.iter();
|
||||
let other_iter = batch_other.arrays.iter();
|
||||
let arrays = self_iter
|
||||
.zip(other_iter)
|
||||
.map(|(lhs, rhs)| lhs.dot(rhs))
|
||||
.map(|output| output.into_shared())
|
||||
.collect();
|
||||
|
||||
Self { arrays, shape }
|
||||
}
|
||||
let mut shape = self.shape.clone();
|
||||
shape.dims[$n - 1] = other.shape.dims[$n - 1];
|
||||
let output = BatchMatrix::new(arrays, shape.clone());
|
||||
|
||||
Self::from(output)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
define_from!(1);
|
||||
define_from!(2);
|
||||
define_from!(3);
|
||||
define_from!(4);
|
||||
define_from!(5);
|
||||
define_from!(6);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{backend::ndarray::NdArrayTensor, Data, Shape, TensorBase, TensorOpsMatmul};
|
||||
use crate::{backend::ndarray::NdArrayTensor, Data, TensorBase, TensorOpsMatmul};
|
||||
|
||||
#[test]
|
||||
fn should_matmul_d2() {
|
||||
let data_1: Data<f64, 2> = Data::from([[1.0, 7.0], [2.0, 3.0], [1.0, 5.0]]);
|
||||
let data_2: Data<f64, 2> = Data::from([[4.0, 7.0, 5.0], [2.0, 3.0, 5.0]]);
|
||||
|
||||
let tensor_1 = NdArrayTensor::from_data(data_1.clone());
|
||||
let tensor_2 = NdArrayTensor::from_data(data_2.clone());
|
||||
let tensor_1 = NdArrayTensor::from(data_1.clone());
|
||||
let tensor_2 = NdArrayTensor::from(data_2.clone());
|
||||
|
||||
let tensor_3 = tensor_1.matmul(&tensor_2);
|
||||
|
||||
assert_eq!(tensor_3.shape, Shape::new([3, 3]));
|
||||
assert_eq!(
|
||||
tensor_3.into_data(),
|
||||
Data::from([[18.0, 28.0, 40.0], [14.0, 23.0, 25.0], [14.0, 22.0, 30.0]])
|
||||
|
@ -47,8 +68,8 @@ mod tests {
|
|||
let data_1: Data<f64, 3> = Data::from([[[1.0, 7.0], [2.0, 3.0]]]);
|
||||
let data_2: Data<f64, 3> = Data::from([[[4.0, 7.0], [2.0, 3.0]]]);
|
||||
|
||||
let tensor_1 = NdArrayTensor::from_data(data_1.clone());
|
||||
let tensor_2 = NdArrayTensor::from_data(data_2.clone());
|
||||
let tensor_1 = NdArrayTensor::from(data_1.clone());
|
||||
let tensor_2 = NdArrayTensor::from(data_2.clone());
|
||||
|
||||
let tensor_3 = tensor_1.matmul(&tensor_2);
|
||||
|
||||
|
|
|
@ -1 +1,5 @@
|
|||
mod add;
|
||||
mod index;
|
||||
mod matmul;
|
||||
mod neg;
|
||||
mod sub;
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
use ndarray::{LinalgScalar, ScalarOperand};
|
||||
|
||||
use crate::{backend::ndarray::NdArrayTensor, TensorOpsNeg};
|
||||
|
||||
impl<P, const D: usize> TensorOpsNeg<P, D> for NdArrayTensor<P, D>
|
||||
where
|
||||
P: Clone + LinalgScalar + Default + std::fmt::Debug + ScalarOperand,
|
||||
{
|
||||
fn neg(&self) -> Self {
|
||||
let minus_one = P::zero() - P::one();
|
||||
let array = self.array.clone() * minus_one;
|
||||
let array = array.into_shared();
|
||||
let shape = self.shape.clone();
|
||||
|
||||
Self { array, shape }
|
||||
}
|
||||
}
|
||||
|
||||
impl<P, const D: usize> std::ops::Neg for NdArrayTensor<P, D>
|
||||
where
|
||||
P: Clone + LinalgScalar + Default + std::fmt::Debug + ScalarOperand,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
TensorOpsNeg::neg(&self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{Data, TensorBase};
|
||||
|
||||
#[test]
|
||||
fn should_support_neg_ops() {
|
||||
let data = Data::<f64, 2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = NdArrayTensor::from(data);
|
||||
|
||||
let data_actual = tensor.neg().into_data();
|
||||
|
||||
let data_expected = Data::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
use crate::{backend::ndarray::NdArrayTensor, TensorOpsSub};
|
||||
use ndarray::{LinalgScalar, ScalarOperand};
|
||||
|
||||
impl<P, const D: usize> TensorOpsSub<P, D> for NdArrayTensor<P, D>
|
||||
where
|
||||
P: Clone + LinalgScalar + Default + std::fmt::Debug + ScalarOperand,
|
||||
{
|
||||
fn sub(&self, other: &Self) -> Self {
|
||||
let array = self.array.clone() - other.array.clone();
|
||||
let array = array.into_shared();
|
||||
let shape = self.shape.clone();
|
||||
|
||||
Self { array, shape }
|
||||
}
|
||||
fn sub_scalar(&self, other: &P) -> Self {
|
||||
let array = self.array.clone() - other.clone();
|
||||
let shape = self.shape.clone();
|
||||
|
||||
Self { array, shape }
|
||||
}
|
||||
}
|
||||
|
||||
impl<P, const D: usize> std::ops::Sub<Self> for NdArrayTensor<P, D>
|
||||
where
|
||||
P: Clone + LinalgScalar + Default + std::fmt::Debug + ScalarOperand,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
TensorOpsSub::sub(&self, &rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<P, const D: usize> std::ops::Sub<P> for NdArrayTensor<P, D>
|
||||
where
|
||||
P: Clone + LinalgScalar + Default + std::fmt::Debug + ScalarOperand,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: P) -> Self::Output {
|
||||
TensorOpsSub::sub_scalar(&self, &rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{Data, TensorBase};
|
||||
|
||||
#[test]
|
||||
fn should_support_sub_ops() {
|
||||
let data_1 = Data::<f64, 2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let data_2 = Data::<f64, 2>::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]);
|
||||
let data_expected = Data::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]);
|
||||
let tensor_1 = NdArrayTensor::from(data_1);
|
||||
let tensor_2 = NdArrayTensor::from(data_2);
|
||||
|
||||
let data_actual = (tensor_1 - tensor_2).into_data();
|
||||
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
}
|
|
@ -2,59 +2,58 @@ use crate::Data;
|
|||
use crate::Shape;
|
||||
use crate::TensorBase;
|
||||
use ndarray::s;
|
||||
use ndarray::ArcArray;
|
||||
use ndarray::Array;
|
||||
use ndarray::Dim;
|
||||
use ndarray::Dimension;
|
||||
use ndarray::Ix2;
|
||||
use ndarray::{ArcArray, IxDyn};
|
||||
|
||||
pub struct NdArrayTensor<P, const D: usize> {
|
||||
pub array: ArcArray<P, IxDyn>,
|
||||
pub shape: Shape<D>,
|
||||
}
|
||||
|
||||
impl<P, const D: usize> TensorBase<P, D> for NdArrayTensor<P, D>
|
||||
where
|
||||
P: Default + Clone,
|
||||
Dim<[usize; D]>: Dimension,
|
||||
{
|
||||
fn shape(&self) -> &Shape<D> {
|
||||
&self.shape
|
||||
}
|
||||
|
||||
fn into_data(self) -> Data<P, D> {
|
||||
let values = self.array.into_iter().collect();
|
||||
Data::new(values, self.shape)
|
||||
}
|
||||
|
||||
fn to_data(&self) -> Data<P, D> {
|
||||
let values = self.array.clone().into_iter().collect();
|
||||
Data::new(values, self.shape)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct BatchMatrix<P, const D: usize> {
|
||||
pub arrays: Vec<ArcArray<P, Ix2>>,
|
||||
pub shape: Shape<D>,
|
||||
}
|
||||
|
||||
impl<P: Default + Copy + std::fmt::Debug, const D: usize> TensorBase<P, D> for NdArrayTensor<P, D> {
|
||||
fn shape(&self) -> &Shape<D> {
|
||||
&self.shape
|
||||
}
|
||||
fn into_data(self) -> Data<P, D> {
|
||||
let mut arrays = self.arrays;
|
||||
|
||||
if D == 1 {
|
||||
let array = arrays.remove(0);
|
||||
let values = array.into_iter().collect();
|
||||
return Data::new(values, self.shape);
|
||||
}
|
||||
|
||||
let mut values = Vec::new();
|
||||
for array in arrays {
|
||||
let mut values_tmp = array.into_iter().collect();
|
||||
values.append(&mut values_tmp);
|
||||
}
|
||||
Data::new(values, self.shape)
|
||||
}
|
||||
fn to_data(&self) -> Data<P, D> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
impl<P, const D: usize> NdArrayTensor<P, D>
|
||||
impl<P, const D: usize> BatchMatrix<P, D>
|
||||
where
|
||||
P: Default + Clone,
|
||||
P: Clone,
|
||||
Dim<[usize; D]>: Dimension,
|
||||
{
|
||||
pub fn from_data(data: Data<P, D>) -> Self {
|
||||
let shape = data.shape.clone();
|
||||
pub fn from_ndarray(array: ArcArray<P, IxDyn>, shape: Shape<D>) -> Self {
|
||||
let mut arrays = Vec::new();
|
||||
|
||||
if D < 2 {
|
||||
let array = Array::from_iter(data.value.into_iter())
|
||||
.into_shared()
|
||||
.reshape((1, shape.dims[0]));
|
||||
let array = array.reshape((1, shape.dims[0]));
|
||||
arrays.push(array);
|
||||
} else {
|
||||
let batch_size = batch_size(&shape);
|
||||
let size0 = shape.dims[D - 2];
|
||||
let size1 = shape.dims[D - 1];
|
||||
let array_global = Array::from_iter(data.value.into_iter())
|
||||
.into_shared()
|
||||
.reshape((batch_size, size0, size1));
|
||||
let array_global = array.reshape((batch_size, size0, size1));
|
||||
for b in 0..batch_size {
|
||||
let array = array_global.slice(s!(b, .., ..));
|
||||
let array = array.into_owned().into_shared();
|
||||
|
@ -75,6 +74,55 @@ fn batch_size<const D: usize>(shape: &Shape<D>) -> usize {
|
|||
num_batch
|
||||
}
|
||||
|
||||
macro_rules! define_from {
|
||||
(
|
||||
$n:expr
|
||||
) => {
|
||||
impl<P> From<Data<P, $n>> for NdArrayTensor<P, $n>
|
||||
where
|
||||
P: Default + Clone,
|
||||
{
|
||||
fn from(data: Data<P, $n>) -> NdArrayTensor<P, $n> {
|
||||
let shape = data.shape.clone();
|
||||
let dim: Dim<[usize; $n]> = shape.clone().into();
|
||||
|
||||
let array: ArcArray<P, Dim<[usize; $n]>> = Array::from_iter(data.value.into_iter())
|
||||
.into_shared()
|
||||
.reshape(dim);
|
||||
let array = array.into_dyn();
|
||||
|
||||
NdArrayTensor { array, shape }
|
||||
}
|
||||
}
|
||||
impl<P> From<BatchMatrix<P, $n>> for NdArrayTensor<P, $n>
|
||||
where
|
||||
P: Default + Clone,
|
||||
{
|
||||
fn from(data: BatchMatrix<P, $n>) -> NdArrayTensor<P, $n> {
|
||||
let shape = data.shape;
|
||||
let dim: Dim<[usize; $n]> = shape.clone().into();
|
||||
let mut values = Vec::new();
|
||||
for array in data.arrays {
|
||||
values.append(&mut array.into_iter().collect());
|
||||
}
|
||||
|
||||
let array: ArcArray<P, Dim<[usize; $n]>> =
|
||||
Array::from_iter(values).into_shared().reshape(dim);
|
||||
let array = array.into_dyn();
|
||||
|
||||
NdArrayTensor { array, shape }
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
define_from!(1);
|
||||
define_from!(2);
|
||||
define_from!(3);
|
||||
define_from!(4);
|
||||
define_from!(5);
|
||||
define_from!(6);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -82,7 +130,7 @@ mod tests {
|
|||
#[test]
|
||||
fn should_support_into_and_from_data_1d() {
|
||||
let data_expected = Data::<f32, 1>::random(Shape::new([3]));
|
||||
let tensor = NdArrayTensor::from_data(data_expected.clone());
|
||||
let tensor = NdArrayTensor::from(data_expected.clone());
|
||||
|
||||
let data_actual = tensor.into_data();
|
||||
|
||||
|
@ -92,7 +140,7 @@ mod tests {
|
|||
#[test]
|
||||
fn should_support_into_and_from_data_2d() {
|
||||
let data_expected = Data::<f32, 2>::random(Shape::new([2, 3]));
|
||||
let tensor = NdArrayTensor::from_data(data_expected.clone());
|
||||
let tensor = NdArrayTensor::from(data_expected.clone());
|
||||
|
||||
let data_actual = tensor.into_data();
|
||||
|
||||
|
@ -102,7 +150,7 @@ mod tests {
|
|||
#[test]
|
||||
fn should_support_into_and_from_data_3d() {
|
||||
let data_expected = Data::<f32, 3>::random(Shape::new([2, 3, 4]));
|
||||
let tensor = NdArrayTensor::from_data(data_expected.clone());
|
||||
let tensor = NdArrayTensor::from(data_expected.clone());
|
||||
|
||||
let data_actual = tensor.into_data();
|
||||
|
||||
|
@ -112,7 +160,7 @@ mod tests {
|
|||
#[test]
|
||||
fn should_support_into_and_from_data_4d() {
|
||||
let data_expected = Data::<f32, 4>::random(Shape::new([2, 3, 4, 2]));
|
||||
let tensor = NdArrayTensor::from_data(data_expected.clone());
|
||||
let tensor = NdArrayTensor::from(data_expected.clone());
|
||||
|
||||
let data_actual = tensor.into_data();
|
||||
|
||||
|
|
Loading…
Reference in New Issue