Improve OWL-ViT postprocessing (#20980)

* add post_process_object_detection method

* style changes
This commit is contained in:
Alara Dirik 2023-01-03 19:25:09 +03:00 committed by GitHub
parent e901914da7
commit cd2457809f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 87 additions and 18 deletions

View File

@ -80,7 +80,7 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi
[[autodoc]] OwlViTImageProcessor
- preprocess
- post_process
- post_process_object_detection
- post_process_image_guided_detection
## OwlViTFeatureExtractor

View File

@ -14,7 +14,8 @@
# limitations under the License.
"""Image processor class for OwlViT"""
from typing import Dict, List, Optional, Union
import warnings
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
@ -344,6 +345,12 @@ class OwlViTImageProcessor(BaseImageProcessor):
in the batch as predicted by the model.
"""
# TODO: (amy) add support for other frameworks
warnings.warn(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection`",
FutureWarning,
)
logits, boxes = outputs.logits, outputs.pred_boxes
if len(logits) != len(target_sizes):
@ -367,6 +374,61 @@ class OwlViTImageProcessor(BaseImageProcessor):
return results
def post_process_object_detection(
self, outputs, threshold: float = 0.1, target_sizes: Union[TensorType, List[Tuple]] = None
):
"""
Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
bottom_right_x, bottom_right_y) format.
Args:
outputs ([`OwlViTObjectDetectionOutput`]):
Raw outputs of the model.
threshold (`float`, *optional*):
Score threshold to keep object detection predictions.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
"""
# TODO: (amy) add support for other frameworks
logits, boxes = outputs.logits, outputs.pred_boxes
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
probs = torch.max(logits, dim=-1)
scores = torch.sigmoid(probs.values)
labels = probs.indices
# Convert to [x0, y0, x1, y1] format
boxes = center_to_corners_format(boxes)
# Convert from relative [0, 1] to absolute [0, height] coordinates
if target_sizes is not None:
if isinstance(target_sizes, List):
img_h = torch.Tensor([i[0] for i in target_sizes])
img_w = torch.Tensor([i[1] for i in target_sizes])
else:
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]
results = []
for s, l, b in zip(scores, labels, boxes):
score = s[s > threshold]
label = l[s > threshold]
box = b[s > threshold]
results.append({"scores": score, "labels": label, "boxes": box})
return results
# TODO: (Amy) Make compatible with other frameworks
def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_threshold=0.3, target_sizes=None):
"""

View File

@ -204,8 +204,8 @@ class OwlViTObjectDetectionOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the unnormalized
bounding boxes.
possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to retrieve the
unnormalized bounding boxes.
text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
@ -248,13 +248,13 @@ class OwlViTImageGuidedObjectDetectionOutput(ModelOutput):
target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual target image in the batch
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the
unnormalized bounding boxes.
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to
retrieve the unnormalized bounding boxes.
query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual query image in the batch
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the
unnormalized bounding boxes.
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to
retrieve the unnormalized bounding boxes.
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
image embeddings for each patch.
@ -1644,18 +1644,18 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
>>> target_sizes = torch.Tensor([image.size[::-1]])
>>> # Convert outputs (bounding boxes and class logits) to COCO API
>>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
>>> # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores
>>> results = processor.post_process_object_detection(
... outputs=outputs, threshold=0.1, target_sizes=target_sizes
... )
>>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
>>> text = texts[i]
>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
>>> score_threshold = 0.1
>>> for box, score, label in zip(boxes, scores, labels):
... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold:
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
```"""

View File

@ -179,6 +179,13 @@ class OwlViTProcessor(ProcessorMixin):
"""
return self.image_processor.post_process(*args, **kwargs)
def post_process_object_detection(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_object_detection`]. Please refer
to the docstring of this method for more information.
"""
return self.image_processor.post_process_object_detection(*args, **kwargs)
def post_process_image_guided_detection(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_one_shot_object_detection`].

View File

@ -173,12 +173,11 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
for model_output in model_outputs:
label = model_output["candidate_label"]
model_output = BaseModelOutput(model_output)
outputs = self.feature_extractor.post_process(
outputs=model_output, target_sizes=model_output["target_size"]
outputs = self.feature_extractor.post_process_object_detection(
outputs=model_output, threshold=threshold, target_sizes=model_output["target_size"]
)[0]
keep = outputs["scores"] >= threshold
for index in keep.nonzero():
for index in outputs["scores"].nonzero():
score = outputs["scores"][index].item()
box = self._get_bounding_box(outputs["boxes"][index][0])

View File

@ -131,7 +131,8 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=Pipeline
object_detector = pipeline("zero-shot-object-detection")
outputs = object_detector(
"http://images.cocodataset.org/val2017/000000039769.jpg", candidate_labels=["cat", "remote", "couch"]
"http://images.cocodataset.org/val2017/000000039769.jpg",
candidate_labels=["cat", "remote", "couch"],
)
self.assertEqual(
nested_simplify(outputs, decimals=4),