mirror of https://github.com/tracel-ai/burn.git
Fix double broadcast with tch (#1026)
* Fix double broadcast with tch * More fixes * Fix clippy warm
This commit is contained in:
parent
f6d14f1b1a
commit
b0de56da29
|
@ -190,17 +190,30 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
|
|||
lhs: TchTensor<i64, D>,
|
||||
rhs: TchTensor<i64, D>,
|
||||
) -> TchTensor<i64, D> {
|
||||
TchOps::div(lhs, rhs)
|
||||
let copy = false;
|
||||
let non_blocking = true;
|
||||
let lhs: TchTensor<f64, D> =
|
||||
TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
|
||||
let rhs: TchTensor<f64, D> =
|
||||
TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
|
||||
|
||||
let out = TchOps::div(lhs, rhs);
|
||||
|
||||
TchTensor::<i64, D>::new(out.tensor.to_dtype(tch::Kind::Int64, non_blocking, copy))
|
||||
}
|
||||
|
||||
fn int_div_scalar<const D: usize>(lhs: TchTensor<i64, D>, rhs: i64) -> TchTensor<i64, D> {
|
||||
let copy = false;
|
||||
let non_blocking = true;
|
||||
let lhs: TchTensor<f64, D> =
|
||||
TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, true, false));
|
||||
let output: TchTensor<i64, D> = lhs.unary_ops(
|
||||
TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
|
||||
|
||||
let out: TchTensor<f64, D> = lhs.unary_ops(
|
||||
|mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
|
||||
|tensor| tensor.f_div_scalar(rhs).unwrap(),
|
||||
);
|
||||
TchTensor::<i64, D>::new(output.tensor.to_dtype(tch::Kind::Int64, true, false))
|
||||
|
||||
TchTensor::<i64, D>::new(out.tensor.to_dtype(tch::Kind::Int64, non_blocking, copy))
|
||||
}
|
||||
|
||||
fn int_neg<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<i64, D> {
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
use crate::{element::TchElement, LibTorch, LibTorchDevice};
|
||||
use burn_tensor::{ops::TensorOps, Data, Shape};
|
||||
use libc::c_void;
|
||||
use std::{marker::PhantomData, rc::Rc};
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
|
||||
/// A reference to a tensor storage.
|
||||
pub type StorageRef = Rc<*mut c_void>;
|
||||
///
|
||||
/// We manually implement `Sync` and `Send` unsafely, so even if we could use `Rc`, it isn't safe.
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
pub type StorageRef = Arc<*mut c_void>;
|
||||
|
||||
/// A tensor that uses the tch backend.
|
||||
#[derive(Debug, PartialEq)]
|
||||
|
@ -23,7 +26,8 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
|
|||
/// storage as the parent, you should use [from_existing](TchTensor::from_existing)
|
||||
/// instead.
|
||||
pub fn new(tensor: tch::Tensor) -> Self {
|
||||
let data = Rc::new(tensor.data_ptr());
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
let data = Arc::new(tensor.data_ptr());
|
||||
|
||||
Self {
|
||||
tensor,
|
||||
|
@ -39,9 +43,10 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
|
|||
pub fn from_existing(tensor: tch::Tensor, storage_parent: StorageRef) -> Self {
|
||||
let storage_child = tensor.data_ptr();
|
||||
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
let storage = match storage_child == *storage_parent {
|
||||
true => storage_parent.clone(),
|
||||
false => Rc::new(storage_child),
|
||||
false => Arc::new(storage_child),
|
||||
};
|
||||
|
||||
Self {
|
||||
|
@ -82,7 +87,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
|
|||
&mut self,
|
||||
func: F,
|
||||
) -> Option<TchTensor<EOut, D_OUT>> {
|
||||
if Rc::strong_count(&self.storage) > 1 {
|
||||
if Arc::strong_count(&self.storage) > 1 {
|
||||
return None;
|
||||
}
|
||||
|
||||
|
@ -99,7 +104,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
|
|||
FOwn: Fn(tch::Tensor) -> tch::Tensor,
|
||||
FRef: Fn(&tch::Tensor) -> tch::Tensor,
|
||||
{
|
||||
if Rc::strong_count(&self.storage) > 1 {
|
||||
if Arc::strong_count(&self.storage) > 1 {
|
||||
return TchTensor::from_existing(fref(&self.tensor), self.storage);
|
||||
}
|
||||
|
||||
|
@ -119,19 +124,25 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
|
|||
FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor,
|
||||
FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor,
|
||||
{
|
||||
let lhs_num_elems = lhs.shape().num_elements();
|
||||
let rhs_num_elems = rhs.shape().num_elements();
|
||||
let lhs_shape = lhs.shape();
|
||||
let rhs_shape = rhs.shape();
|
||||
let mut out_shape = Shape::new([1; D_OUT]);
|
||||
|
||||
let safe_mut_lhs = lhs_num_elems > rhs_num_elems;
|
||||
let safe_mut_rhs = rhs_num_elems > lhs_num_elems;
|
||||
for i in 0..D_OUT {
|
||||
out_shape.dims[i] = usize::max(lhs_shape.dims[i], rhs_shape.dims[i]);
|
||||
}
|
||||
|
||||
if safe_mut_lhs {
|
||||
let num_elements_out = out_shape.num_elements();
|
||||
|
||||
// Safe to mut lhs tensor.
|
||||
if lhs_shape.num_elements() == num_elements_out {
|
||||
if let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor)) {
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
||||
if safe_mut_rhs {
|
||||
// Safe to mut rhs tensor.
|
||||
if rhs_shape.num_elements() == num_elements_out {
|
||||
if let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs)) {
|
||||
return output;
|
||||
}
|
||||
|
@ -155,6 +166,7 @@ impl<P: tch::kind::Element, const D: usize> Clone for TchTensor<P, D> {
|
|||
}
|
||||
|
||||
/// A shape that can be used by LibTorch.
|
||||
#[derive(Debug)]
|
||||
pub struct TchShape<const D: usize> {
|
||||
/// The shape's dimensions.
|
||||
pub dims: [i64; D],
|
||||
|
|
|
@ -30,6 +30,17 @@ mod tests {
|
|||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul_broadcast_2_dims() {
|
||||
let tensor_1: Tensor<TestBackend, 2> = Tensor::from_data([0.0, 1.0, 2.0]).reshape([3, 1]);
|
||||
let tensor_2: Tensor<TestBackend, 2> = Tensor::from_data([3.0, 4.0, 5.0]).reshape([1, 3]);
|
||||
|
||||
let data_actual = (tensor_1 * tensor_2).into_data();
|
||||
|
||||
let data_expected = Data::from([[0.0, 0.0, 0.0], [3.0, 4.0, 5.0], [6.0, 8.0, 10.0]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_mul_scalar_ops() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
|
Loading…
Reference in New Issue