Fix double broadcast with tch (#1026)

* Fix double broadcast with tch

* More fixes

* Fix clippy warm
This commit is contained in:
Nathaniel Simard 2023-12-01 10:02:57 -05:00 committed by GitHub
parent f6d14f1b1a
commit b0de56da29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 52 additions and 16 deletions

View File

@ -190,17 +190,30 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
lhs: TchTensor<i64, D>,
rhs: TchTensor<i64, D>,
) -> TchTensor<i64, D> {
TchOps::div(lhs, rhs)
let copy = false;
let non_blocking = true;
let lhs: TchTensor<f64, D> =
TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
let rhs: TchTensor<f64, D> =
TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
let out = TchOps::div(lhs, rhs);
TchTensor::<i64, D>::new(out.tensor.to_dtype(tch::Kind::Int64, non_blocking, copy))
}
fn int_div_scalar<const D: usize>(lhs: TchTensor<i64, D>, rhs: i64) -> TchTensor<i64, D> {
let copy = false;
let non_blocking = true;
let lhs: TchTensor<f64, D> =
TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, true, false));
let output: TchTensor<i64, D> = lhs.unary_ops(
TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
let out: TchTensor<f64, D> = lhs.unary_ops(
|mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
|tensor| tensor.f_div_scalar(rhs).unwrap(),
);
TchTensor::<i64, D>::new(output.tensor.to_dtype(tch::Kind::Int64, true, false))
TchTensor::<i64, D>::new(out.tensor.to_dtype(tch::Kind::Int64, non_blocking, copy))
}
fn int_neg<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<i64, D> {

View File

@ -1,10 +1,13 @@
use crate::{element::TchElement, LibTorch, LibTorchDevice};
use burn_tensor::{ops::TensorOps, Data, Shape};
use libc::c_void;
use std::{marker::PhantomData, rc::Rc};
use std::{marker::PhantomData, sync::Arc};
/// A reference to a tensor storage.
pub type StorageRef = Rc<*mut c_void>;
///
/// We manually implement `Sync` and `Send` unsafely, so even if we could use `Rc`, it isn't safe.
#[allow(clippy::arc_with_non_send_sync)]
pub type StorageRef = Arc<*mut c_void>;
/// A tensor that uses the tch backend.
#[derive(Debug, PartialEq)]
@ -23,7 +26,8 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
/// storage as the parent, you should use [from_existing](TchTensor::from_existing)
/// instead.
pub fn new(tensor: tch::Tensor) -> Self {
let data = Rc::new(tensor.data_ptr());
#[allow(clippy::arc_with_non_send_sync)]
let data = Arc::new(tensor.data_ptr());
Self {
tensor,
@ -39,9 +43,10 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
pub fn from_existing(tensor: tch::Tensor, storage_parent: StorageRef) -> Self {
let storage_child = tensor.data_ptr();
#[allow(clippy::arc_with_non_send_sync)]
let storage = match storage_child == *storage_parent {
true => storage_parent.clone(),
false => Rc::new(storage_child),
false => Arc::new(storage_child),
};
Self {
@ -82,7 +87,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
&mut self,
func: F,
) -> Option<TchTensor<EOut, D_OUT>> {
if Rc::strong_count(&self.storage) > 1 {
if Arc::strong_count(&self.storage) > 1 {
return None;
}
@ -99,7 +104,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
FOwn: Fn(tch::Tensor) -> tch::Tensor,
FRef: Fn(&tch::Tensor) -> tch::Tensor,
{
if Rc::strong_count(&self.storage) > 1 {
if Arc::strong_count(&self.storage) > 1 {
return TchTensor::from_existing(fref(&self.tensor), self.storage);
}
@ -119,19 +124,25 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor,
FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor,
{
let lhs_num_elems = lhs.shape().num_elements();
let rhs_num_elems = rhs.shape().num_elements();
let lhs_shape = lhs.shape();
let rhs_shape = rhs.shape();
let mut out_shape = Shape::new([1; D_OUT]);
let safe_mut_lhs = lhs_num_elems > rhs_num_elems;
let safe_mut_rhs = rhs_num_elems > lhs_num_elems;
for i in 0..D_OUT {
out_shape.dims[i] = usize::max(lhs_shape.dims[i], rhs_shape.dims[i]);
}
if safe_mut_lhs {
let num_elements_out = out_shape.num_elements();
// Safe to mut lhs tensor.
if lhs_shape.num_elements() == num_elements_out {
if let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor)) {
return output;
}
}
if safe_mut_rhs {
// Safe to mut rhs tensor.
if rhs_shape.num_elements() == num_elements_out {
if let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs)) {
return output;
}
@ -155,6 +166,7 @@ impl<P: tch::kind::Element, const D: usize> Clone for TchTensor<P, D> {
}
/// A shape that can be used by LibTorch.
#[derive(Debug)]
pub struct TchShape<const D: usize> {
/// The shape's dimensions.
pub dims: [i64; D],

View File

@ -30,6 +30,17 @@ mod tests {
assert_eq!(data_expected, data_actual);
}
#[test]
fn test_mul_broadcast_2_dims() {
let tensor_1: Tensor<TestBackend, 2> = Tensor::from_data([0.0, 1.0, 2.0]).reshape([3, 1]);
let tensor_2: Tensor<TestBackend, 2> = Tensor::from_data([3.0, 4.0, 5.0]).reshape([1, 3]);
let data_actual = (tensor_1 * tensor_2).into_data();
let data_expected = Data::from([[0.0, 0.0, 0.0], [3.0, 4.0, 5.0], [6.0, 8.0, 10.0]]);
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_mul_scalar_ops() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);