Feat/tensor casting (#604)

This commit is contained in:
Nathaniel Simard 2023-08-08 10:02:17 -04:00 committed by GitHub
parent 8bc687e1bb
commit 441a7011ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 146 additions and 23 deletions

View File

@ -1,5 +1,5 @@
use crate::{
tensor::{BoolTensor, IntTensor},
tensor::{ADTensor, BoolTensor, IntTensor},
ADBackendDecorator,
};
@ -80,4 +80,10 @@ impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B>
fn bool_equal_elem<const D: usize>(lhs: BoolTensor<B, D>, rhs: bool) -> BoolTensor<B, D> {
B::bool_equal_elem(lhs, rhs)
}
fn bool_into_float<const D: usize>(
tensor: BoolTensor<B, D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
ADTensor::new(B::bool_into_float(tensor))
}
}

View File

@ -1,5 +1,5 @@
use crate::{
tensor::{BoolTensor, IntTensor},
tensor::{ADTensor, BoolTensor, IntTensor},
ADBackendDecorator,
};
@ -308,4 +308,9 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn int_abs<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<D> {
B::int_abs(tensor)
}
fn int_into_float<const D: usize>(
tensor: <ADBackendDecorator<B> as Backend>::IntTensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
ADTensor::new(B::int_into_float(tensor))
}
}

View File

@ -1394,6 +1394,12 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}
}
}
fn into_int<const D: usize>(
tensor: ADTensor<B, D>,
) -> <ADBackendDecorator<B> as Backend>::IntTensorPrimitive<D> {
B::into_int(tensor.primitive)
}
}
/// Make sure the grad tensor has the given shape.

View File

@ -2,6 +2,7 @@
use alloc::vec;
use alloc::vec::Vec;
use burn_tensor::ops::{BoolTensorOps, IntTensorOps};
use burn_tensor::ElementConversion;
use core::ops::Range;
// Current crate
@ -117,4 +118,11 @@ impl<E: FloatNdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend
let array = lhs.array.mapv(|a| a == rhs).into_shared();
NdArrayTensor { array }
}
fn bool_into_float<const D: usize>(
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let array = tensor.array.mapv(|a| (a as i32).elem()).into_shared();
NdArrayTensor { array }
}
}

View File

@ -2,6 +2,7 @@
use alloc::vec;
use alloc::vec::Vec;
use burn_tensor::ops::IntTensorOps;
use burn_tensor::ElementConversion;
use core::ops::Range;
// Current crate
@ -363,4 +364,11 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
NdArrayTensor::new(array)
}
fn int_into_float<const D: usize>(
tensor: <NdArrayBackend<E> as Backend>::IntTensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let array = tensor.array.mapv(|a| a.elem()).into_shared();
NdArrayTensor { array }
}
}

View File

@ -437,4 +437,11 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
fn clamp<const D: usize>(tensor: NdArrayTensor<E, D>, min: E, max: E) -> NdArrayTensor<E, D> {
NdArrayMathOps::clamp(tensor, min, max)
}
fn into_int<const D: usize>(
tensor: <NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::IntTensorPrimitive<D> {
let array = tensor.array.mapv(|a| a.elem()).into_shared();
NdArrayTensor { array }
}
}

View File

@ -109,6 +109,11 @@ impl<E: TchElement> BoolTensorOps<TchBackend<E>> for TchBackend<E> {
}
fn bool_into_int<const D: usize>(tensor: TchTensor<bool, D>) -> TchTensor<i64, D> {
let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
TchTensor::new(tensor)
}
fn bool_into_float<const D: usize>(tensor: TchTensor<bool, D>) -> TchTensor<E, D> {
let tensor = tensor.tensor.to_kind(E::KIND);
TchTensor::new(tensor)
}

View File

@ -361,4 +361,9 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
fn int_abs<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<i64, D> {
tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
}
fn int_into_float<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<E, D> {
let tensor = tensor.tensor.to_kind(E::KIND);
TchTensor::new(tensor)
}
}

View File

@ -435,4 +435,9 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
) -> <TchBackend<E> as Backend>::TensorPrimitive<D> {
TchOps::clamp(tensor, min.elem::<f64>(), max.elem::<f64>())
}
fn into_int<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<i64, D> {
let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
TchTensor::new(tensor)
}
}

View File

@ -15,7 +15,12 @@ where
}
/// Convert the bool tensor into an int tensor.
pub fn into_int(self) -> Tensor<B, D, Int> {
pub fn int(self) -> Tensor<B, D, Int> {
Tensor::new(B::bool_into_int(self.primitive))
}
/// Convert the bool tensor into an float tensor.
pub fn float(self) -> Tensor<B, D> {
Tensor::new(B::bool_into_float(self.primitive))
}
}

View File

@ -119,7 +119,7 @@ where
/// }
/// ```
pub fn int(self) -> Tensor<B, D, Int> {
Tensor::<B, D, Int>::from_data(self.into_data().convert())
Tensor::new(B::into_int(self.primitive))
}
/// Returns a new tensor with the same shape and device as the current tensor filled with zeros.

View File

@ -81,6 +81,6 @@ where
/// }
/// ```
pub fn float(self) -> Tensor<B, D, Float> {
Tensor::<B, D, Float>::from_data(self.into_data().convert())
Tensor::new(B::int_into_float(self.primitive))
}
}

View File

@ -82,6 +82,17 @@ pub trait BoolTensorOps<B: Backend> {
fn bool_into_int<const D: usize>(tensor: B::BoolTensorPrimitive<D>)
-> B::IntTensorPrimitive<D>;
/// Converts bool tensor to float tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The float tensor with the same data as the bool tensor.
fn bool_into_float<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> B::TensorPrimitive<D>;
/// Gets the device of the tensor.
///
/// # Arguments

View File

@ -131,6 +131,17 @@ pub trait IntTensorOps<B: Backend> {
value: B::IntTensorPrimitive<D1>,
) -> B::IntTensorPrimitive<D1>;
/// Converts int tensor to float tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The int tensor with the same data as the float tensor.
fn int_into_float<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::TensorPrimitive<D>;
/// Fills the tensor with values from the source tensor if the mask is true at the given
/// indices.
///

View File

@ -163,6 +163,17 @@ pub trait TensorOps<B: Backend> {
Self::arange_step(range, 1, device)
}
/// Converts float tensor to int tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The int tensor with the same data as the float tensor.
fn into_int<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::IntTensorPrimitive<D>;
/// Creates a new tensor with values from the given range with the given step size.
///
/// # Arguments

View File

@ -1,25 +1,43 @@
#[burn_tensor_testgen::testgen(cast)]
mod tests {
use super::*;
use burn_tensor::{Data, Int, Tensor};
use burn_tensor::{Bool, Data, Int, Tensor};
#[test]
fn cast_float_tensor() {
let float_data = Data::from([[1.0, 2.0, 3.0], [4.4, 5.5, 6.6]]);
let float_tensor = Tensor::<TestBackend, 2>::from_data(float_data);
fn cast_float_to_int() {
let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.4, 5.5, 6.6]]);
let int_tensor = float_tensor.int();
let data_expected = Data::from([[1, 2, 3], [4, 5, 6]]);
assert_eq!(data_expected, int_tensor.to_data());
let actual = tensor.int().into_data();
let expected = Data::from([[1, 2, 3], [4, 5, 6]]);
assert_eq!(expected, actual);
}
#[test]
fn cast_int_tensor() {
let int_data = Data::from([[1, 2, 3], [4, 5, 6]]);
let int_tensor = Tensor::<TestBackend, 2, Int>::from_data(int_data);
fn cast_int_to_float_tensor() {
let tensor = Tensor::<TestBackend, 2, Int>::from_data([[1, 2, 3], [4, 5, 6]]);
let float_tensor = int_tensor.float();
let data_expected = Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
assert_eq!(data_expected, float_tensor.to_data());
let actual = tensor.float().into_data();
let expected = Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
assert_eq!(expected, actual);
}
#[test]
fn cast_bool_to_int_tensor() {
let tensor =
Tensor::<TestBackend, 2, Bool>::from_data([[true, false, true], [false, false, true]]);
let actual = tensor.int().into_data();
let expected = Data::from([[1, 0, 1], [0, 0, 1]]);
assert_eq!(expected, actual);
}
#[test]
fn cast_bool_to_float_tensor() {
let tensor =
Tensor::<TestBackend, 2, Bool>::from_data([[true, false, true], [false, false, true]]);
let actual = tensor.float().into_data();
let expected = Data::from([[1., 0., 1.], [0., 0., 1.]]);
assert_eq!(expected, actual);
}
}

View File

@ -49,15 +49,15 @@ impl<B: Backend> Metric for AccuracyMetric<B> {
let accuracy = match self.pad_token {
Some(pad_token) => {
let mask = targets.clone().equal_elem(pad_token as i64);
let matches = outputs.equal(targets).into_int().mask_fill(mask.clone(), 0);
let num_pad = mask.into_int().sum().into_scalar().elem::<f64>();
let matches = outputs.equal(targets).int().mask_fill(mask.clone(), 0);
let num_pad = mask.int().sum().into_scalar().elem::<f64>();
matches.sum().into_scalar().elem::<f64>() / (batch_size as f64 - num_pad)
}
None => {
outputs
.equal(targets)
.into_int()
.int()
.sum()
.into_scalar()
.elem::<f64>()

View File

@ -1,4 +1,4 @@
use super::{BoolTensor, Device, IntTensor};
use super::{BoolTensor, Device, FloatTensor, IntTensor};
use crate::{
element::{FloatElement, IntElement},
kernel,
@ -112,4 +112,8 @@ where
},
)
}
fn bool_into_float<const D: usize>(tensor: BoolTensor<Self, D>) -> FloatTensor<Self, D> {
kernel::cast(tensor)
}
}

View File

@ -446,6 +446,10 @@ where
kernel::argmin(tensor, dim)
}
fn into_int<const D: usize>(tensor: FloatTensor<Self, D>) -> IntTensor<Self, D> {
kernel::cast(tensor)
}
// TODO implement clamp kernels (see https://github.com/burn-rs/burn/issues/549)
// fn clamp_min<const D: usize>(
// tensor: FloatTensor<Self, D>,

View File

@ -1,4 +1,4 @@
use super::{numeric, BoolTensor, Device, IntElem, IntTensor};
use super::{numeric, BoolTensor, Device, FloatTensor, IntElem, IntTensor};
use crate::kernel::{unary_default, unary_inplace_default};
use crate::{
element::{FloatElement, IntElement},
@ -306,4 +306,8 @@ where
unary_default::<IntAbs, I, D>(tensor)
}
fn int_into_float<const D: usize>(tensor: IntTensor<Self, D>) -> FloatTensor<Self, D> {
kernel::cast(tensor)
}
}