Improve OWL-ViT postprocessing (#20980)
* add post_process_object_detection method * style changes
This commit is contained in:
parent
e901914da7
commit
cd2457809f
|
@ -80,7 +80,7 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi
|
|||
|
||||
[[autodoc]] OwlViTImageProcessor
|
||||
- preprocess
|
||||
- post_process
|
||||
- post_process_object_detection
|
||||
- post_process_image_guided_detection
|
||||
|
||||
## OwlViTFeatureExtractor
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
# limitations under the License.
|
||||
"""Image processor class for OwlViT"""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -344,6 +345,12 @@ class OwlViTImageProcessor(BaseImageProcessor):
|
|||
in the batch as predicted by the model.
|
||||
"""
|
||||
# TODO: (amy) add support for other frameworks
|
||||
warnings.warn(
|
||||
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
|
||||
" `post_process_object_detection`",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
logits, boxes = outputs.logits, outputs.pred_boxes
|
||||
|
||||
if len(logits) != len(target_sizes):
|
||||
|
@ -367,6 +374,61 @@ class OwlViTImageProcessor(BaseImageProcessor):
|
|||
|
||||
return results
|
||||
|
||||
def post_process_object_detection(
|
||||
self, outputs, threshold: float = 0.1, target_sizes: Union[TensorType, List[Tuple]] = None
|
||||
):
|
||||
"""
|
||||
Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
||||
bottom_right_x, bottom_right_y) format.
|
||||
|
||||
Args:
|
||||
outputs ([`OwlViTObjectDetectionOutput`]):
|
||||
Raw outputs of the model.
|
||||
threshold (`float`, *optional*):
|
||||
Score threshold to keep object detection predictions.
|
||||
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
||||
in the batch as predicted by the model.
|
||||
"""
|
||||
# TODO: (amy) add support for other frameworks
|
||||
logits, boxes = outputs.logits, outputs.pred_boxes
|
||||
|
||||
if target_sizes is not None:
|
||||
if len(logits) != len(target_sizes):
|
||||
raise ValueError(
|
||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||
)
|
||||
|
||||
probs = torch.max(logits, dim=-1)
|
||||
scores = torch.sigmoid(probs.values)
|
||||
labels = probs.indices
|
||||
|
||||
# Convert to [x0, y0, x1, y1] format
|
||||
boxes = center_to_corners_format(boxes)
|
||||
|
||||
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
||||
if target_sizes is not None:
|
||||
if isinstance(target_sizes, List):
|
||||
img_h = torch.Tensor([i[0] for i in target_sizes])
|
||||
img_w = torch.Tensor([i[1] for i in target_sizes])
|
||||
else:
|
||||
img_h, img_w = target_sizes.unbind(1)
|
||||
|
||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
||||
boxes = boxes * scale_fct[:, None, :]
|
||||
|
||||
results = []
|
||||
for s, l, b in zip(scores, labels, boxes):
|
||||
score = s[s > threshold]
|
||||
label = l[s > threshold]
|
||||
box = b[s > threshold]
|
||||
results.append({"scores": score, "labels": label, "boxes": box})
|
||||
|
||||
return results
|
||||
|
||||
# TODO: (Amy) Make compatible with other frameworks
|
||||
def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_threshold=0.3, target_sizes=None):
|
||||
"""
|
||||
|
|
|
@ -204,8 +204,8 @@ class OwlViTObjectDetectionOutput(ModelOutput):
|
|||
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
|
||||
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
||||
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
||||
possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the unnormalized
|
||||
bounding boxes.
|
||||
possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to retrieve the
|
||||
unnormalized bounding boxes.
|
||||
text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
|
||||
The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
|
||||
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
|
||||
|
@ -248,13 +248,13 @@ class OwlViTImageGuidedObjectDetectionOutput(ModelOutput):
|
|||
target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
|
||||
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
||||
values are normalized in [0, 1], relative to the size of each individual target image in the batch
|
||||
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the
|
||||
unnormalized bounding boxes.
|
||||
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to
|
||||
retrieve the unnormalized bounding boxes.
|
||||
query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
|
||||
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
||||
values are normalized in [0, 1], relative to the size of each individual query image in the batch
|
||||
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the
|
||||
unnormalized bounding boxes.
|
||||
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to
|
||||
retrieve the unnormalized bounding boxes.
|
||||
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
|
||||
Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
|
||||
image embeddings for each patch.
|
||||
|
@ -1644,18 +1644,18 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
|||
|
||||
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
|
||||
>>> target_sizes = torch.Tensor([image.size[::-1]])
|
||||
>>> # Convert outputs (bounding boxes and class logits) to COCO API
|
||||
>>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
|
||||
>>> # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores
|
||||
>>> results = processor.post_process_object_detection(
|
||||
... outputs=outputs, threshold=0.1, target_sizes=target_sizes
|
||||
... )
|
||||
|
||||
>>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
|
||||
>>> text = texts[i]
|
||||
>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
|
||||
|
||||
>>> score_threshold = 0.1
|
||||
>>> for box, score, label in zip(boxes, scores, labels):
|
||||
... box = [round(i, 2) for i in box.tolist()]
|
||||
... if score >= score_threshold:
|
||||
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
|
||||
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
|
||||
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
|
||||
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
|
||||
```"""
|
||||
|
|
|
@ -179,6 +179,13 @@ class OwlViTProcessor(ProcessorMixin):
|
|||
"""
|
||||
return self.image_processor.post_process(*args, **kwargs)
|
||||
|
||||
def post_process_object_detection(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_object_detection`]. Please refer
|
||||
to the docstring of this method for more information.
|
||||
"""
|
||||
return self.image_processor.post_process_object_detection(*args, **kwargs)
|
||||
|
||||
def post_process_image_guided_detection(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_one_shot_object_detection`].
|
||||
|
|
|
@ -173,12 +173,11 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
|
|||
for model_output in model_outputs:
|
||||
label = model_output["candidate_label"]
|
||||
model_output = BaseModelOutput(model_output)
|
||||
outputs = self.feature_extractor.post_process(
|
||||
outputs=model_output, target_sizes=model_output["target_size"]
|
||||
outputs = self.feature_extractor.post_process_object_detection(
|
||||
outputs=model_output, threshold=threshold, target_sizes=model_output["target_size"]
|
||||
)[0]
|
||||
keep = outputs["scores"] >= threshold
|
||||
|
||||
for index in keep.nonzero():
|
||||
for index in outputs["scores"].nonzero():
|
||||
score = outputs["scores"][index].item()
|
||||
box = self._get_bounding_box(outputs["boxes"][index][0])
|
||||
|
||||
|
|
|
@ -131,7 +131,8 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=Pipeline
|
|||
object_detector = pipeline("zero-shot-object-detection")
|
||||
|
||||
outputs = object_detector(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg", candidate_labels=["cat", "remote", "couch"]
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
candidate_labels=["cat", "remote", "couch"],
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
|
|
Loading…
Reference in New Issue