refactor: detach-ops (#93)

This commit is contained in:
Nathaniel Simard 2022-11-12 09:44:59 -05:00 committed by GitHub
parent 3da122db09
commit e7094b92ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 14 additions and 44 deletions

View File

@ -1,9 +0,0 @@
use crate::tensor::backend::Backend;
use crate::tensor::{backend::autodiff::ADTensor, ops::*};
impl<B: Backend, P, const D: usize> TensorOpsDetach<P, D> for ADTensor<D, B> {
fn detach(self) -> Self {
let tensor = self.tensor();
Self::from_tensor(tensor.detach())
}
}

View File

@ -3,7 +3,6 @@ mod arg;
mod base;
mod cat;
mod creation;
mod detach;
mod erf;
mod exp;
mod log;

View File

@ -702,4 +702,10 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
) -> <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D> {
B::lower_equal_scalar(lhs.tensor_ref(), rhs)
}
fn detach<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
ADTensor::from_tensor(B::detach(tensor.tensor_ref()))
}
}

View File

@ -21,7 +21,6 @@ pub trait Backend:
type FullPrecisionBackend: Backend<Elem = Self::FullPrecisionElem, Device = Self::Device>;
type IntegerBackend: Backend<Elem = i64, Device = Self::Device>;
type TensorPrimitive<const D: usize>: std::ops::Add<Self::TensorPrimitive<D>, Output = Self::TensorPrimitive<D>>
+ TensorOpsDetach<Self::Elem, D>
+ Zeros<Self::TensorPrimitive<D>>
+ Ones<Self::TensorPrimitive<D>>
+ TensorOpsPrecision<Self, D>

View File

@ -1,13 +0,0 @@
use crate::{
tensor::{backend::ndarray::NdArrayTensor, ops::*},
NdArrayElement,
};
impl<E, const D: usize> TensorOpsDetach<E, D> for NdArrayTensor<E, D>
where
E: NdArrayElement,
{
fn detach(self) -> Self {
self
}
}

View File

@ -2,7 +2,6 @@ mod aggregation;
mod arg;
mod cat;
mod creation;
mod detach;
mod erf;
mod exp;
mod log;

View File

@ -365,6 +365,9 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
array,
}
}
fn detach<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
tensor.clone()
}
}
fn to_slice_args<const D1: usize, const D2: usize>(

View File

@ -1,13 +0,0 @@
use crate::{
tensor::{backend::tch::TchTensor, ops::*},
TchElement,
};
impl<E, const D: usize> TensorOpsDetach<E, D> for TchTensor<E, D>
where
E: TchElement,
{
fn detach(self) -> Self {
self
}
}

View File

@ -2,7 +2,6 @@ mod aggregation;
mod arg;
mod cat;
mod creation;
mod detach;
mod erf;
mod exp;
mod log;

View File

@ -318,6 +318,9 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
kind: TchKind::<bool>::new(),
}
}
fn detach<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
tensor.clone()
}
}
fn to_tensor<const D: usize, E: TchElement>(tensor: tch::Tensor) -> TchTensor<E, D> {

View File

@ -519,7 +519,7 @@ where
/// This can be used in batchers or elsewere to ensure that previous operations are not
/// considered in the autodiff graph.
pub fn detach(self) -> Self {
Self::new(self.value.detach())
Self::new(B::detach(&self.value))
}
/// Unsqueeze the current tensor. Create new dimensions to fit the given size.

View File

@ -169,6 +169,7 @@ pub trait TensorOps<B: Backend> {
lhs: &B::TensorPrimitive<D>,
rhs: &B::Elem,
) -> B::BoolTensorPrimitive<D>;
fn detach<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
}
pub trait TensorOpsAggregation<B: Backend, const D: usize> {
@ -206,10 +207,6 @@ pub trait TensorOpsLog<E, const D: usize> {
fn log(&self) -> Self;
}
pub trait TensorOpsDetach<E, const D: usize> {
fn detach(self) -> Self;
}
pub trait TensorOpsErf<E, const D: usize> {
fn erf(&self) -> Self;
}