Change ndarray mask_where implementation to correctly deal with NaNs (#2272)

* Change ndarray mask_where implementation to correctly deal with NaNs

* Add test
This commit is contained in:
Guillaume Lagrange 2024-09-13 15:16:39 -04:00 committed by GitHub
parent 2fbad48f64
commit 6f0e61aa4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 11 deletions

View File

@ -409,17 +409,14 @@ where
mask: NdArrayTensor<bool, D>, mask: NdArrayTensor<bool, D>,
source: NdArrayTensor<E, D>, source: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> { ) -> NdArrayTensor<E, D> {
let mask_mul_4tensor = mask.array.mapv(|x| match x { let tensor = tensor.array.broadcast(mask.array.dim()).unwrap();
true => 0.elem(), let source = source.array.broadcast(mask.array.dim()).unwrap();
false => 1.elem(), let output = Zip::from(&tensor)
}); .and(&mask.array)
let mask_mul_4source = mask.array.mapv(|x| match x { .and(&source)
true => 1.elem(), .map_collect(|&x, &mask_val, &y| if mask_val { y } else { x })
false => 0.elem(), .into_shared();
}); NdArrayTensor::new(output)
let array = (tensor.array * mask_mul_4tensor) + (source.array * mask_mul_4source);
NdArrayTensor::new(array)
} }
pub fn mask_fill<const D: usize>( pub fn mask_fill<const D: usize>(

View File

@ -22,6 +22,40 @@ mod tests {
output.into_data().assert_eq(&expected, false); output.into_data().assert_eq(&expected, false);
} }
#[test]
fn should_handle_mask_where_nans() {
let device = Default::default();
let tensor = TestTensor::from_data(
[
[f32::NAN, f32::NAN, f32::NAN],
[f32::NAN, f32::NAN, f32::NAN],
[f32::NAN, f32::NAN, f32::NAN],
],
&device,
);
let mask = Tensor::<TestBackend, 2, Bool>::from_bool(
TensorData::from([
[true, true, true],
[true, true, false],
[false, false, false],
]),
&device,
);
let value = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.9, 0.8, 0.7], [0.6, 0.5, 0.4], [0.3, 0.2, 0.1]]),
&device,
);
let output = tensor.mask_where(mask, value);
let expected = TensorData::from([
[0.9, 0.8, 0.7],
[0.6, 0.5, f32::NAN],
[f32::NAN, f32::NAN, f32::NAN],
]);
output.into_data().assert_eq(&expected, false);
}
#[test] #[test]
fn should_support_mask_fill_ops() { fn should_support_mask_fill_ops() {
let device = Default::default(); let device = Default::default();