mirror of https://github.com/tracel-ai/burn.git
refactor: detach-ops (#93)
This commit is contained in:
parent
3da122db09
commit
e7094b92ac
|
@ -1,9 +0,0 @@
|
|||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::{backend::autodiff::ADTensor, ops::*};
|
||||
|
||||
impl<B: Backend, P, const D: usize> TensorOpsDetach<P, D> for ADTensor<D, B> {
|
||||
fn detach(self) -> Self {
|
||||
let tensor = self.tensor();
|
||||
Self::from_tensor(tensor.detach())
|
||||
}
|
||||
}
|
|
@ -3,7 +3,6 @@ mod arg;
|
|||
mod base;
|
||||
mod cat;
|
||||
mod creation;
|
||||
mod detach;
|
||||
mod erf;
|
||||
mod exp;
|
||||
mod log;
|
||||
|
|
|
@ -702,4 +702,10 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
) -> <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D> {
|
||||
B::lower_equal_scalar(lhs.tensor_ref(), rhs)
|
||||
}
|
||||
|
||||
fn detach<const D: usize>(
|
||||
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
|
||||
ADTensor::from_tensor(B::detach(tensor.tensor_ref()))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ pub trait Backend:
|
|||
type FullPrecisionBackend: Backend<Elem = Self::FullPrecisionElem, Device = Self::Device>;
|
||||
type IntegerBackend: Backend<Elem = i64, Device = Self::Device>;
|
||||
type TensorPrimitive<const D: usize>: std::ops::Add<Self::TensorPrimitive<D>, Output = Self::TensorPrimitive<D>>
|
||||
+ TensorOpsDetach<Self::Elem, D>
|
||||
+ Zeros<Self::TensorPrimitive<D>>
|
||||
+ Ones<Self::TensorPrimitive<D>>
|
||||
+ TensorOpsPrecision<Self, D>
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
use crate::{
|
||||
tensor::{backend::ndarray::NdArrayTensor, ops::*},
|
||||
NdArrayElement,
|
||||
};
|
||||
|
||||
impl<E, const D: usize> TensorOpsDetach<E, D> for NdArrayTensor<E, D>
|
||||
where
|
||||
E: NdArrayElement,
|
||||
{
|
||||
fn detach(self) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
|
@ -2,7 +2,6 @@ mod aggregation;
|
|||
mod arg;
|
||||
mod cat;
|
||||
mod creation;
|
||||
mod detach;
|
||||
mod erf;
|
||||
mod exp;
|
||||
mod log;
|
||||
|
|
|
@ -365,6 +365,9 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
array,
|
||||
}
|
||||
}
|
||||
fn detach<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
tensor.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn to_slice_args<const D1: usize, const D2: usize>(
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
use crate::{
|
||||
tensor::{backend::tch::TchTensor, ops::*},
|
||||
TchElement,
|
||||
};
|
||||
|
||||
impl<E, const D: usize> TensorOpsDetach<E, D> for TchTensor<E, D>
|
||||
where
|
||||
E: TchElement,
|
||||
{
|
||||
fn detach(self) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
|
@ -2,7 +2,6 @@ mod aggregation;
|
|||
mod arg;
|
||||
mod cat;
|
||||
mod creation;
|
||||
mod detach;
|
||||
mod erf;
|
||||
mod exp;
|
||||
mod log;
|
||||
|
|
|
@ -318,6 +318,9 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
kind: TchKind::<bool>::new(),
|
||||
}
|
||||
}
|
||||
fn detach<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
tensor.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn to_tensor<const D: usize, E: TchElement>(tensor: tch::Tensor) -> TchTensor<E, D> {
|
||||
|
|
|
@ -519,7 +519,7 @@ where
|
|||
/// This can be used in batchers or elsewere to ensure that previous operations are not
|
||||
/// considered in the autodiff graph.
|
||||
pub fn detach(self) -> Self {
|
||||
Self::new(self.value.detach())
|
||||
Self::new(B::detach(&self.value))
|
||||
}
|
||||
|
||||
/// Unsqueeze the current tensor. Create new dimensions to fit the given size.
|
||||
|
|
|
@ -169,6 +169,7 @@ pub trait TensorOps<B: Backend> {
|
|||
lhs: &B::TensorPrimitive<D>,
|
||||
rhs: &B::Elem,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
fn detach<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
}
|
||||
|
||||
pub trait TensorOpsAggregation<B: Backend, const D: usize> {
|
||||
|
@ -206,10 +207,6 @@ pub trait TensorOpsLog<E, const D: usize> {
|
|||
fn log(&self) -> Self;
|
||||
}
|
||||
|
||||
pub trait TensorOpsDetach<E, const D: usize> {
|
||||
fn detach(self) -> Self;
|
||||
}
|
||||
|
||||
pub trait TensorOpsErf<E, const D: usize> {
|
||||
fn erf(&self) -> Self;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue