From 36f52e9593cce6530b04cee7a16ed84b8f424a2e Mon Sep 17 00:00:00 2001 From: Alara Dirik <8944735+alaradirik@users.noreply.github.com> Date: Mon, 3 Oct 2022 12:02:51 +0300 Subject: [PATCH] Restructure DETR post-processing, return prediction scores (#19262) * Restructure DetrFeatureExtractor post-processing methods * Update post_process_instance_segmentation and post_process_panoptic_segmentation methods to return prediction scores * Update DETR models docs --- docs/source/en/model_doc/detr.mdx | 6 +- .../models/detr/feature_extraction_detr.py | 304 +++++++++--------- src/transformers/models/detr/modeling_detr.py | 4 +- 3 files changed, 161 insertions(+), 153 deletions(-) diff --git a/docs/source/en/model_doc/detr.mdx b/docs/source/en/model_doc/detr.mdx index 9739ead3a4..a6025580a6 100644 --- a/docs/source/en/model_doc/detr.mdx +++ b/docs/source/en/model_doc/detr.mdx @@ -171,9 +171,9 @@ mean Average Precision (mAP) and Panoptic Quality (PQ). The latter objects are i [[autodoc]] DetrFeatureExtractor - __call__ - pad_and_create_pixel_mask - - post_process - - post_process_segmentation - - post_process_panoptic + - post_process_semantic_segmentation + - post_process_instance_segmentation + - post_process_panoptic_segmentation ## DetrModel diff --git a/src/transformers/models/detr/feature_extraction_detr.py b/src/transformers/models/detr/feature_extraction_detr.py index 3ede3662a1..04fb123cf6 100644 --- a/src/transformers/models/detr/feature_extraction_detr.py +++ b/src/transformers/models/detr/feature_extraction_detr.py @@ -141,11 +141,33 @@ def binary_mask_to_rle(mask): return [x for x in runs] +def convert_segmentation_to_rle(segmentation): + """ + Converts given segmentation map of shape (height, width) to the run-length encoding (RLE) format. + + Args: + segmentation (`torch.Tensor` or `numpy.array`): + A segmentation map of shape `(height, width)` where each value denotes a segment or class id. + Returns: + `List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id. + """ + segment_ids = torch.unique(segmentation) + + run_length_encodings = [] + for idx in segment_ids: + mask = torch.where(segmentation == idx, 1, 0) + rle = binary_mask_to_rle(mask) + run_length_encodings.append(rle) + + return run_length_encodings + + def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels): """ + Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and + `labels`. + Args: - Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` - and `labels`. masks (`torch.Tensor`): A tensor of shape `(num_queries, height, width)`. scores (`torch.Tensor`): @@ -168,6 +190,81 @@ def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_ return masks[to_keep], scores[to_keep], labels[to_keep] +def check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_threshold=0.8): + # Get the mask associated with the k class + mask_k = mask_labels == k + mask_k_area = mask_k.sum() + + # Compute the area of all the stuff in query k + original_area = (mask_probs[k] >= 0.5).sum() + mask_exists = mask_k_area > 0 and original_area > 0 + + # Eliminate disconnected tiny segments + if mask_exists: + area_ratio = mask_k_area / original_area + if not area_ratio.item() > overlap_mask_area_threshold: + mask_exists = False + + return mask_exists, mask_k + + +def compute_segments( + mask_probs, + pred_scores, + pred_labels, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[Set[int]] = None, + target_size: Tuple[int, int] = None, +): + height = mask_probs.shape[1] if target_size is None else target_size[0] + width = mask_probs.shape[2] if target_size is None else target_size[1] + + segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device) + segments: List[Dict] = [] + + if target_size is not None: + mask_probs = nn.functional.interpolate( + mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False + )[0] + + current_segment_id = 0 + + # Weigh each mask by its prediction score + mask_probs *= pred_scores.view(-1, 1, 1) + mask_labels = mask_probs.argmax(0) # [height, width] + + # Keep track of instances of each class + stuff_memory_list: Dict[str, int] = {} + for k in range(pred_labels.shape[0]): + pred_class = pred_labels[k].item() + should_fuse = pred_class in label_ids_to_fuse + + # Check if mask exists and large enough to be a segment + mask_exists, mask_k = check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_threshold) + + if mask_exists: + if pred_class in stuff_memory_list: + current_segment_id = stuff_memory_list[pred_class] + else: + current_segment_id += 1 + + # Add current object segment to final segmentation map + segmentation[mask_k] = current_segment_id + segment_score = round(pred_scores[k].item(), 6) + segments.append( + { + "id": current_segment_id, + "label_id": pred_class, + "was_fused": should_fuse, + "score": segment_score, + } + ) + if should_fuse: + stuff_memory_list[pred_class] = current_segment_id + + return segmentation, segments + + class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): r""" Constructs a DETR feature extractor. @@ -1098,7 +1195,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): semantic_segmentation = [] for idx in range(batch_size): - resized_logits = torch.nn.functional.interpolate( + resized_logits = nn.functional.interpolate( segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False ) semantic_map = resized_logits[0].argmax(dim=0) @@ -1114,31 +1211,34 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): outputs, threshold: float = 0.5, overlap_mask_area_threshold: float = 0.8, - target_sizes: List[Tuple] = None, + target_sizes: Optional[List[Tuple[int, int]]] = None, return_coco_annotation: Optional[bool] = False, - ): + ) -> List[Dict]: """ Args: Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch. outputs ([`DetrForSegmentation`]): Raw outputs of the model. - threshold (`float`, *optional*): - The probability score threshold to keep predicted instance masks, defaults to 0.5. - overlap_mask_area_threshold (`float`, *optional*): + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): The overlap mask area threshold to merge or discard small disconnected parts within each binary - instance mask, defaults to 0.8. - target_sizes (`List[Tuple]`, *optional*, defaults to `None`): + instance mask. + target_sizes (`List[Tuple]`, *optional*): List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested final size (height, width) of each prediction. If left to None, predictions will not be resized. - return_coco_annotation (`bool`, *optional*, defaults to `False`): - If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format. + return_coco_annotation (`bool`, *optional*): + Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) + format. Returns: `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or - `List[List]` run-length encoding (RLE) of the segmentation map if return_coco_format is set to `True`. - - **segment_ids** -- A dictionary that maps segment ids to semantic class ids. + `List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to + `True`. Set to `None` if no mask if found above `threshold`. + - **segments_info** -- A dictionary that contains additional information on each segment. - **id** -- An integer representing the `segment_id`. - - **label_id** -- An integer representing the segment's label / semantic class id. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. """ class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] @@ -1159,76 +1259,27 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels ) - height, width = target_sizes[i][0], target_sizes[i][1] - segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs_item.device) - segments: List[Dict] = [] + # No mask found + if mask_probs_item.shape[0] <= 0: + segmentation = None + segments: List[Dict] = [] + continue - object_detected = mask_probs_item.shape[0] > 0 - - if object_detected: - # Resize mask to corresponding target_size - if target_sizes is not None: - mask_probs_item = torch.nn.functional.interpolate( - mask_probs_item.unsqueeze(0), - size=target_sizes[i], - mode="bilinear", - align_corners=False, - )[0] - - current_segment_id = 0 - - # Weigh each mask by its prediction score - mask_probs_item *= pred_scores_item.view(-1, 1, 1) - mask_labels_item = mask_probs_item.argmax(0) # [height, width] - - # Keep track of instances of each class - stuff_memory_list: Dict[str, int] = {} - for k in range(pred_labels_item.shape[0]): - # Get the mask associated with the k class - pred_class = pred_labels_item[k].item() - mask_k = mask_labels_item == k - mask_k_area = mask_k.sum() - - # Compute the area of all the stuff in query k - original_area = (mask_probs_item[k] >= 0.5).sum() - mask_exists = mask_k_area > 0 and original_area > 0 - - if mask_exists: - # Eliminate segments with mask area below threshold - area_ratio = mask_k_area / original_area - if not area_ratio.item() > overlap_mask_area_threshold: - continue - - # Add corresponding class id - if pred_class in stuff_memory_list: - current_segment_id = stuff_memory_list[pred_class] - else: - current_segment_id += 1 - - # Add current object segment to final segmentation map - segmentation[mask_k] = current_segment_id - segments.append( - { - "id": current_segment_id, - "label_id": pred_class, - } - ) - else: - segmentation -= 1 + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs_item, + pred_scores_item, + pred_labels_item, + overlap_mask_area_threshold, + target_size, + ) # Return segmentation map in run-length encoding (RLE) format if return_coco_annotation: - segment_ids = torch.unique(segmentation) + segmentation = convert_segmentation_to_rle(segmentation) - run_length_encodings = [] - for idx in segment_ids: - mask = torch.where(segmentation == idx, 1, 0) - rle = binary_mask_to_rle(mask) - run_length_encodings.append(rle) - - segmentation = run_length_encodings - - results.append({"segmentation": segmentation, "segment_ids": segments}) + results.append({"segmentation": segmentation, "segments_info": segments}) return results def post_process_panoptic_segmentation( @@ -1237,7 +1288,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): threshold: float = 0.5, overlap_mask_area_threshold: float = 0.8, label_ids_to_fuse: Optional[Set[int]] = None, - target_sizes: List[Tuple] = None, + target_sizes: Optional[List[Tuple[int, int]]] = None, ) -> List[Dict]: """ Args: @@ -1250,7 +1301,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): The overlap mask area threshold to merge or discard small disconnected parts within each binary instance mask. - label_ids_to_fuse (`Set[int]`, *optional*, defaults to `None`): + label_ids_to_fuse (`Set[int]`, *optional*): The labels in this state will have all their instances be fused together. For instance we could say there can only be one sky in an image, but several persons, so the label ID for sky would be in that set, but not the one for person. @@ -1260,13 +1311,15 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): resized. Returns: `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: - - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`. If - `target_sizes` is specified, segmentation is resized to the corresponding `target_sizes` entry. - - **segment_ids** -- A dictionary that maps segment ids to semantic class ids. - - **id** -- An integer representing the `segment_id`. - - **label_id** -- An integer representing the segment's label / semantic class id. + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or + `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to + the corresponding `target_sizes` entry. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- an integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. Multiple instances of the same class / label were fused and assigned a single `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. """ if label_ids_to_fuse is None: @@ -1292,67 +1345,22 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels ) - height, width = target_sizes[i][0], target_sizes[i][1] - segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs_item.device) - segments: List[Dict] = [] + # No mask found + if mask_probs_item.shape[0] <= 0: + segmentation = None + segments: List[Dict] = [] + continue - object_detected = mask_probs_item.shape[0] > 0 + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs_item, + pred_scores_item, + pred_labels_item, + overlap_mask_area_threshold, + label_ids_to_fuse, + target_size, + ) - if object_detected: - # Resize mask to corresponding target_size - if target_sizes is not None: - mask_probs_item = torch.nn.functional.interpolate( - mask_probs_item.unsqueeze(0), - size=target_sizes[i], - mode="bilinear", - align_corners=False, - )[0] - - current_segment_id = 0 - - # Weigh each mask by its prediction score - mask_probs_item *= pred_scores_item.view(-1, 1, 1) - mask_labels_item = mask_probs_item.argmax(0) # [height, width] - - # Keep track of instances of each class - stuff_memory_list: Dict[str, int] = {} - for k in range(pred_labels_item.shape[0]): - pred_class = pred_labels_item[k].item() - should_fuse = pred_class in label_ids_to_fuse - - # Get the mask associated with the k class - mask_k = mask_labels_item == k - mask_k_area = mask_k.sum() - - # Compute the area of all the stuff in query k - original_area = (mask_probs_item[k] >= 0.5).sum() - mask_exists = mask_k_area > 0 and original_area > 0 - - if mask_exists: - # Eliminate disconnected tiny segments - area_ratio = mask_k_area / original_area - if not area_ratio.item() > overlap_mask_area_threshold: - continue - - # Add corresponding class id - if pred_class in stuff_memory_list: - current_segment_id = stuff_memory_list[pred_class] - else: - current_segment_id += 1 - - # Add current object segment to final segmentation map - segmentation[mask_k] = current_segment_id - segments.append( - { - "id": current_segment_id, - "label_id": pred_class, - "was_fused": should_fuse, - } - ) - if should_fuse: - stuff_memory_list[pred_class] = current_segment_id - else: - segmentation -= 1 - - results.append({"segmentation": segmentation, "segment_ids": segments}) + results.append({"segmentation": segmentation, "segments_info": segments}) return results diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index dc5b562626..724c2b71a7 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -1605,12 +1605,12 @@ class DetrForSegmentation(DetrPreTrainedModel): >>> # Use the `post_process_panoptic_segmentation` method of `DetrFeatureExtractor` to retrieve post-processed panoptic segmentation maps >>> # Segmentation results are returned as a list of dictionaries - >>> result = feature_extractor.post_process_panoptic_segmentation(outputs, processed_sizes) + >>> result = feature_extractor.post_process_panoptic_segmentation(outputs, target_size=[(300, 500)]) >>> # A tensor of shape (height, width) where each value denotes a segment id >>> panoptic_seg = result[0]["segmentation"] >>> # Get mapping of segment ids to semantic class ids - >>> panoptic_segments_info = result[0]["segment_ids"] + >>> panoptic_segments_info = result[0]["segments_info"] ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict