Fix bug in segmentation postprocessing (#20198)
* Fix post_process_instance_segmentation * Add test for label fusing
This commit is contained in:
parent
292acd71d6
commit
52c9e6af29
|
@ -1050,12 +1050,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||
# Get segmentation map and segment information of batch item
|
||||
target_size = target_sizes[i] if target_sizes is not None else None
|
||||
segmentation, segments = compute_segments(
|
||||
mask_probs_item,
|
||||
pred_scores_item,
|
||||
pred_labels_item,
|
||||
mask_threshold,
|
||||
overlap_mask_area_threshold,
|
||||
target_size,
|
||||
mask_probs=mask_probs_item,
|
||||
pred_scores=pred_scores_item,
|
||||
pred_labels=pred_labels_item,
|
||||
mask_threshold=mask_threshold,
|
||||
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
||||
label_ids_to_fuse=[],
|
||||
target_size=target_size,
|
||||
)
|
||||
|
||||
# Return segmentation map in run-length encoding (RLE) format
|
||||
|
@ -1143,13 +1144,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||
# Get segmentation map and segment information of batch item
|
||||
target_size = target_sizes[i] if target_sizes is not None else None
|
||||
segmentation, segments = compute_segments(
|
||||
mask_probs_item,
|
||||
pred_scores_item,
|
||||
pred_labels_item,
|
||||
mask_threshold,
|
||||
overlap_mask_area_threshold,
|
||||
label_ids_to_fuse,
|
||||
target_size,
|
||||
mask_probs=mask_probs_item,
|
||||
pred_scores=pred_scores_item,
|
||||
pred_labels=pred_labels_item,
|
||||
mask_threshold=mask_threshold,
|
||||
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
||||
label_ids_to_fuse=label_ids_to_fuse,
|
||||
target_size=target_size,
|
||||
)
|
||||
|
||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||
|
|
|
@ -589,3 +589,30 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
|||
self.assertEqual(
|
||||
el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width)
|
||||
)
|
||||
|
||||
def test_post_process_label_fusing(self):
|
||||
feature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||
|
||||
segmentation = feature_extractor.post_process_panoptic_segmentation(
|
||||
outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0
|
||||
)
|
||||
unfused_segments = [el["segments_info"] for el in segmentation]
|
||||
|
||||
fused_segmentation = feature_extractor.post_process_panoptic_segmentation(
|
||||
outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0, label_ids_to_fuse={1}
|
||||
)
|
||||
fused_segments = [el["segments_info"] for el in fused_segmentation]
|
||||
|
||||
for el_unfused, el_fused in zip(unfused_segments, fused_segments):
|
||||
if len(el_unfused) == 0:
|
||||
self.assertEqual(len(el_unfused), len(el_fused))
|
||||
continue
|
||||
|
||||
# Get number of segments to be fused
|
||||
fuse_targets = [1 for el in el_unfused if el["label_id"] in {1}]
|
||||
num_to_fuse = 0 if len(fuse_targets) == 0 else sum(fuse_targets) - 1
|
||||
# Expected number of segments after fusing
|
||||
expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse
|
||||
num_segments_fused = max([el["id"] for el in el_fused])
|
||||
self.assertEqual(num_segments_fused, expected_num_segments)
|
||||
|
|
Loading…
Reference in New Issue