refactor: move tensor base ops

This commit is contained in:
nathaniel 2022-07-27 10:05:38 -04:00
parent 6f45e878f1
commit 122cd842a2
28 changed files with 32 additions and 35 deletions

View File

@ -1,5 +1,4 @@
use crate::tensor::ops::*;
use crate::tensor::TensorBase;
use half::bf16;
use half::f16;

View File

@ -74,7 +74,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
#[test]
fn should_diff_add() {

View File

@ -116,7 +116,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
#[test]
fn should_diff_matmul_with_index() {

View File

@ -37,7 +37,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
#[test]
fn should_diff_matmul() {

View File

@ -79,7 +79,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
#[test]
fn should_diff_mul() {

View File

@ -46,7 +46,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
#[test]
fn should_diff_neg() {

View File

@ -65,7 +65,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
#[test]
fn should_diff_mul() {

View File

@ -79,7 +79,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
#[test]
fn should_diff_sub() {

View File

@ -27,7 +27,7 @@ impl<T: ADCompatibleTensor<P, D>, P: ADElement, const D: usize> TensorOpsTranspo
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data, TensorBase};
use crate::tensor::{backend::autodiff::helper::ADTchTensor, Data};
#[test]
fn should_diff_transpose() {

View File

@ -31,7 +31,7 @@ impl<T, P, const D: usize> AsNode<T> for ADTensor<P, D, T> {
#[cfg(test)]
mod tests {
use crate::tensor::{backend::autodiff::helper::ADTchTensor, ops::*, Data, TensorBase};
use crate::tensor::{backend::autodiff::helper::ADTchTensor, ops::*, Data};
#[test]
fn should_diff_full_complex_1() {

View File

@ -1,6 +1,6 @@
#[cfg(test)]
mod tests {
use crate::tensor::{backend::autodiff::helper::ADTchTensor, ops::*, Data, TensorBase};
use crate::tensor::{backend::autodiff::helper::ADTchTensor, ops::*, Data};
#[test]
fn should_behave_the_same_with_multithread() {

View File

@ -4,7 +4,8 @@ use crate::{
graph::node::ForwardNodeRef,
tensor::{
backend::autodiff::{ADCompatibleTensor, ADElement},
Data, Shape, TensorBase,
ops::TensorBase,
Data, Shape,
},
};

View File

@ -48,7 +48,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Data, TensorBase};
use crate::tensor::Data;
#[test]
fn should_support_add_ops() {

View File

@ -57,7 +57,7 @@ fn to_slice_args<const D1: usize, const D2: usize>(
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Data, TensorBase};
use crate::tensor::Data;
#[test]
fn should_support_full_indexing_1d() {

View File

@ -31,7 +31,7 @@ where
#[cfg(test)]
mod tests {
use crate::tensor::{backend::ndarray::NdArrayTensor, ops::*, Data, TensorBase};
use crate::tensor::{backend::ndarray::NdArrayTensor, ops::*, Data};
#[test]
fn should_matmul_d2() {

View File

@ -49,7 +49,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Data, TensorBase};
use crate::tensor::Data;
#[test]
fn should_support_mul_ops() {

View File

@ -29,7 +29,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Data, TensorBase};
use crate::tensor::Data;
#[test]
fn should_support_neg_ops() {

View File

@ -45,7 +45,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Data, TensorBase};
use crate::tensor::Data;
#[test]
fn should_support_sub_ops() {

View File

@ -1,4 +1,4 @@
use crate::tensor::{Data, Shape, TensorBase};
use crate::tensor::{ops::TensorBase, Data, Shape};
use ndarray::{s, ArcArray, Array, Axis, Dim, Dimension, Ix2, Ix3, IxDyn};
#[derive(Debug, Clone)]

View File

@ -54,7 +54,6 @@ impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> s
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::TensorBase;
#[test]
fn should_support_add_ops() {

View File

@ -59,7 +59,7 @@ impl<
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Data, TensorBase};
use crate::tensor::Data;
#[test]
fn should_support_full_indexing_1d() {

View File

@ -48,7 +48,7 @@ impl<P: tch::kind::Element + Into<f64>, const D: usize> std::ops::Mul<TchTensor<
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Data, TensorBase};
use crate::tensor::Data;
#[test]
fn should_support_mul_ops() {

View File

@ -29,7 +29,7 @@ impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> s
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Data, TensorBase};
use crate::tensor::Data;
#[test]
fn should_support_neg_ops() {

View File

@ -54,7 +54,6 @@ impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> s
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::TensorBase;
#[test]
fn should_support_sub_ops() {

View File

@ -1,4 +1,4 @@
use crate::tensor::{Data, Shape, TensorBase};
use crate::tensor::{ops::TensorBase, Data, Shape};
#[derive(Debug, PartialEq)]
pub struct TchTensor<P: tch::kind::Element, const D: usize> {

View File

@ -4,8 +4,6 @@ pub mod ops;
mod data;
mod print;
mod shape;
mod tensor;
pub use data::*;
pub use shape::*;
pub use tensor::*;

View File

@ -1,6 +1,13 @@
use crate::tensor::{Shape, TensorBase};
use super::Data;
use crate::tensor::Shape;
use std::ops::Range;
pub trait TensorBase<P, const D: usize> {
fn shape(&self) -> &Shape<D>;
fn into_data(self) -> Data<P, D>;
fn to_data(&self) -> Data<P, D>;
}
pub trait TensorOpsAdd<P, const D: usize>:
std::ops::Add<Self, Output = Self> + std::ops::Add<P, Output = Self>
where
@ -52,6 +59,7 @@ pub trait TensorOpsIndex<P, const D1: usize, const D2: usize> {
pub trait Zeros<T> {
fn zeros(&self) -> T;
}
pub trait Ones<T> {
fn ones(&self) -> T;
}

View File

@ -1,7 +0,0 @@
use crate::tensor::{Data, Shape};
pub trait TensorBase<P, const D: usize> {
fn shape(&self) -> &Shape<D>;
fn into_data(self) -> Data<P, D>;
fn to_data(&self) -> Data<P, D>;
}