mirror of https://github.com/tracel-ai/burn.git
Fix flaky tests + add feature flag (#362)
This commit is contained in:
parent
56e40ae63b
commit
3ef2a18d87
|
@ -70,12 +70,18 @@ mod tests {
|
|||
let (mean, std) = (0.0, 1.0);
|
||||
let normal: Tensor<TB, 1> = Initializer::Normal(mean, std).init([1000]);
|
||||
let (var_act, mean_act) = normal.var_mean(0);
|
||||
var_act
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([std as f32]), 1);
|
||||
mean_act
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([mean as f32]), 1);
|
||||
|
||||
let var_act: f32 = var_act.into_scalar().elem();
|
||||
let mean_act: f32 = mean_act.into_scalar().elem();
|
||||
|
||||
assert!(
|
||||
var_act > 0.9 && var_act < 1.1,
|
||||
"Expected variance to be between 1.0 += 0.1, but got {var_act}"
|
||||
);
|
||||
assert!(
|
||||
mean_act > -0.1 && mean_act < 0.1,
|
||||
"Expected mean to be between 0.0 += 0.1, but got {mean_act}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -371,6 +371,8 @@ mod tests {
|
|||
#[test]
|
||||
fn test_autoregressive_norm_last() {
|
||||
let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
|
||||
TestBackend::seed(0);
|
||||
|
||||
test_autoregressive(
|
||||
TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers)
|
||||
.with_norm_first(false),
|
||||
|
@ -380,6 +382,8 @@ mod tests {
|
|||
#[test]
|
||||
fn test_autoregressive_norm_first() {
|
||||
let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
|
||||
TestBackend::seed(0);
|
||||
|
||||
test_autoregressive(
|
||||
TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
|
||||
)
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
#[cfg(feature = "audio")]
|
||||
use burn_dataset::{audio::SpeechCommandsDataset, Dataset};
|
||||
|
||||
fn main() {
|
||||
#[cfg(feature = "audio")]
|
||||
fn speech_command() {
|
||||
let index: usize = 4835;
|
||||
let test = SpeechCommandsDataset::test();
|
||||
let item = test.get(index).unwrap();
|
||||
|
@ -13,3 +15,8 @@ fn main() {
|
|||
assert_eq!(item.sample_rate, 16000);
|
||||
assert_eq!(item.audio_samples.len(), 16000);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
#[cfg(feature = "audio")]
|
||||
speech_command()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue