diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/detach.rs b/burn-tensor/src/tensor/backend/autodiff/ops/detach.rs deleted file mode 100644 index 0d1914945..000000000 --- a/burn-tensor/src/tensor/backend/autodiff/ops/detach.rs +++ /dev/null @@ -1,9 +0,0 @@ -use crate::tensor::backend::Backend; -use crate::tensor::{backend::autodiff::ADTensor, ops::*}; - -impl TensorOpsDetach for ADTensor { - fn detach(self) -> Self { - let tensor = self.tensor(); - Self::from_tensor(tensor.detach()) - } -} diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs index ffce8f406..ec40330b8 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs @@ -3,7 +3,6 @@ mod arg; mod base; mod cat; mod creation; -mod detach; mod erf; mod exp; mod log; diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs b/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs index ae4d21dd5..d54c02501 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs @@ -702,4 +702,10 @@ impl TensorOps> for ADBackendDecorator { ) -> as Backend>::BoolTensorPrimitive { B::lower_equal_scalar(lhs.tensor_ref(), rhs) } + + fn detach( + tensor: & as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + ADTensor::from_tensor(B::detach(tensor.tensor_ref())) + } } diff --git a/burn-tensor/src/tensor/backend/base.rs b/burn-tensor/src/tensor/backend/base.rs index 5bfdafa80..c3324f14f 100644 --- a/burn-tensor/src/tensor/backend/base.rs +++ b/burn-tensor/src/tensor/backend/base.rs @@ -21,7 +21,6 @@ pub trait Backend: type FullPrecisionBackend: Backend; type IntegerBackend: Backend; type TensorPrimitive: std::ops::Add, Output = Self::TensorPrimitive> - + TensorOpsDetach + Zeros> + Ones> + TensorOpsPrecision diff --git a/burn-tensor/src/tensor/backend/ndarray/ops/detach.rs b/burn-tensor/src/tensor/backend/ndarray/ops/detach.rs deleted file mode 100644 index 8bf07d933..000000000 --- a/burn-tensor/src/tensor/backend/ndarray/ops/detach.rs +++ /dev/null @@ -1,13 +0,0 @@ -use crate::{ - tensor::{backend::ndarray::NdArrayTensor, ops::*}, - NdArrayElement, -}; - -impl TensorOpsDetach for NdArrayTensor -where - E: NdArrayElement, -{ - fn detach(self) -> Self { - self - } -} diff --git a/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs b/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs index 06ccecabb..1b1028949 100644 --- a/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs @@ -2,7 +2,6 @@ mod aggregation; mod arg; mod cat; mod creation; -mod detach; mod erf; mod exp; mod log; diff --git a/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs b/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs index a4b5ddd30..a3bc85b50 100644 --- a/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs +++ b/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs @@ -365,6 +365,9 @@ impl TensorOps> for NdArrayBackend { array, } } + fn detach(tensor: &NdArrayTensor) -> NdArrayTensor { + tensor.clone() + } } fn to_slice_args( diff --git a/burn-tensor/src/tensor/backend/tch/ops/detach.rs b/burn-tensor/src/tensor/backend/tch/ops/detach.rs deleted file mode 100644 index c3e02d44b..000000000 --- a/burn-tensor/src/tensor/backend/tch/ops/detach.rs +++ /dev/null @@ -1,13 +0,0 @@ -use crate::{ - tensor::{backend::tch::TchTensor, ops::*}, - TchElement, -}; - -impl TensorOpsDetach for TchTensor -where - E: TchElement, -{ - fn detach(self) -> Self { - self - } -} diff --git a/burn-tensor/src/tensor/backend/tch/ops/mod.rs b/burn-tensor/src/tensor/backend/tch/ops/mod.rs index 06ccecabb..1b1028949 100644 --- a/burn-tensor/src/tensor/backend/tch/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/tch/ops/mod.rs @@ -2,7 +2,6 @@ mod aggregation; mod arg; mod cat; mod creation; -mod detach; mod erf; mod exp; mod log; diff --git a/burn-tensor/src/tensor/backend/tch/tensor_ops.rs b/burn-tensor/src/tensor/backend/tch/tensor_ops.rs index fa520945e..a8910fd5e 100644 --- a/burn-tensor/src/tensor/backend/tch/tensor_ops.rs +++ b/burn-tensor/src/tensor/backend/tch/tensor_ops.rs @@ -318,6 +318,9 @@ impl TensorOps> for TchBackend { kind: TchKind::::new(), } } + fn detach(tensor: &TchTensor) -> TchTensor { + tensor.clone() + } } fn to_tensor(tensor: tch::Tensor) -> TchTensor { diff --git a/burn-tensor/src/tensor/base.rs b/burn-tensor/src/tensor/base.rs index 276bd7ddd..b5b7116ac 100644 --- a/burn-tensor/src/tensor/base.rs +++ b/burn-tensor/src/tensor/base.rs @@ -519,7 +519,7 @@ where /// This can be used in batchers or elsewere to ensure that previous operations are not /// considered in the autodiff graph. pub fn detach(self) -> Self { - Self::new(self.value.detach()) + Self::new(B::detach(&self.value)) } /// Unsqueeze the current tensor. Create new dimensions to fit the given size. diff --git a/burn-tensor/src/tensor/ops/base.rs b/burn-tensor/src/tensor/ops/base.rs index a07056cde..54ac1edab 100644 --- a/burn-tensor/src/tensor/ops/base.rs +++ b/burn-tensor/src/tensor/ops/base.rs @@ -169,6 +169,7 @@ pub trait TensorOps { lhs: &B::TensorPrimitive, rhs: &B::Elem, ) -> B::BoolTensorPrimitive; + fn detach(tensor: &B::TensorPrimitive) -> B::TensorPrimitive; } pub trait TensorOpsAggregation { @@ -206,10 +207,6 @@ pub trait TensorOpsLog { fn log(&self) -> Self; } -pub trait TensorOpsDetach { - fn detach(self) -> Self; -} - pub trait TensorOpsErf { fn erf(&self) -> Self; }