mirror of https://github.com/tracel-ai/burn.git
Add tests: Slice assign vs Cat in LSTM backward (#1146)
* slice assign test * added tests but no error * test for non zero grad * clippy * i'm confused * fix ci
This commit is contained in:
parent
868608222b
commit
2defd01342
|
@ -77,4 +77,43 @@ mod tests {
|
|||
assert_eq!(grad_1.to_data(), Data::from([[85.0, 65.0], [118.0, 82.0]]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([[88.0, 15.0], [24.0, 50.0]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_assign_diff_should_give_same_results_as_cat() {
|
||||
let data_1: Data<f32, 2> = Data::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let data_2: Data<f32, 2> = Data::from([[5.0, 6.0], [7.0, 8.0]]);
|
||||
let data_3: Data<f32, 2> = Data::from([[14.0, 97.0, 100.0, 9.0], [2.0, 3.0, 15.0, 7.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device);
|
||||
|
||||
let slice_assign_output = TestAutodiffTensor::zeros([2, 4], &Default::default());
|
||||
let slice_assign_output = slice_assign_output.slice_assign([0..2, 0..2], tensor_1.clone());
|
||||
let slice_assign_output = slice_assign_output.slice_assign([0..2, 2..4], tensor_2.clone());
|
||||
let slice_assign_output = slice_assign_output / tensor_3.clone();
|
||||
|
||||
let cat_output = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 1);
|
||||
let cat_output = cat_output / tensor_3;
|
||||
|
||||
slice_assign_output
|
||||
.to_data()
|
||||
.assert_approx_eq(&cat_output.to_data(), 3);
|
||||
|
||||
let slice_assign_grads = slice_assign_output.backward();
|
||||
let cat_grads = cat_output.backward();
|
||||
|
||||
let slice_assign_grad_1 = tensor_1.grad(&slice_assign_grads).unwrap();
|
||||
let slice_assign_grad_2 = tensor_2.grad(&slice_assign_grads).unwrap();
|
||||
let cat_grad_1 = tensor_1.grad(&cat_grads).unwrap();
|
||||
let cat_grad_2 = tensor_2.grad(&cat_grads).unwrap();
|
||||
|
||||
slice_assign_grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq(&cat_grad_1.to_data(), 3);
|
||||
slice_assign_grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq(&cat_grad_2.to_data(), 3);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -224,7 +224,10 @@ impl<B: Backend> Lstm<B> {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::{module::Param, nn::LinearRecord, TestBackend};
|
||||
use burn_tensor::{Data, Distribution};
|
||||
use burn_tensor::{Data, Distribution, Shape};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use crate::TestAutodiffBackend;
|
||||
|
||||
#[test]
|
||||
fn test_with_uniform_initializer() {
|
||||
|
@ -353,4 +356,28 @@ mod tests {
|
|||
assert_eq!(cell_state.shape().dims, [8, 10, 1024]);
|
||||
assert_eq!(hidden_state.shape().dims, [8, 10, 1024]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "std")]
|
||||
fn test_batched_backward_pass() {
|
||||
let device = Default::default();
|
||||
let lstm = LstmConfig::new(64, 32, true).init(&device);
|
||||
let shape: Shape<3> = [8, 10, 64].into();
|
||||
let batched_input =
|
||||
Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);
|
||||
|
||||
let (cell_state, hidden_state) = lstm.forward(batched_input.clone(), None);
|
||||
let fake_loss = cell_state + hidden_state;
|
||||
let grads = fake_loss.backward();
|
||||
|
||||
let some_gradient = lstm
|
||||
.output_gate
|
||||
.hidden_transform
|
||||
.weight
|
||||
.grad(&grads)
|
||||
.unwrap();
|
||||
|
||||
// Asserts the gradients exist and are non zero
|
||||
assert!(*some_gradient.abs().sum().into_data().value.first().unwrap() > 0.);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue