mirror of https://github.com/tracel-ai/burn.git
refactor: move tensor base ops
This commit is contained in:
parent
6f45e878f1
commit
122cd842a2
|
@ -1,5 +1,4 @@
|
||||||
use crate::tensor::ops::*;
|
use crate::tensor::ops::*;
|
||||||
use crate::tensor::TensorBase;
|
|
||||||
use half::bf16;
|
use half::bf16;
|
||||||
use half::f16;
|
use half::f16;
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ where
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_diff_add() {
|
fn should_diff_add() {
|
||||||
|
|
|
@ -116,7 +116,7 @@ where
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_diff_matmul_with_index() {
|
fn should_diff_matmul_with_index() {
|
||||||
|
|
|
@ -37,7 +37,7 @@ where
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_diff_matmul() {
|
fn should_diff_matmul() {
|
||||||
|
|
|
@ -79,7 +79,7 @@ where
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_diff_mul() {
|
fn should_diff_mul() {
|
||||||
|
|
|
@ -46,7 +46,7 @@ where
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_diff_neg() {
|
fn should_diff_neg() {
|
||||||
|
|
|
@ -65,7 +65,7 @@ where
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_diff_mul() {
|
fn should_diff_mul() {
|
||||||
|
|
|
@ -79,7 +79,7 @@ where
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_diff_sub() {
|
fn should_diff_sub() {
|
||||||
|
|
|
@ -27,7 +27,7 @@ impl<T: ADCompatibleTensor<P, D>, P: ADElement, const D: usize> TensorOpsTranspo
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_diff_transpose() {
|
fn should_diff_transpose() {
|
||||||
|
|
|
@ -31,7 +31,7 @@ impl<T, P, const D: usize> AsNode<T> for ADTensor<P, D, T> {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, ops::*, Data, TensorBase};
|
use crate::tensor::{backend::autodiff::helper::ADTchTensor, ops::*, Data};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_diff_full_complex_1() {
|
fn should_diff_full_complex_1() {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, ops::*, Data, TensorBase};
|
use crate::tensor::{backend::autodiff::helper::ADTchTensor, ops::*, Data};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_behave_the_same_with_multithread() {
|
fn should_behave_the_same_with_multithread() {
|
||||||
|
|
|
@ -4,7 +4,8 @@ use crate::{
|
||||||
graph::node::ForwardNodeRef,
|
graph::node::ForwardNodeRef,
|
||||||
tensor::{
|
tensor::{
|
||||||
backend::autodiff::{ADCompatibleTensor, ADElement},
|
backend::autodiff::{ADCompatibleTensor, ADElement},
|
||||||
Data, Shape, TensorBase,
|
ops::TensorBase,
|
||||||
|
Data, Shape,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ where
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{Data, TensorBase};
|
use crate::tensor::Data;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_add_ops() {
|
fn should_support_add_ops() {
|
||||||
|
|
|
@ -57,7 +57,7 @@ fn to_slice_args<const D1: usize, const D2: usize>(
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{Data, TensorBase};
|
use crate::tensor::Data;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_full_indexing_1d() {
|
fn should_support_full_indexing_1d() {
|
||||||
|
|
|
@ -31,7 +31,7 @@ where
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::tensor::{backend::ndarray::NdArrayTensor, ops::*, Data, TensorBase};
|
use crate::tensor::{backend::ndarray::NdArrayTensor, ops::*, Data};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_matmul_d2() {
|
fn should_matmul_d2() {
|
||||||
|
|
|
@ -49,7 +49,7 @@ where
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{Data, TensorBase};
|
use crate::tensor::Data;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_mul_ops() {
|
fn should_support_mul_ops() {
|
||||||
|
|
|
@ -29,7 +29,7 @@ where
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{Data, TensorBase};
|
use crate::tensor::Data;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_neg_ops() {
|
fn should_support_neg_ops() {
|
||||||
|
|
|
@ -45,7 +45,7 @@ where
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{Data, TensorBase};
|
use crate::tensor::Data;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_sub_ops() {
|
fn should_support_sub_ops() {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::tensor::{Data, Shape, TensorBase};
|
use crate::tensor::{ops::TensorBase, Data, Shape};
|
||||||
use ndarray::{s, ArcArray, Array, Axis, Dim, Dimension, Ix2, Ix3, IxDyn};
|
use ndarray::{s, ArcArray, Array, Axis, Dim, Dimension, Ix2, Ix3, IxDyn};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
|
|
@ -54,7 +54,6 @@ impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> s
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::TensorBase;
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_add_ops() {
|
fn should_support_add_ops() {
|
||||||
|
|
|
@ -59,7 +59,7 @@ impl<
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{Data, TensorBase};
|
use crate::tensor::Data;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_full_indexing_1d() {
|
fn should_support_full_indexing_1d() {
|
||||||
|
|
|
@ -48,7 +48,7 @@ impl<P: tch::kind::Element + Into<f64>, const D: usize> std::ops::Mul<TchTensor<
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{Data, TensorBase};
|
use crate::tensor::Data;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_mul_ops() {
|
fn should_support_mul_ops() {
|
||||||
|
|
|
@ -29,7 +29,7 @@ impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> s
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::{Data, TensorBase};
|
use crate::tensor::Data;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_neg_ops() {
|
fn should_support_neg_ops() {
|
||||||
|
|
|
@ -54,7 +54,6 @@ impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> s
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::tensor::TensorBase;
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_sub_ops() {
|
fn should_support_sub_ops() {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::tensor::{Data, Shape, TensorBase};
|
use crate::tensor::{ops::TensorBase, Data, Shape};
|
||||||
|
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct TchTensor<P: tch::kind::Element, const D: usize> {
|
pub struct TchTensor<P: tch::kind::Element, const D: usize> {
|
||||||
|
|
|
@ -4,8 +4,6 @@ pub mod ops;
|
||||||
mod data;
|
mod data;
|
||||||
mod print;
|
mod print;
|
||||||
mod shape;
|
mod shape;
|
||||||
mod tensor;
|
|
||||||
|
|
||||||
pub use data::*;
|
pub use data::*;
|
||||||
pub use shape::*;
|
pub use shape::*;
|
||||||
pub use tensor::*;
|
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
use crate::tensor::{Shape, TensorBase};
|
use super::Data;
|
||||||
|
use crate::tensor::Shape;
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
|
|
||||||
|
pub trait TensorBase<P, const D: usize> {
|
||||||
|
fn shape(&self) -> &Shape<D>;
|
||||||
|
fn into_data(self) -> Data<P, D>;
|
||||||
|
fn to_data(&self) -> Data<P, D>;
|
||||||
|
}
|
||||||
|
|
||||||
pub trait TensorOpsAdd<P, const D: usize>:
|
pub trait TensorOpsAdd<P, const D: usize>:
|
||||||
std::ops::Add<Self, Output = Self> + std::ops::Add<P, Output = Self>
|
std::ops::Add<Self, Output = Self> + std::ops::Add<P, Output = Self>
|
||||||
where
|
where
|
||||||
|
@ -52,6 +59,7 @@ pub trait TensorOpsIndex<P, const D1: usize, const D2: usize> {
|
||||||
pub trait Zeros<T> {
|
pub trait Zeros<T> {
|
||||||
fn zeros(&self) -> T;
|
fn zeros(&self) -> T;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Ones<T> {
|
pub trait Ones<T> {
|
||||||
fn ones(&self) -> T;
|
fn ones(&self) -> T;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +0,0 @@
|
||||||
use crate::tensor::{Data, Shape};
|
|
||||||
|
|
||||||
pub trait TensorBase<P, const D: usize> {
|
|
||||||
fn shape(&self) -> &Shape<D>;
|
|
||||||
fn into_data(self) -> Data<P, D>;
|
|
||||||
fn to_data(&self) -> Data<P, D>;
|
|
||||||
}
|
|
Loading…
Reference in New Issue