#384 Include tests for int.rs and float.rs (#794)

This commit is contained in:
Juliano Decico Negri 2023-09-21 10:00:09 -03:00 committed by GitHub
parent 393d86e99d
commit 293020aae6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 268 additions and 4 deletions

View File

@ -15,6 +15,7 @@ pub use tensor::*;
#[cfg(test)]
mod tests {
extern crate alloc;
use super::*;
pub type TestBackend = CandleBackend<f32, i64>;

View File

@ -14,6 +14,8 @@ pub use tensor::*;
#[cfg(test)]
mod tests {
extern crate alloc;
type TestBackend = crate::TchBackend<f32>;
type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;

View File

@ -38,6 +38,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_cat!();
burn_tensor::testgen_clamp!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_create_like!();
burn_tensor::testgen_div!();
burn_tensor::testgen_erf!();
burn_tensor::testgen_exp!();
@ -54,6 +55,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_maxmin!();
burn_tensor::testgen_mul!();
burn_tensor::testgen_neg!();
burn_tensor::testgen_one_hot!();
burn_tensor::testgen_powf!();
burn_tensor::testgen_random!();
burn_tensor::testgen_repeat!();

View File

@ -1,6 +1,7 @@
#[burn_tensor_testgen::testgen(arange)]
mod tests {
use super::*;
use burn_tensor::backend::Backend;
use burn_tensor::{Data, Int, Tensor};
#[test]
@ -8,4 +9,13 @@ mod tests {
let tensor = Tensor::<TestBackend, 1, Int>::arange(2..5);
assert_eq!(tensor.into_data(), Data::from([2, 3, 4]));
}
#[test]
fn test_arange_device() {
let device = <TestBackend as Backend>::Device::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange_device(2..5, &device);
assert_eq!(tensor.clone().into_data(), Data::from([2, 3, 4]));
assert_eq!(tensor.device(), device);
}
}

View File

@ -1,6 +1,7 @@
#[burn_tensor_testgen::testgen(arange_step)]
mod tests {
use super::*;
use burn_tensor::backend::Backend;
use burn_tensor::{Data, Int, Tensor};
#[test]
@ -18,9 +19,27 @@ mod tests {
assert_eq!(tensor.into_data(), Data::from([0]));
}
#[test]
fn test_arange_step_device() {
let device = <TestBackend as Backend>::Device::default();
// Test correct sequence of numbers when the range is 0..9 and the step is 1
let tensor = Tensor::<TestBackend, 1, Int>::arange_step_device(0..9, 1, &device);
assert_eq!(tensor.into_data(), Data::from([0, 1, 2, 3, 4, 5, 6, 7, 8]));
// Test correct sequence of numbers when the range is 0..3 and the step is 2
let tensor = Tensor::<TestBackend, 1, Int>::arange_step_device(0..3, 2, &device);
assert_eq!(tensor.into_data(), Data::from([0, 2]));
// Test correct sequence of numbers when the range is 0..2 and the step is 5
let tensor = Tensor::<TestBackend, 1, Int>::arange_step_device(0..2, 5, &device);
assert_eq!(tensor.clone().into_data(), Data::from([0]));
assert_eq!(tensor.device(), device);
}
#[test]
#[should_panic]
fn test_arange_step_panic() {
fn should_panic_when_step_is_zero() {
// Test that arange_step panics when the step is 0
let _tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..3, 0);
}

View File

@ -1,8 +1,8 @@
#[burn_tensor_testgen::testgen(cat)]
mod tests {
use super::*;
use alloc::vec::Vec;
use burn_tensor::{Bool, Data, Int, Tensor};
#[test]
fn should_support_cat_ops_2d_dim0() {
let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0]]);
@ -57,4 +57,29 @@ mod tests {
let data_expected = Data::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
#[test]
#[should_panic]
fn should_panic_when_dimensions_are_not_the_same() {
let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]);
let tensor_2 = TestTensor::from_data([[4.0, 5.0]]);
TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data();
}
#[test]
#[should_panic]
fn should_panic_when_list_of_vectors_is_empty() {
let tensor: Vec<TestTensor<2>> = vec![];
TestTensor::cat(tensor, 0).into_data();
}
#[test]
#[should_panic]
fn should_panic_when_cat_exceeds_dimension() {
let tensor_1 = TestTensor::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]]);
let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]]);
TestTensor::cat(vec![tensor_1, tensor_2], 3).into_data();
}
}

View File

@ -0,0 +1,52 @@
#[burn_tensor_testgen::testgen(create_like)]
mod tests {
use super::*;
use burn_tensor::{Data, Distribution, Tensor};
#[test]
fn should_support_zeros_like() {
let tensor = TestTensor::from_floats([
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
]);
let data_actual = tensor.zeros_like().into_data();
let data_expected =
Data::from([[[0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.]]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
#[test]
fn should_support_ones_like() {
let tensor = TestTensor::from_floats([
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
]);
let data_actual = tensor.ones_like().into_data();
let data_expected =
Data::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
#[test]
fn should_support_randoms_like() {
let tensor = TestTensor::from_floats([
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
]);
let data_actual = tensor
.random_like(Distribution::Uniform(0.99999, 1.))
.into_data();
let data_expected =
Data::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
}

View File

@ -164,4 +164,14 @@ mod tests {
Data::from([[0.0, 1.0, 0.0], [0.0, 0.0, 4.0]])
);
}
#[test]
#[should_panic]
fn scatter_should_panic_on_mismatch_of_shapes() {
let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]);
let values = TestTensor::from_floats([5.0, 4.0]);
let indices = TestTensorInt::from_ints([1, 0, 2]);
tensor.scatter(0, indices, values);
}
}

View File

@ -85,4 +85,24 @@ mod tests {
])
);
}
#[test]
#[should_panic]
fn should_panic_when_inner_dimensions_are_not_equal() {
let tensor_1 = TestTensor::from_floats([[3., 3.], [4., 4.], [5., 5.], [6., 6.]]);
let tensor_2 =
TestTensor::from_floats([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]);
let tensor_3 = tensor_1.matmul(tensor_2);
assert_eq!(
tensor_3.into_data(),
Data::from([
[9., 18., 27., 36.],
[12., 24., 36., 48.],
[15., 30., 45., 60.],
[18., 36., 54., 72.]
])
);
}
}

View File

@ -8,6 +8,7 @@ mod cast;
mod cat;
mod clamp;
mod cos;
mod create_like;
mod div;
mod erf;
mod exp;
@ -24,6 +25,7 @@ mod matmul;
mod maxmin;
mod mul;
mod neg;
mod one_hot;
mod powf;
mod random;
mod repeat;

View File

@ -0,0 +1,32 @@
#[burn_tensor_testgen::testgen(one_hot)]
mod tests {
use super::*;
use burn_tensor::{Data, Int};
#[test]
fn should_support_one_hot() {
let tensor = TestTensor::<1>::one_hot(0, 5);
assert_eq!(tensor.to_data(), Data::from([1., 0., 0., 0., 0.]));
let tensor = TestTensor::<1>::one_hot(1, 5);
assert_eq!(tensor.to_data(), Data::from([0., 1., 0., 0., 0.]));
let tensor = TestTensor::<1>::one_hot(4, 5);
assert_eq!(tensor.to_data(), Data::from([0., 0., 0., 0., 1.]));
let tensor = TestTensor::<1>::one_hot(1, 2);
assert_eq!(tensor.to_data(), Data::from([0., 1.]));
}
#[test]
#[should_panic]
fn should_panic_when_index_exceeds_number_of_classes() {
let tensor = TestTensor::<1>::one_hot(1, 1);
}
#[test]
#[should_panic]
fn should_panic_when_number_of_classes_is_zero() {
let tensor = TestTensor::<1>::one_hot(0, 0);
}
}

View File

@ -116,4 +116,13 @@ mod tests {
Data::from([[2.0, 2.0, 5.0], [8.0, 8.0, 11.0]])
);
}
#[test]
#[should_panic]
fn should_select_panic_invalid_dimension() {
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]);
tensor.select(10, indices);
}
}

View File

@ -102,4 +102,48 @@ mod tests {
let data_expected = Data::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]);
assert_eq!(data_expected, data_actual);
}
#[test]
#[should_panic]
fn should_panic_when_slice_exceeds_dimension() {
let data = Data::from([0.0, 1.0, 2.0]);
let tensor = Tensor::<TestBackend, 1>::from_data(data.clone());
let data_actual = tensor.slice([0..4]).into_data();
assert_eq!(data, data_actual);
}
#[test]
#[should_panic]
fn should_panic_when_slice_with_too_many_dimensions() {
let data = Data::from([0.0, 1.0, 2.0]);
let tensor = Tensor::<TestBackend, 1>::from_data(data.clone());
let data_actual = tensor.slice([0..1, 0..1]).into_data();
assert_eq!(data, data_actual);
}
#[test]
#[should_panic]
fn should_panic_when_slice_is_desc() {
let data = Data::from([0.0, 1.0, 2.0]);
let tensor = Tensor::<TestBackend, 1>::from_data(data.clone());
let data_actual = tensor.slice([2..1]).into_data();
assert_eq!(data, data_actual);
}
#[test]
#[should_panic]
fn should_panic_when_slice_is_equal() {
let data = Data::from([0.0, 1.0, 2.0]);
let tensor = Tensor::<TestBackend, 1>::from_data(data.clone());
let data_actual = tensor.slice([1..1]).into_data();
assert_eq!(data, data_actual);
}
}

View File

@ -9,12 +9,47 @@ mod tests {
#[test]
fn test_var() {
let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
let data_actual = tensor.var(1).into_data();
let data_expected = Data::from([[2.4892], [15.3333]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
#[test]
fn test_var_mean() {
let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
let (var, mean) = tensor.var_mean(1);
let var_expected = Data::from([[2.4892], [15.3333]]);
let mean_expected = Data::from([[0.125], [1.]]);
var_expected.assert_approx_eq(&(var.into_data()), 3);
mean_expected.assert_approx_eq(&(mean.into_data()), 3);
}
#[test]
fn test_var_bias() {
let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
let data_actual = tensor.var_bias(1).into_data();
let data_expected = Data::from([[1.86688], [11.5]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
#[test]
fn test_var_mean_bias() {
let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
let (var, mean) = tensor.var_mean_bias(1);
let var_expected = Data::from([[1.86688], [11.5]]);
let mean_expected = Data::from([[0.125], [1.]]);
var_expected.assert_approx_eq(&(var.into_data()), 3);
mean_expected.assert_approx_eq(&(mean.into_data()), 3);
}
}

View File

@ -4,6 +4,7 @@
#[macro_use]
extern crate derive_new;
extern crate alloc;
mod ops;