From 6f0e61aa4f7368914c2492b11095a74145f4209c Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 13 Sep 2024 15:16:39 -0400 Subject: [PATCH] Change ndarray mask_where implementation to correctly deal with NaNs (#2272) * Change ndarray mask_where implementation to correctly deal with NaNs * Add test --- crates/burn-ndarray/src/ops/base.rs | 19 ++++++------- crates/burn-tensor/src/tests/ops/mask.rs | 34 ++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/crates/burn-ndarray/src/ops/base.rs b/crates/burn-ndarray/src/ops/base.rs index eed8915e7..8edc248d0 100644 --- a/crates/burn-ndarray/src/ops/base.rs +++ b/crates/burn-ndarray/src/ops/base.rs @@ -409,17 +409,14 @@ where mask: NdArrayTensor, source: NdArrayTensor, ) -> NdArrayTensor { - let mask_mul_4tensor = mask.array.mapv(|x| match x { - true => 0.elem(), - false => 1.elem(), - }); - let mask_mul_4source = mask.array.mapv(|x| match x { - true => 1.elem(), - false => 0.elem(), - }); - let array = (tensor.array * mask_mul_4tensor) + (source.array * mask_mul_4source); - - NdArrayTensor::new(array) + let tensor = tensor.array.broadcast(mask.array.dim()).unwrap(); + let source = source.array.broadcast(mask.array.dim()).unwrap(); + let output = Zip::from(&tensor) + .and(&mask.array) + .and(&source) + .map_collect(|&x, &mask_val, &y| if mask_val { y } else { x }) + .into_shared(); + NdArrayTensor::new(output) } pub fn mask_fill( diff --git a/crates/burn-tensor/src/tests/ops/mask.rs b/crates/burn-tensor/src/tests/ops/mask.rs index 9613d85d1..677e70960 100644 --- a/crates/burn-tensor/src/tests/ops/mask.rs +++ b/crates/burn-tensor/src/tests/ops/mask.rs @@ -22,6 +22,40 @@ mod tests { output.into_data().assert_eq(&expected, false); } + #[test] + fn should_handle_mask_where_nans() { + let device = Default::default(); + let tensor = TestTensor::from_data( + [ + [f32::NAN, f32::NAN, f32::NAN], + [f32::NAN, f32::NAN, f32::NAN], + [f32::NAN, f32::NAN, f32::NAN], + ], + &device, + ); + let mask = Tensor::::from_bool( + TensorData::from([ + [true, true, true], + [true, true, false], + [false, false, false], + ]), + &device, + ); + let value = Tensor::::from_data( + TensorData::from([[0.9, 0.8, 0.7], [0.6, 0.5, 0.4], [0.3, 0.2, 0.1]]), + &device, + ); + + let output = tensor.mask_where(mask, value); + let expected = TensorData::from([ + [0.9, 0.8, 0.7], + [0.6, 0.5, f32::NAN], + [f32::NAN, f32::NAN, f32::NAN], + ]); + + output.into_data().assert_eq(&expected, false); + } + #[test] fn should_support_mask_fill_ops() { let device = Default::default();