Update flava tests (#29611)
* update * update * update --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
df1542581e
commit
a15bd3af4e
|
@ -1287,9 +1287,9 @@ class FlavaModelIntegrationTest(unittest.TestCase):
|
||||||
outputs = model(**inputs, return_dict=True)
|
outputs = model(**inputs, return_dict=True)
|
||||||
|
|
||||||
# verify the embeddings
|
# verify the embeddings
|
||||||
self.assertAlmostEqual(outputs.image_embeddings.sum().item(), -1352.54943, places=4)
|
self.assertAlmostEqual(outputs.image_embeddings.sum().item(), -1352.53540, places=4)
|
||||||
self.assertAlmostEqual(outputs.text_embeddings.sum().item(), -198.98225, places=4)
|
self.assertAlmostEqual(outputs.text_embeddings.sum().item(), -198.98225, places=4)
|
||||||
self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -4030.466552, places=4)
|
self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -4030.4602050, places=4)
|
||||||
|
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
|
@ -1339,9 +1339,9 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
|
||||||
|
|
||||||
expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
|
expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
|
||||||
self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
|
self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
|
||||||
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 2.0736470, places=4)
|
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 2.0727925, places=4)
|
||||||
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.025580, places=4)
|
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0282096, places=4)
|
||||||
self.assertAlmostEqual(outputs.loss.item(), 11.37761, places=4)
|
self.assertAlmostEqual(outputs.loss.item(), 11.3792324, places=4)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_inference_with_itm_labels(self):
|
def test_inference_with_itm_labels(self):
|
||||||
|
@ -1390,6 +1390,6 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
|
||||||
|
|
||||||
expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
|
expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
|
||||||
self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
|
self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
|
||||||
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 2.0736470, places=4)
|
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 2.0727925, places=4)
|
||||||
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 6.8962264, places=4)
|
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 6.8965902, places=4)
|
||||||
self.assertAlmostEqual(outputs.loss.item(), 9.6090, places=4)
|
self.assertAlmostEqual(outputs.loss.item(), 9.6084213, places=4)
|
||||||
|
|
Loading…
Reference in New Issue