fix powf bugs (#1207)

This commit is contained in:
Louis Fortier-Dubois 2024-01-31 11:38:45 -05:00 committed by GitHub
parent f1d98bc5f8
commit e03facc5de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View File

@ -447,8 +447,8 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
TchTensor::binary_ops_tensor(
tensor,
exponent,
|lhs, rhs| lhs.f_pow_tensor_(rhs).unwrap(),
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
|lhs, rhs| rhs.f_pow(lhs).unwrap(),
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
)
}

View File

@ -7,10 +7,10 @@ mod tests {
fn should_support_powf_ops() {
let data = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
let pow = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let pow = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 2.0]]);
let tensor_pow = Tensor::<TestBackend, 2>::from_data(pow, &Default::default());
let data_actual = tensor.powf(tensor_pow).into_data();
let data_expected = Data::from([[1.0, 1.0, 4.0], [27.0, 256.0, 3125.0]]);
let data_expected = Data::from([[1.0, 1.0, 4.0], [27.0, 256.0, 25.0]]);
data_expected.assert_approx_eq(&data_actual, 3);
}