Restructure DETR post-processing, return prediction scores (#19262)
* Restructure DetrFeatureExtractor post-processing methods * Update post_process_instance_segmentation and post_process_panoptic_segmentation methods to return prediction scores * Update DETR models docs
This commit is contained in:
parent
5cd16f01db
commit
36f52e9593
|
@ -171,9 +171,9 @@ mean Average Precision (mAP) and Panoptic Quality (PQ). The latter objects are i
|
|||
[[autodoc]] DetrFeatureExtractor
|
||||
- __call__
|
||||
- pad_and_create_pixel_mask
|
||||
- post_process
|
||||
- post_process_segmentation
|
||||
- post_process_panoptic
|
||||
- post_process_semantic_segmentation
|
||||
- post_process_instance_segmentation
|
||||
- post_process_panoptic_segmentation
|
||||
|
||||
## DetrModel
|
||||
|
||||
|
|
|
@ -141,11 +141,33 @@ def binary_mask_to_rle(mask):
|
|||
return [x for x in runs]
|
||||
|
||||
|
||||
def convert_segmentation_to_rle(segmentation):
|
||||
"""
|
||||
Converts given segmentation map of shape (height, width) to the run-length encoding (RLE) format.
|
||||
|
||||
Args:
|
||||
segmentation (`torch.Tensor` or `numpy.array`):
|
||||
A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
|
||||
Returns:
|
||||
`List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
|
||||
"""
|
||||
segment_ids = torch.unique(segmentation)
|
||||
|
||||
run_length_encodings = []
|
||||
for idx in segment_ids:
|
||||
mask = torch.where(segmentation == idx, 1, 0)
|
||||
rle = binary_mask_to_rle(mask)
|
||||
run_length_encodings.append(rle)
|
||||
|
||||
return run_length_encodings
|
||||
|
||||
|
||||
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
|
||||
"""
|
||||
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
|
||||
`labels`.
|
||||
|
||||
Args:
|
||||
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores`
|
||||
and `labels`.
|
||||
masks (`torch.Tensor`):
|
||||
A tensor of shape `(num_queries, height, width)`.
|
||||
scores (`torch.Tensor`):
|
||||
|
@ -168,6 +190,81 @@ def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_
|
|||
return masks[to_keep], scores[to_keep], labels[to_keep]
|
||||
|
||||
|
||||
def check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_threshold=0.8):
|
||||
# Get the mask associated with the k class
|
||||
mask_k = mask_labels == k
|
||||
mask_k_area = mask_k.sum()
|
||||
|
||||
# Compute the area of all the stuff in query k
|
||||
original_area = (mask_probs[k] >= 0.5).sum()
|
||||
mask_exists = mask_k_area > 0 and original_area > 0
|
||||
|
||||
# Eliminate disconnected tiny segments
|
||||
if mask_exists:
|
||||
area_ratio = mask_k_area / original_area
|
||||
if not area_ratio.item() > overlap_mask_area_threshold:
|
||||
mask_exists = False
|
||||
|
||||
return mask_exists, mask_k
|
||||
|
||||
|
||||
def compute_segments(
|
||||
mask_probs,
|
||||
pred_scores,
|
||||
pred_labels,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
label_ids_to_fuse: Optional[Set[int]] = None,
|
||||
target_size: Tuple[int, int] = None,
|
||||
):
|
||||
height = mask_probs.shape[1] if target_size is None else target_size[0]
|
||||
width = mask_probs.shape[2] if target_size is None else target_size[1]
|
||||
|
||||
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
|
||||
segments: List[Dict] = []
|
||||
|
||||
if target_size is not None:
|
||||
mask_probs = nn.functional.interpolate(
|
||||
mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
|
||||
)[0]
|
||||
|
||||
current_segment_id = 0
|
||||
|
||||
# Weigh each mask by its prediction score
|
||||
mask_probs *= pred_scores.view(-1, 1, 1)
|
||||
mask_labels = mask_probs.argmax(0) # [height, width]
|
||||
|
||||
# Keep track of instances of each class
|
||||
stuff_memory_list: Dict[str, int] = {}
|
||||
for k in range(pred_labels.shape[0]):
|
||||
pred_class = pred_labels[k].item()
|
||||
should_fuse = pred_class in label_ids_to_fuse
|
||||
|
||||
# Check if mask exists and large enough to be a segment
|
||||
mask_exists, mask_k = check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_threshold)
|
||||
|
||||
if mask_exists:
|
||||
if pred_class in stuff_memory_list:
|
||||
current_segment_id = stuff_memory_list[pred_class]
|
||||
else:
|
||||
current_segment_id += 1
|
||||
|
||||
# Add current object segment to final segmentation map
|
||||
segmentation[mask_k] = current_segment_id
|
||||
segment_score = round(pred_scores[k].item(), 6)
|
||||
segments.append(
|
||||
{
|
||||
"id": current_segment_id,
|
||||
"label_id": pred_class,
|
||||
"was_fused": should_fuse,
|
||||
"score": segment_score,
|
||||
}
|
||||
)
|
||||
if should_fuse:
|
||||
stuff_memory_list[pred_class] = current_segment_id
|
||||
|
||||
return segmentation, segments
|
||||
|
||||
|
||||
class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
r"""
|
||||
Constructs a DETR feature extractor.
|
||||
|
@ -1098,7 +1195,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||
|
||||
semantic_segmentation = []
|
||||
for idx in range(batch_size):
|
||||
resized_logits = torch.nn.functional.interpolate(
|
||||
resized_logits = nn.functional.interpolate(
|
||||
segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
||||
)
|
||||
semantic_map = resized_logits[0].argmax(dim=0)
|
||||
|
@ -1114,31 +1211,34 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||
outputs,
|
||||
threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
target_sizes: List[Tuple] = None,
|
||||
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
return_coco_annotation: Optional[bool] = False,
|
||||
):
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Args:
|
||||
Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch.
|
||||
outputs ([`DetrForSegmentation`]):
|
||||
Raw outputs of the model.
|
||||
threshold (`float`, *optional*):
|
||||
The probability score threshold to keep predicted instance masks, defaults to 0.5.
|
||||
overlap_mask_area_threshold (`float`, *optional*):
|
||||
threshold (`float`, *optional*, defaults to 0.5):
|
||||
The probability score threshold to keep predicted instance masks.
|
||||
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
||||
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
||||
instance mask, defaults to 0.8.
|
||||
target_sizes (`List[Tuple]`, *optional*, defaults to `None`):
|
||||
instance mask.
|
||||
target_sizes (`List[Tuple]`, *optional*):
|
||||
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
|
||||
final size (height, width) of each prediction. If left to None, predictions will not be resized.
|
||||
return_coco_annotation (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format.
|
||||
return_coco_annotation (`bool`, *optional*):
|
||||
Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
|
||||
format.
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
||||
- **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
|
||||
`List[List]` run-length encoding (RLE) of the segmentation map if return_coco_format is set to `True`.
|
||||
- **segment_ids** -- A dictionary that maps segment ids to semantic class ids.
|
||||
`List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
|
||||
`True`. Set to `None` if no mask if found above `threshold`.
|
||||
- **segments_info** -- A dictionary that contains additional information on each segment.
|
||||
- **id** -- An integer representing the `segment_id`.
|
||||
- **label_id** -- An integer representing the segment's label / semantic class id.
|
||||
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
||||
- **score** -- Prediction score of segment with `segment_id`.
|
||||
"""
|
||||
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
|
||||
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
|
||||
|
@ -1159,76 +1259,27 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
||||
)
|
||||
|
||||
height, width = target_sizes[i][0], target_sizes[i][1]
|
||||
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs_item.device)
|
||||
segments: List[Dict] = []
|
||||
# No mask found
|
||||
if mask_probs_item.shape[0] <= 0:
|
||||
segmentation = None
|
||||
segments: List[Dict] = []
|
||||
continue
|
||||
|
||||
object_detected = mask_probs_item.shape[0] > 0
|
||||
|
||||
if object_detected:
|
||||
# Resize mask to corresponding target_size
|
||||
if target_sizes is not None:
|
||||
mask_probs_item = torch.nn.functional.interpolate(
|
||||
mask_probs_item.unsqueeze(0),
|
||||
size=target_sizes[i],
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)[0]
|
||||
|
||||
current_segment_id = 0
|
||||
|
||||
# Weigh each mask by its prediction score
|
||||
mask_probs_item *= pred_scores_item.view(-1, 1, 1)
|
||||
mask_labels_item = mask_probs_item.argmax(0) # [height, width]
|
||||
|
||||
# Keep track of instances of each class
|
||||
stuff_memory_list: Dict[str, int] = {}
|
||||
for k in range(pred_labels_item.shape[0]):
|
||||
# Get the mask associated with the k class
|
||||
pred_class = pred_labels_item[k].item()
|
||||
mask_k = mask_labels_item == k
|
||||
mask_k_area = mask_k.sum()
|
||||
|
||||
# Compute the area of all the stuff in query k
|
||||
original_area = (mask_probs_item[k] >= 0.5).sum()
|
||||
mask_exists = mask_k_area > 0 and original_area > 0
|
||||
|
||||
if mask_exists:
|
||||
# Eliminate segments with mask area below threshold
|
||||
area_ratio = mask_k_area / original_area
|
||||
if not area_ratio.item() > overlap_mask_area_threshold:
|
||||
continue
|
||||
|
||||
# Add corresponding class id
|
||||
if pred_class in stuff_memory_list:
|
||||
current_segment_id = stuff_memory_list[pred_class]
|
||||
else:
|
||||
current_segment_id += 1
|
||||
|
||||
# Add current object segment to final segmentation map
|
||||
segmentation[mask_k] = current_segment_id
|
||||
segments.append(
|
||||
{
|
||||
"id": current_segment_id,
|
||||
"label_id": pred_class,
|
||||
}
|
||||
)
|
||||
else:
|
||||
segmentation -= 1
|
||||
# Get segmentation map and segment information of batch item
|
||||
target_size = target_sizes[i] if target_sizes is not None else None
|
||||
segmentation, segments = compute_segments(
|
||||
mask_probs_item,
|
||||
pred_scores_item,
|
||||
pred_labels_item,
|
||||
overlap_mask_area_threshold,
|
||||
target_size,
|
||||
)
|
||||
|
||||
# Return segmentation map in run-length encoding (RLE) format
|
||||
if return_coco_annotation:
|
||||
segment_ids = torch.unique(segmentation)
|
||||
segmentation = convert_segmentation_to_rle(segmentation)
|
||||
|
||||
run_length_encodings = []
|
||||
for idx in segment_ids:
|
||||
mask = torch.where(segmentation == idx, 1, 0)
|
||||
rle = binary_mask_to_rle(mask)
|
||||
run_length_encodings.append(rle)
|
||||
|
||||
segmentation = run_length_encodings
|
||||
|
||||
results.append({"segmentation": segmentation, "segment_ids": segments})
|
||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||
return results
|
||||
|
||||
def post_process_panoptic_segmentation(
|
||||
|
@ -1237,7 +1288,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||
threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
label_ids_to_fuse: Optional[Set[int]] = None,
|
||||
target_sizes: List[Tuple] = None,
|
||||
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Args:
|
||||
|
@ -1250,7 +1301,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
||||
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
||||
instance mask.
|
||||
label_ids_to_fuse (`Set[int]`, *optional*, defaults to `None`):
|
||||
label_ids_to_fuse (`Set[int]`, *optional*):
|
||||
The labels in this state will have all their instances be fused together. For instance we could say
|
||||
there can only be one sky in an image, but several persons, so the label ID for sky would be in that
|
||||
set, but not the one for person.
|
||||
|
@ -1260,13 +1311,15 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||
resized.
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
||||
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`. If
|
||||
`target_sizes` is specified, segmentation is resized to the corresponding `target_sizes` entry.
|
||||
- **segment_ids** -- A dictionary that maps segment ids to semantic class ids.
|
||||
- **id** -- An integer representing the `segment_id`.
|
||||
- **label_id** -- An integer representing the segment's label / semantic class id.
|
||||
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or
|
||||
`None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to
|
||||
the corresponding `target_sizes` entry.
|
||||
- **segments_info** -- A dictionary that contains additional information on each segment.
|
||||
- **id** -- an integer representing the `segment_id`.
|
||||
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
||||
- **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
|
||||
Multiple instances of the same class / label were fused and assigned a single `segment_id`.
|
||||
- **score** -- Prediction score of segment with `segment_id`.
|
||||
"""
|
||||
|
||||
if label_ids_to_fuse is None:
|
||||
|
@ -1292,67 +1345,22 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
||||
)
|
||||
|
||||
height, width = target_sizes[i][0], target_sizes[i][1]
|
||||
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs_item.device)
|
||||
segments: List[Dict] = []
|
||||
# No mask found
|
||||
if mask_probs_item.shape[0] <= 0:
|
||||
segmentation = None
|
||||
segments: List[Dict] = []
|
||||
continue
|
||||
|
||||
object_detected = mask_probs_item.shape[0] > 0
|
||||
# Get segmentation map and segment information of batch item
|
||||
target_size = target_sizes[i] if target_sizes is not None else None
|
||||
segmentation, segments = compute_segments(
|
||||
mask_probs_item,
|
||||
pred_scores_item,
|
||||
pred_labels_item,
|
||||
overlap_mask_area_threshold,
|
||||
label_ids_to_fuse,
|
||||
target_size,
|
||||
)
|
||||
|
||||
if object_detected:
|
||||
# Resize mask to corresponding target_size
|
||||
if target_sizes is not None:
|
||||
mask_probs_item = torch.nn.functional.interpolate(
|
||||
mask_probs_item.unsqueeze(0),
|
||||
size=target_sizes[i],
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)[0]
|
||||
|
||||
current_segment_id = 0
|
||||
|
||||
# Weigh each mask by its prediction score
|
||||
mask_probs_item *= pred_scores_item.view(-1, 1, 1)
|
||||
mask_labels_item = mask_probs_item.argmax(0) # [height, width]
|
||||
|
||||
# Keep track of instances of each class
|
||||
stuff_memory_list: Dict[str, int] = {}
|
||||
for k in range(pred_labels_item.shape[0]):
|
||||
pred_class = pred_labels_item[k].item()
|
||||
should_fuse = pred_class in label_ids_to_fuse
|
||||
|
||||
# Get the mask associated with the k class
|
||||
mask_k = mask_labels_item == k
|
||||
mask_k_area = mask_k.sum()
|
||||
|
||||
# Compute the area of all the stuff in query k
|
||||
original_area = (mask_probs_item[k] >= 0.5).sum()
|
||||
mask_exists = mask_k_area > 0 and original_area > 0
|
||||
|
||||
if mask_exists:
|
||||
# Eliminate disconnected tiny segments
|
||||
area_ratio = mask_k_area / original_area
|
||||
if not area_ratio.item() > overlap_mask_area_threshold:
|
||||
continue
|
||||
|
||||
# Add corresponding class id
|
||||
if pred_class in stuff_memory_list:
|
||||
current_segment_id = stuff_memory_list[pred_class]
|
||||
else:
|
||||
current_segment_id += 1
|
||||
|
||||
# Add current object segment to final segmentation map
|
||||
segmentation[mask_k] = current_segment_id
|
||||
segments.append(
|
||||
{
|
||||
"id": current_segment_id,
|
||||
"label_id": pred_class,
|
||||
"was_fused": should_fuse,
|
||||
}
|
||||
)
|
||||
if should_fuse:
|
||||
stuff_memory_list[pred_class] = current_segment_id
|
||||
else:
|
||||
segmentation -= 1
|
||||
|
||||
results.append({"segmentation": segmentation, "segment_ids": segments})
|
||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||
return results
|
||||
|
|
|
@ -1605,12 +1605,12 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|||
|
||||
>>> # Use the `post_process_panoptic_segmentation` method of `DetrFeatureExtractor` to retrieve post-processed panoptic segmentation maps
|
||||
>>> # Segmentation results are returned as a list of dictionaries
|
||||
>>> result = feature_extractor.post_process_panoptic_segmentation(outputs, processed_sizes)
|
||||
>>> result = feature_extractor.post_process_panoptic_segmentation(outputs, target_size=[(300, 500)])
|
||||
|
||||
>>> # A tensor of shape (height, width) where each value denotes a segment id
|
||||
>>> panoptic_seg = result[0]["segmentation"]
|
||||
>>> # Get mapping of segment ids to semantic class ids
|
||||
>>> panoptic_segments_info = result[0]["segment_ids"]
|
||||
>>> panoptic_segments_info = result[0]["segments_info"]
|
||||
```"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
|
Loading…
Reference in New Issue