diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index ed4ea4750..cc6577404 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -496,8 +496,9 @@ impl TchOps { tensor: TchTensor, shape: Shape, ) -> TchTensor { - let tensor = tensor.tensor.broadcast_to(shape.dims.map(|x| x as i64)); - TchTensor::new(tensor) + let storage = tensor.storage.clone(); + let broadcasted_tensor = tensor.tensor.broadcast_to(shape.dims.map(|x| x as i64)); + TchTensor::from_existing(broadcasted_tensor, storage) } pub fn sort( diff --git a/crates/burn-tch/src/tensor.rs b/crates/burn-tch/src/tensor.rs index bc4db651f..148c3ee43 100644 --- a/crates/burn-tch/src/tensor.rs +++ b/crates/burn-tch/src/tensor.rs @@ -60,13 +60,16 @@ impl Storage { } } -/// A tensor that uses the tch backend. +/// A tensor using the tch backend. #[derive(Debug, PartialEq)] pub struct TchTensor { /// Handle to the tensor. Call methods on this field. pub tensor: tch::Tensor, + /// The tensor's storage pub storage: Storage, + + /// The element type of the tensor. phantom: PhantomData, } @@ -84,8 +87,8 @@ impl TchTensor { Self { tensor, - phantom: PhantomData, storage, + phantom: PhantomData, } } @@ -158,7 +161,17 @@ unsafe impl Send for TchTensor {} unsafe impl Sync for TchTensor {} impl TchTensor { - /// Execute an operation on a tensor if the data can be reused. + /// Checks if the tensor can be mutated in-place. + /// + /// Returns `true` if the tensor's stride does not contain zero (no broadcasting) + /// and the storage can be mutated. + pub fn can_mut(&self) -> bool { + let stride_contains_zero = self.tensor.stride().iter().any(|&s| s == 0); + + !stride_contains_zero && self.storage.can_mut() + } + + /// Executes an operation on a tensor if the data can be reused. pub fn mut_ops< F: Fn(&mut tch::Tensor) -> tch::Tensor, EOut: tch::kind::Element, @@ -167,14 +180,15 @@ impl TchTensor { &mut self, func: F, ) -> Option> { - if !self.storage.can_mut() { + if !self.can_mut() { return None; } let data = self.storage.clone(); Some(TchTensor::from_existing(func(&mut self.tensor), data)) } - /// Execute a unary ops reusing the tensor data if possible. + + /// Executes a unary operation, reusing the tensor data if possible. pub fn unary_ops( self, fown: FOwn, @@ -184,14 +198,14 @@ impl TchTensor { FOwn: Fn(tch::Tensor) -> tch::Tensor, FRef: Fn(&tch::Tensor) -> tch::Tensor, { - if !self.storage.can_mut() { + if !self.can_mut() { return TchTensor::from_existing(fref(&self.tensor), self.storage); } TchTensor::from_existing(fown(self.tensor), self.storage) } - /// Execute a binary ops reusing the tensor data if possible. + /// Executes a binary operation, reusing the tensor data if possible. pub fn binary_ops_tensor( mut lhs: Self, mut rhs: Self, @@ -214,14 +228,14 @@ impl TchTensor { let num_elements_out = out_shape.num_elements(); - // Safe to mut lhs tensor. + // Attempt to mutate lhs tensor if lhs_shape.num_elements() == num_elements_out { if let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor)) { return output; } } - // Safe to mut rhs tensor. + // Attempt to mutate rhs tensor if rhs_shape.num_elements() == num_elements_out { if let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs)) { return output; diff --git a/crates/burn-tensor/src/tests/ops/expand.rs b/crates/burn-tensor/src/tests/ops/expand.rs index 7c223ec24..3ac85d17e 100644 --- a/crates/burn-tensor/src/tests/ops/expand.rs +++ b/crates/burn-tensor/src/tests/ops/expand.rs @@ -115,4 +115,16 @@ mod tests { let tensor = TestTensorInt::<1>::from([1, 2, 3]); let _expanded_tensor = tensor.expand([-1, 3]); } + + /// Regression test for https://github.com/tracel-ai/burn/issues/2091 + #[test] + fn inplace_op_after_expand() { + let tensor = TestTensorInt::<1>::from([1, 2, 3]); + let mut output = tensor.expand([2, 3]); + output = output + 1; + + output + .into_data() + .assert_eq(&TensorData::from([[2, 3, 4], [2, 3, 4]]), false); + } }