mirror of https://github.com/tracel-ai/burn.git
refactor: move tensor base ops
This commit is contained in:
parent
6f45e878f1
commit
122cd842a2
|
@ -1,5 +1,4 @@
|
|||
use crate::tensor::ops::*;
|
||||
use crate::tensor::TensorBase;
|
||||
use half::bf16;
|
||||
use half::f16;
|
||||
|
||||
|
|
|
@ -74,7 +74,7 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_add() {
|
||||
|
|
|
@ -116,7 +116,7 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_matmul_with_index() {
|
||||
|
|
|
@ -37,7 +37,7 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_matmul() {
|
||||
|
|
|
@ -79,7 +79,7 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_mul() {
|
||||
|
|
|
@ -46,7 +46,7 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_neg() {
|
||||
|
|
|
@ -65,7 +65,7 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_mul() {
|
||||
|
|
|
@ -79,7 +79,7 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_sub() {
|
||||
|
|
|
@ -27,7 +27,7 @@ impl<T: ADCompatibleTensor<P, D>, P: ADElement, const D: usize> TensorOpsTranspo
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_transpose() {
|
||||
|
|
|
@ -31,7 +31,7 @@ impl<T, P, const D: usize> AsNode<T> for ADTensor<P, D, T> {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, ops::*, Data, TensorBase};
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, ops::*, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_full_complex_1() {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, ops::*, Data, TensorBase};
|
||||
use crate::tensor::{backend::autodiff::helper::ADTchTensor, ops::*, Data};
|
||||
|
||||
#[test]
|
||||
fn should_behave_the_same_with_multithread() {
|
||||
|
|
|
@ -4,7 +4,8 @@ use crate::{
|
|||
graph::node::ForwardNodeRef,
|
||||
tensor::{
|
||||
backend::autodiff::{ADCompatibleTensor, ADElement},
|
||||
Data, Shape, TensorBase,
|
||||
ops::TensorBase,
|
||||
Data, Shape,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{Data, TensorBase};
|
||||
use crate::tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_support_add_ops() {
|
||||
|
|
|
@ -57,7 +57,7 @@ fn to_slice_args<const D1: usize, const D2: usize>(
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{Data, TensorBase};
|
||||
use crate::tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_support_full_indexing_1d() {
|
||||
|
|
|
@ -31,7 +31,7 @@ where
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tensor::{backend::ndarray::NdArrayTensor, ops::*, Data, TensorBase};
|
||||
use crate::tensor::{backend::ndarray::NdArrayTensor, ops::*, Data};
|
||||
|
||||
#[test]
|
||||
fn should_matmul_d2() {
|
||||
|
|
|
@ -49,7 +49,7 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{Data, TensorBase};
|
||||
use crate::tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_support_mul_ops() {
|
||||
|
|
|
@ -29,7 +29,7 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{Data, TensorBase};
|
||||
use crate::tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_support_neg_ops() {
|
||||
|
|
|
@ -45,7 +45,7 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{Data, TensorBase};
|
||||
use crate::tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_support_sub_ops() {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::tensor::{Data, Shape, TensorBase};
|
||||
use crate::tensor::{ops::TensorBase, Data, Shape};
|
||||
use ndarray::{s, ArcArray, Array, Axis, Dim, Dimension, Ix2, Ix3, IxDyn};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
|
@ -54,7 +54,6 @@ impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> s
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::TensorBase;
|
||||
|
||||
#[test]
|
||||
fn should_support_add_ops() {
|
||||
|
|
|
@ -59,7 +59,7 @@ impl<
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{Data, TensorBase};
|
||||
use crate::tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_support_full_indexing_1d() {
|
||||
|
|
|
@ -48,7 +48,7 @@ impl<P: tch::kind::Element + Into<f64>, const D: usize> std::ops::Mul<TchTensor<
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{Data, TensorBase};
|
||||
use crate::tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_support_mul_ops() {
|
||||
|
|
|
@ -29,7 +29,7 @@ impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> s
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{Data, TensorBase};
|
||||
use crate::tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_support_neg_ops() {
|
||||
|
|
|
@ -54,7 +54,6 @@ impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> s
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::TensorBase;
|
||||
|
||||
#[test]
|
||||
fn should_support_sub_ops() {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::tensor::{Data, Shape, TensorBase};
|
||||
use crate::tensor::{ops::TensorBase, Data, Shape};
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct TchTensor<P: tch::kind::Element, const D: usize> {
|
||||
|
|
|
@ -4,8 +4,6 @@ pub mod ops;
|
|||
mod data;
|
||||
mod print;
|
||||
mod shape;
|
||||
mod tensor;
|
||||
|
||||
pub use data::*;
|
||||
pub use shape::*;
|
||||
pub use tensor::*;
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
use crate::tensor::{Shape, TensorBase};
|
||||
use super::Data;
|
||||
use crate::tensor::Shape;
|
||||
use std::ops::Range;
|
||||
|
||||
pub trait TensorBase<P, const D: usize> {
|
||||
fn shape(&self) -> &Shape<D>;
|
||||
fn into_data(self) -> Data<P, D>;
|
||||
fn to_data(&self) -> Data<P, D>;
|
||||
}
|
||||
|
||||
pub trait TensorOpsAdd<P, const D: usize>:
|
||||
std::ops::Add<Self, Output = Self> + std::ops::Add<P, Output = Self>
|
||||
where
|
||||
|
@ -52,6 +59,7 @@ pub trait TensorOpsIndex<P, const D1: usize, const D2: usize> {
|
|||
pub trait Zeros<T> {
|
||||
fn zeros(&self) -> T;
|
||||
}
|
||||
|
||||
pub trait Ones<T> {
|
||||
fn ones(&self) -> T;
|
||||
}
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
use crate::tensor::{Data, Shape};
|
||||
|
||||
pub trait TensorBase<P, const D: usize> {
|
||||
fn shape(&self) -> &Shape<D>;
|
||||
fn into_data(self) -> Data<P, D>;
|
||||
fn to_data(&self) -> Data<P, D>;
|
||||
}
|
Loading…
Reference in New Issue