Remainder operator (#1726)

* Adds remainder ops implementation for Tensor.

* Adds test for % operator.
This commit is contained in:
Jonas Kantic 2024-06-01 23:47:02 +02:00 committed by GitHub
parent 99e1ba4864
commit fba1e27e0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 0 deletions

View File

@ -2868,6 +2868,20 @@ where
}
}
impl<E, const D: usize, B, K> core::ops::Rem<E> for Tensor<B, D, K>
where
E: ElementConversion,
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
type Output = Self;
fn rem(self, other: E) -> Self {
Tensor::remainder_scalar(self, other)
}
}
impl<B, const D: usize, K> core::ops::Mul<Tensor<B, D, K>> for Tensor<B, D, K>
where
B: Backend,

View File

@ -95,4 +95,17 @@ mod tests {
let data_expected = Data::from([9.0, 1.0]);
data_expected.assert_approx_eq(&data_actual, 3);
}
#[test]
fn should_support_remainder_op() {
let data = Data::from([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]);
let device = Default::default();
let tensor = Tensor::<TestBackend, 1>::from_data(data, &device);
let output = tensor % 2.0;
let data_actual = output.into_data();
let data_expected = Data::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]);
data_expected.assert_approx_eq(&data_actual, 3);
}
}