diff --git a/tests/test_modeling_bert_generation.py b/tests/test_modeling_bert_generation.py index aa22970a96..2048c127e9 100755 --- a/tests/test_modeling_bert_generation.py +++ b/tests/test_modeling_bert_generation.py @@ -297,3 +297,33 @@ class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittes def test_model_from_pretrained(self): model = BertGenerationEncoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder") self.assertIsNotNone(model) + + +@require_torch +class BertGenerationEncoderIntegrationTest(unittest.TestCase): + @slow + def test_inference_no_head_absolute_embedding(self): + model = BertGenerationEncoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder") + input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]]) + output = model(input_ids)[0] + expected_shape = torch.Size([1, 8, 1024]) + self.assertEqual(output.shape, expected_shape) + expected_slice = torch.tensor( + [[[0.1775, 0.0083, -0.0321], [1.6002, 0.1287, 0.3912], [2.1473, 0.5791, 0.6066]]] + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) + + +@require_torch +class BertGenerationDecoderIntegrationTest(unittest.TestCase): + @slow + def test_inference_no_head_absolute_embedding(self): + model = BertGenerationDecoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder") + input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]]) + output = model(input_ids)[0] + expected_shape = torch.Size([1, 8, 50358]) + self.assertEqual(output.shape, expected_shape) + expected_slice = torch.tensor( + [[[-0.5788, -2.5994, -3.7054], [0.0438, 4.7997, 1.8795], [1.5862, 6.6409, 4.4638]]] + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))