Fix image segmentation and object detection pipeline tests (#18100)

This commit is contained in:
Sylvain Gugger 2022-07-11 12:41:56 -04:00 committed by GitHub
parent b0520f594c
commit 6c8017a5c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 14 deletions

View File

@ -147,7 +147,6 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
pass
@require_torch
@unittest.skip("Test is broken, fix me please!")
def test_small_model_pt(self):
model_id = "hf-internal-testing/tiny-detr-mobilenetsv3-panoptic"
@ -165,12 +164,12 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
[
{
"score": 0.004,
"label": "LABEL_0",
"label": "LABEL_215",
"mask": "34eecd16bbfb0f476083ef947d81bf66",
},
{
"score": 0.004,
"label": "LABEL_0",
"label": "LABEL_215",
"mask": "34eecd16bbfb0f476083ef947d81bf66",
},
],
@ -193,24 +192,24 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
[
{
"score": 0.004,
"label": "LABEL_0",
"label": "LABEL_215",
"mask": "34eecd16bbfb0f476083ef947d81bf66",
},
{
"score": 0.004,
"label": "LABEL_0",
"label": "LABEL_215",
"mask": "34eecd16bbfb0f476083ef947d81bf66",
},
],
[
{
"score": 0.004,
"label": "LABEL_0",
"label": "LABEL_215",
"mask": "34eecd16bbfb0f476083ef947d81bf66",
},
{
"score": 0.004,
"label": "LABEL_0",
"label": "LABEL_215",
"mask": "34eecd16bbfb0f476083ef947d81bf66",
},
],

View File

@ -105,7 +105,6 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
pass
@require_torch
@unittest.skip("Test is broken, fix me please!")
def test_small_model_pt(self):
model_id = "hf-internal-testing/tiny-detr-mobilenetsv3"
@ -118,8 +117,8 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
],
)
@ -135,12 +134,12 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
],
[
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
],
],
)