diff --git a/docs/source/en/model_doc/owlvit.mdx b/docs/source/en/model_doc/owlvit.mdx index 84747d0a6d..0b61d7b274 100644 --- a/docs/source/en/model_doc/owlvit.mdx +++ b/docs/source/en/model_doc/owlvit.mdx @@ -39,19 +39,26 @@ OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses [CL >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - ->>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") - +>>> texts = [["a photo of a cat", "a photo of a dog"]] +>>> inputs = processor(text=texts, images=image, return_tensors="pt") >>> outputs = model(**inputs) ->>> logits = outputs["logits"] # Prediction logits of shape [batch_size, num_patches, num_max_text_queries] ->>> boxes = outputs["pred_boxes"] # Object box boundaries of shape [batch_size, num_patches, 4] ->>> batch_size = boxes.shape[0] ->>> for i in range(batch_size): # Loop over sets of images and text queries -... boxes = outputs["pred_boxes"][i] -... logits = torch.max(outputs["logits"][i], dim=-1) -... scores = torch.sigmoid(logits.values) -... labels = logits.indices +>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2] +>>> target_sizes = torch.Tensor([image.size[::-1]]) +>>> # Convert outputs (bounding boxes and class logits) to COCO API +>>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes) + +>>> i = 0 # Retrieve predictions for the first image for the corresponding text queries +>>> text = texts[i] +>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] + +>>> score_threshold = 0.1 +>>> for box, score, label in zip(boxes, scores, labels): +... box = [round(i, 2) for i in box.tolist()] +... if score >= score_threshold: +... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") +Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48] +Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61] ``` This model was contributed by [adirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit). diff --git a/src/transformers/models/owlvit/feature_extraction_owlvit.py b/src/transformers/models/owlvit/feature_extraction_owlvit.py index 8e0a142085..1e4bc73560 100644 --- a/src/transformers/models/owlvit/feature_extraction_owlvit.py +++ b/src/transformers/models/owlvit/feature_extraction_owlvit.py @@ -26,7 +26,6 @@ from ...utils import TensorType, is_torch_available, logging if is_torch_available(): import torch - from torch import nn logger = logging.get_logger(__name__) @@ -109,18 +108,19 @@ class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image in the batch as predicted by the model. """ - out_logits, out_bbox = outputs.logits, outputs.pred_boxes + logits, boxes = outputs.logits, outputs.pred_boxes - if len(out_logits) != len(target_sizes): + if len(logits) != len(target_sizes): raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") if target_sizes.shape[1] != 2: raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") - prob = nn.functional.softmax(out_logits, -1) - scores, labels = prob[..., :-1].max(-1) + probs = torch.max(logits, dim=-1) + scores = torch.sigmoid(probs.values) + labels = probs.indices # Convert to [x0, y0, x1, y1] format - boxes = center_to_corners_format(out_bbox) + boxes = center_to_corners_format(boxes) # Convert from relative [0, 1] to absolute [0, height] coordinates img_h, img_w = target_sizes.unbind(1) diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index c872b1b28f..cb9a385cc2 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -1300,23 +1300,31 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): >>> import torch >>> from transformers import OwlViTProcessor, OwlViTForObjectDetection - >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") + >>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") + >>> texts = [["a photo of a cat", "a photo of a dog"]] + >>> inputs = processor(text=texts, images=image, return_tensors="pt") >>> outputs = model(**inputs) - >>> logits = outputs["logits"] # Prediction logits of shape [batch_size, num_patches, num_max_text_queries] - >>> boxes = outputs["pred_boxes"] # Object box boundaries of shape # [batch_size, num_patches, 4] - >>> batch_size = boxes.shape[0] - >>> for i in range(batch_size): # Loop over sets of images and text queries - ... boxes = outputs["pred_boxes"][i] - ... logits = torch.max(outputs["logits"][i], dim=-1) - ... scores = torch.sigmoid(logits.values) - ... labels = logits.indices + >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2] + >>> target_sizes = torch.Tensor([image.size[::-1]]) + >>> # Convert outputs (bounding boxes and class logits) to COCO API + >>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes) + + >>> i = 0 # Retrieve predictions for the first image for the corresponding text queries + >>> text = texts[i] + >>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] + + >>> score_threshold = 0.1 + >>> for box, score, label in zip(boxes, scores, labels): + ... box = [round(i, 2) for i in box.tolist()] + ... if score >= score_threshold: + ... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") + Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48] + Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61] ```""" output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/src/transformers/models/owlvit/processing_owlvit.py b/src/transformers/models/owlvit/processing_owlvit.py index 8dc04055bb..48060f0dcf 100644 --- a/src/transformers/models/owlvit/processing_owlvit.py +++ b/src/transformers/models/owlvit/processing_owlvit.py @@ -139,6 +139,13 @@ class OwlViTProcessor(ProcessorMixin): else: return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) + def post_process(self, *args, **kwargs): + """ + This method forwards all its arguments to [`OwlViTFeatureExtractor.post_process`]. Please refer to the + docstring of this method for more information. + """ + return self.feature_extractor.post_process(*args, **kwargs) + def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 0d1216e5b8..358023c1d6 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -48,6 +48,7 @@ src/transformers/models/mobilevit/modeling_mobilevit.py src/transformers/models/opt/modeling_opt.py src/transformers/models/opt/modeling_tf_opt.py src/transformers/models/opt/modeling_flax_opt.py +src/transformers/models/owlvit/modeling_owlvit.py src/transformers/models/pegasus/modeling_pegasus.py src/transformers/models/plbart/modeling_plbart.py src/transformers/models/poolformer/modeling_poolformer.py