From 2defd01342bb93c3df47ec0d79077a3ddb106fb0 Mon Sep 17 00:00:00 2001 From: Louis Fortier-Dubois Date: Sun, 28 Jan 2024 21:29:06 -0500 Subject: [PATCH] Add tests: Slice assign vs Cat in LSTM backward (#1146) * slice assign test * added tests but no error * test for non zero grad * clippy * i'm confused * fix ci --- burn-autodiff/src/tests/slice.rs | 39 ++++++++++++++++++++++++++++++++ burn-core/src/nn/rnn/lstm.rs | 29 +++++++++++++++++++++++- 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/burn-autodiff/src/tests/slice.rs b/burn-autodiff/src/tests/slice.rs index 45de254ea..a69e397ae 100644 --- a/burn-autodiff/src/tests/slice.rs +++ b/burn-autodiff/src/tests/slice.rs @@ -77,4 +77,43 @@ mod tests { assert_eq!(grad_1.to_data(), Data::from([[85.0, 65.0], [118.0, 82.0]])); assert_eq!(grad_2.to_data(), Data::from([[88.0, 15.0], [24.0, 50.0]])); } + + #[test] + fn slice_assign_diff_should_give_same_results_as_cat() { + let data_1: Data = Data::from([[1.0, 2.0], [3.0, 4.0]]); + let data_2: Data = Data::from([[5.0, 6.0], [7.0, 8.0]]); + let data_3: Data = Data::from([[14.0, 97.0, 100.0, 9.0], [2.0, 3.0, 15.0, 7.0]]); + + let device = Default::default(); + let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3, &device); + + let slice_assign_output = TestAutodiffTensor::zeros([2, 4], &Default::default()); + let slice_assign_output = slice_assign_output.slice_assign([0..2, 0..2], tensor_1.clone()); + let slice_assign_output = slice_assign_output.slice_assign([0..2, 2..4], tensor_2.clone()); + let slice_assign_output = slice_assign_output / tensor_3.clone(); + + let cat_output = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 1); + let cat_output = cat_output / tensor_3; + + slice_assign_output + .to_data() + .assert_approx_eq(&cat_output.to_data(), 3); + + let slice_assign_grads = slice_assign_output.backward(); + let cat_grads = cat_output.backward(); + + let slice_assign_grad_1 = tensor_1.grad(&slice_assign_grads).unwrap(); + let slice_assign_grad_2 = tensor_2.grad(&slice_assign_grads).unwrap(); + let cat_grad_1 = tensor_1.grad(&cat_grads).unwrap(); + let cat_grad_2 = tensor_2.grad(&cat_grads).unwrap(); + + slice_assign_grad_1 + .to_data() + .assert_approx_eq(&cat_grad_1.to_data(), 3); + slice_assign_grad_2 + .to_data() + .assert_approx_eq(&cat_grad_2.to_data(), 3); + } } diff --git a/burn-core/src/nn/rnn/lstm.rs b/burn-core/src/nn/rnn/lstm.rs index c6d581d55..ba32d698c 100644 --- a/burn-core/src/nn/rnn/lstm.rs +++ b/burn-core/src/nn/rnn/lstm.rs @@ -224,7 +224,10 @@ impl Lstm { mod tests { use super::*; use crate::{module::Param, nn::LinearRecord, TestBackend}; - use burn_tensor::{Data, Distribution}; + use burn_tensor::{Data, Distribution, Shape}; + + #[cfg(feature = "std")] + use crate::TestAutodiffBackend; #[test] fn test_with_uniform_initializer() { @@ -353,4 +356,28 @@ mod tests { assert_eq!(cell_state.shape().dims, [8, 10, 1024]); assert_eq!(hidden_state.shape().dims, [8, 10, 1024]); } + + #[test] + #[cfg(feature = "std")] + fn test_batched_backward_pass() { + let device = Default::default(); + let lstm = LstmConfig::new(64, 32, true).init(&device); + let shape: Shape<3> = [8, 10, 64].into(); + let batched_input = + Tensor::::random(shape, Distribution::Default, &device); + + let (cell_state, hidden_state) = lstm.forward(batched_input.clone(), None); + let fake_loss = cell_state + hidden_state; + let grads = fake_loss.backward(); + + let some_gradient = lstm + .output_gate + .hidden_transform + .weight + .grad(&grads) + .unwrap(); + + // Asserts the gradients exist and are non zero + assert!(*some_gradient.abs().sum().into_data().value.first().unwrap() > 0.); + } }