Fix #2091 bug (in-place after expand) (#2114)

This commit is contained in:
Dilshod Tadjibaev 2024-08-07 16:37:20 -05:00 committed by GitHub
parent e39485322d
commit 6b61ad5a61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 11 deletions

View File

@ -496,8 +496,9 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
tensor: TchTensor<E, D>,
shape: Shape<D2>,
) -> TchTensor<E, D2> {
let tensor = tensor.tensor.broadcast_to(shape.dims.map(|x| x as i64));
TchTensor::new(tensor)
let storage = tensor.storage.clone();
let broadcasted_tensor = tensor.tensor.broadcast_to(shape.dims.map(|x| x as i64));
TchTensor::from_existing(broadcasted_tensor, storage)
}
pub fn sort<const D: usize>(

View File

@ -60,13 +60,16 @@ impl Storage {
}
}
/// A tensor that uses the tch backend.
/// A tensor using the tch backend.
#[derive(Debug, PartialEq)]
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
/// Handle to the tensor. Call methods on this field.
pub tensor: tch::Tensor,
/// The tensor's storage
pub storage: Storage,
/// The element type of the tensor.
phantom: PhantomData<E>,
}
@ -84,8 +87,8 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
Self {
tensor,
phantom: PhantomData,
storage,
phantom: PhantomData,
}
}
@ -158,7 +161,17 @@ unsafe impl<E: tch::kind::Element, const D: usize> Send for TchTensor<E, D> {}
unsafe impl<E: tch::kind::Element, const D: usize> Sync for TchTensor<E, D> {}
impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
/// Execute an operation on a tensor if the data can be reused.
/// Checks if the tensor can be mutated in-place.
///
/// Returns `true` if the tensor's stride does not contain zero (no broadcasting)
/// and the storage can be mutated.
pub fn can_mut(&self) -> bool {
let stride_contains_zero = self.tensor.stride().iter().any(|&s| s == 0);
!stride_contains_zero && self.storage.can_mut()
}
/// Executes an operation on a tensor if the data can be reused.
pub fn mut_ops<
F: Fn(&mut tch::Tensor) -> tch::Tensor,
EOut: tch::kind::Element,
@ -167,14 +180,15 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
&mut self,
func: F,
) -> Option<TchTensor<EOut, D_OUT>> {
if !self.storage.can_mut() {
if !self.can_mut() {
return None;
}
let data = self.storage.clone();
Some(TchTensor::from_existing(func(&mut self.tensor), data))
}
/// Execute a unary ops reusing the tensor data if possible.
/// Executes a unary operation, reusing the tensor data if possible.
pub fn unary_ops<FOwn, FRef, EOut: tch::kind::Element, const D_OUT: usize>(
self,
fown: FOwn,
@ -184,14 +198,14 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
FOwn: Fn(tch::Tensor) -> tch::Tensor,
FRef: Fn(&tch::Tensor) -> tch::Tensor,
{
if !self.storage.can_mut() {
if !self.can_mut() {
return TchTensor::from_existing(fref(&self.tensor), self.storage);
}
TchTensor::from_existing(fown(self.tensor), self.storage)
}
/// Execute a binary ops reusing the tensor data if possible.
/// Executes a binary operation, reusing the tensor data if possible.
pub fn binary_ops_tensor<FLMut, FRMut, FRef, EOut: tch::kind::Element, const D_OUT: usize>(
mut lhs: Self,
mut rhs: Self,
@ -214,14 +228,14 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
let num_elements_out = out_shape.num_elements();
// Safe to mut lhs tensor.
// Attempt to mutate lhs tensor
if lhs_shape.num_elements() == num_elements_out {
if let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor)) {
return output;
}
}
// Safe to mut rhs tensor.
// Attempt to mutate rhs tensor
if rhs_shape.num_elements() == num_elements_out {
if let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs)) {
return output;

View File

@ -115,4 +115,16 @@ mod tests {
let tensor = TestTensorInt::<1>::from([1, 2, 3]);
let _expanded_tensor = tensor.expand([-1, 3]);
}
/// Regression test for https://github.com/tracel-ai/burn/issues/2091
#[test]
fn inplace_op_after_expand() {
let tensor = TestTensorInt::<1>::from([1, 2, 3]);
let mut output = tensor.expand([2, 3]);
output = output + 1;
output
.into_data()
.assert_eq(&TensorData::from([[2, 3, 4], [2, 3, 4]]), false);
}
}