Add tests: Slice assign vs Cat in LSTM backward (#1146)

* slice assign test

* added tests but no error

* test for non zero grad

* clippy

* i'm confused

* fix ci
This commit is contained in:
Louis Fortier-Dubois 2024-01-28 21:29:06 -05:00 committed by GitHub
parent 868608222b
commit 2defd01342
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 67 additions and 1 deletions

View File

@ -77,4 +77,43 @@ mod tests {
assert_eq!(grad_1.to_data(), Data::from([[85.0, 65.0], [118.0, 82.0]]));
assert_eq!(grad_2.to_data(), Data::from([[88.0, 15.0], [24.0, 50.0]]));
}
#[test]
fn slice_assign_diff_should_give_same_results_as_cat() {
let data_1: Data<f32, 2> = Data::from([[1.0, 2.0], [3.0, 4.0]]);
let data_2: Data<f32, 2> = Data::from([[5.0, 6.0], [7.0, 8.0]]);
let data_3: Data<f32, 2> = Data::from([[14.0, 97.0, 100.0, 9.0], [2.0, 3.0, 15.0, 7.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device);
let slice_assign_output = TestAutodiffTensor::zeros([2, 4], &Default::default());
let slice_assign_output = slice_assign_output.slice_assign([0..2, 0..2], tensor_1.clone());
let slice_assign_output = slice_assign_output.slice_assign([0..2, 2..4], tensor_2.clone());
let slice_assign_output = slice_assign_output / tensor_3.clone();
let cat_output = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 1);
let cat_output = cat_output / tensor_3;
slice_assign_output
.to_data()
.assert_approx_eq(&cat_output.to_data(), 3);
let slice_assign_grads = slice_assign_output.backward();
let cat_grads = cat_output.backward();
let slice_assign_grad_1 = tensor_1.grad(&slice_assign_grads).unwrap();
let slice_assign_grad_2 = tensor_2.grad(&slice_assign_grads).unwrap();
let cat_grad_1 = tensor_1.grad(&cat_grads).unwrap();
let cat_grad_2 = tensor_2.grad(&cat_grads).unwrap();
slice_assign_grad_1
.to_data()
.assert_approx_eq(&cat_grad_1.to_data(), 3);
slice_assign_grad_2
.to_data()
.assert_approx_eq(&cat_grad_2.to_data(), 3);
}
}

View File

@ -224,7 +224,10 @@ impl<B: Backend> Lstm<B> {
mod tests {
use super::*;
use crate::{module::Param, nn::LinearRecord, TestBackend};
use burn_tensor::{Data, Distribution};
use burn_tensor::{Data, Distribution, Shape};
#[cfg(feature = "std")]
use crate::TestAutodiffBackend;
#[test]
fn test_with_uniform_initializer() {
@ -353,4 +356,28 @@ mod tests {
assert_eq!(cell_state.shape().dims, [8, 10, 1024]);
assert_eq!(hidden_state.shape().dims, [8, 10, 1024]);
}
#[test]
#[cfg(feature = "std")]
fn test_batched_backward_pass() {
let device = Default::default();
let lstm = LstmConfig::new(64, 32, true).init(&device);
let shape: Shape<3> = [8, 10, 64].into();
let batched_input =
Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);
let (cell_state, hidden_state) = lstm.forward(batched_input.clone(), None);
let fake_loss = cell_state + hidden_state;
let grads = fake_loss.backward();
let some_gradient = lstm
.output_gate
.hidden_transform
.weight
.grad(&grads)
.unwrap();
// Asserts the gradients exist and are non zero
assert!(*some_gradient.abs().sum().into_data().value.first().unwrap() > 0.);
}
}