mirror of https://github.com/tracel-ai/burn.git
parent
e39485322d
commit
6b61ad5a61
|
@ -496,8 +496,9 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
tensor: TchTensor<E, D>,
|
||||
shape: Shape<D2>,
|
||||
) -> TchTensor<E, D2> {
|
||||
let tensor = tensor.tensor.broadcast_to(shape.dims.map(|x| x as i64));
|
||||
TchTensor::new(tensor)
|
||||
let storage = tensor.storage.clone();
|
||||
let broadcasted_tensor = tensor.tensor.broadcast_to(shape.dims.map(|x| x as i64));
|
||||
TchTensor::from_existing(broadcasted_tensor, storage)
|
||||
}
|
||||
|
||||
pub fn sort<const D: usize>(
|
||||
|
|
|
@ -60,13 +60,16 @@ impl Storage {
|
|||
}
|
||||
}
|
||||
|
||||
/// A tensor that uses the tch backend.
|
||||
/// A tensor using the tch backend.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
|
||||
/// Handle to the tensor. Call methods on this field.
|
||||
pub tensor: tch::Tensor,
|
||||
|
||||
/// The tensor's storage
|
||||
pub storage: Storage,
|
||||
|
||||
/// The element type of the tensor.
|
||||
phantom: PhantomData<E>,
|
||||
}
|
||||
|
||||
|
@ -84,8 +87,8 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
|
|||
|
||||
Self {
|
||||
tensor,
|
||||
phantom: PhantomData,
|
||||
storage,
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -158,7 +161,17 @@ unsafe impl<E: tch::kind::Element, const D: usize> Send for TchTensor<E, D> {}
|
|||
unsafe impl<E: tch::kind::Element, const D: usize> Sync for TchTensor<E, D> {}
|
||||
|
||||
impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
|
||||
/// Execute an operation on a tensor if the data can be reused.
|
||||
/// Checks if the tensor can be mutated in-place.
|
||||
///
|
||||
/// Returns `true` if the tensor's stride does not contain zero (no broadcasting)
|
||||
/// and the storage can be mutated.
|
||||
pub fn can_mut(&self) -> bool {
|
||||
let stride_contains_zero = self.tensor.stride().iter().any(|&s| s == 0);
|
||||
|
||||
!stride_contains_zero && self.storage.can_mut()
|
||||
}
|
||||
|
||||
/// Executes an operation on a tensor if the data can be reused.
|
||||
pub fn mut_ops<
|
||||
F: Fn(&mut tch::Tensor) -> tch::Tensor,
|
||||
EOut: tch::kind::Element,
|
||||
|
@ -167,14 +180,15 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
|
|||
&mut self,
|
||||
func: F,
|
||||
) -> Option<TchTensor<EOut, D_OUT>> {
|
||||
if !self.storage.can_mut() {
|
||||
if !self.can_mut() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let data = self.storage.clone();
|
||||
Some(TchTensor::from_existing(func(&mut self.tensor), data))
|
||||
}
|
||||
/// Execute a unary ops reusing the tensor data if possible.
|
||||
|
||||
/// Executes a unary operation, reusing the tensor data if possible.
|
||||
pub fn unary_ops<FOwn, FRef, EOut: tch::kind::Element, const D_OUT: usize>(
|
||||
self,
|
||||
fown: FOwn,
|
||||
|
@ -184,14 +198,14 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
|
|||
FOwn: Fn(tch::Tensor) -> tch::Tensor,
|
||||
FRef: Fn(&tch::Tensor) -> tch::Tensor,
|
||||
{
|
||||
if !self.storage.can_mut() {
|
||||
if !self.can_mut() {
|
||||
return TchTensor::from_existing(fref(&self.tensor), self.storage);
|
||||
}
|
||||
|
||||
TchTensor::from_existing(fown(self.tensor), self.storage)
|
||||
}
|
||||
|
||||
/// Execute a binary ops reusing the tensor data if possible.
|
||||
/// Executes a binary operation, reusing the tensor data if possible.
|
||||
pub fn binary_ops_tensor<FLMut, FRMut, FRef, EOut: tch::kind::Element, const D_OUT: usize>(
|
||||
mut lhs: Self,
|
||||
mut rhs: Self,
|
||||
|
@ -214,14 +228,14 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
|
|||
|
||||
let num_elements_out = out_shape.num_elements();
|
||||
|
||||
// Safe to mut lhs tensor.
|
||||
// Attempt to mutate lhs tensor
|
||||
if lhs_shape.num_elements() == num_elements_out {
|
||||
if let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor)) {
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
||||
// Safe to mut rhs tensor.
|
||||
// Attempt to mutate rhs tensor
|
||||
if rhs_shape.num_elements() == num_elements_out {
|
||||
if let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs)) {
|
||||
return output;
|
||||
|
|
|
@ -115,4 +115,16 @@ mod tests {
|
|||
let tensor = TestTensorInt::<1>::from([1, 2, 3]);
|
||||
let _expanded_tensor = tensor.expand([-1, 3]);
|
||||
}
|
||||
|
||||
/// Regression test for https://github.com/tracel-ai/burn/issues/2091
|
||||
#[test]
|
||||
fn inplace_op_after_expand() {
|
||||
let tensor = TestTensorInt::<1>::from([1, 2, 3]);
|
||||
let mut output = tensor.expand([2, 3]);
|
||||
output = output + 1;
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[2, 3, 4], [2, 3, 4]]), false);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue