[OWL-ViT] Make model consistent with CLIP (#20144)

* Apply fix

* Fix test

* Remove another argument which is not used

* Fix pipeline test

* Add argument back, add deprecation warning

* Add warning add other location

* Use warnings instead

* Add num_channels to config

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain>
This commit is contained in:
NielsRogge 2022-11-11 11:36:17 +01:00 committed by GitHub
parent d3c0566679
commit cbbeca3d17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 27 deletions

View File

@ -165,6 +165,8 @@ class OwlViTVisionConfig(PretrainedConfig):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 768):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 32):
@ -206,6 +208,7 @@ class OwlViTVisionConfig(PretrainedConfig):
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=768,
patch_size=32,
hidden_act="quick_gelu",
@ -222,6 +225,7 @@ class OwlViTVisionConfig(PretrainedConfig):
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.image_size = image_size
self.patch_size = patch_size
self.hidden_act = hidden_act

View File

@ -15,6 +15,7 @@
""" PyTorch OWL-ViT model."""
import warnings
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
@ -516,9 +517,6 @@ OWLVIT_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_base_image_embeds (`bool`, *optional*):
Whether or not to return unprojected image embeddings. Set to `True` when `OwlViTModel` is called within
`OwlViTForObjectDetection`.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@ -785,7 +783,6 @@ class OwlViTVisionTransformer(nn.Module):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_hidden_state: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
@ -809,10 +806,7 @@ class OwlViTVisionTransformer(nn.Module):
last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :]
if use_hidden_state:
pooled_output = self.post_layernorm(last_hidden_state)
else:
pooled_output = self.post_layernorm(pooled_output)
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
@ -963,7 +957,6 @@ class OwlViTModel(OwlViTPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_projected: Optional[bool] = True,
) -> torch.FloatTensor:
r"""
Returns:
@ -1000,10 +993,8 @@ class OwlViTModel(OwlViTPreTrainedModel):
pooled_output = vision_outputs[1] # pooled_output
# Return projected output
if return_projected:
image_features = self.visual_projection(pooled_output)
else:
image_features = pooled_output
image_features = self.visual_projection(pooled_output)
return image_features
@add_start_docstrings_to_model_forward(OWLVIT_INPUTS_DOCSTRING)
@ -1044,15 +1035,11 @@ class OwlViTModel(OwlViTPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Whether to return unprojected image features
return_base_image_embeds = return_base_image_embeds if return_base_image_embeds is not None else False
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_hidden_state=False,
)
# Get embeddings for all text queries in all batch samples
@ -1070,12 +1057,12 @@ class OwlViTModel(OwlViTPreTrainedModel):
image_embeds = self.visual_projection(image_embeds)
# normalized features
image_embeds_norm = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
text_embeds = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds_norm, image_embeds_norm.t()) * logit_scale
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()
loss = None
@ -1083,11 +1070,13 @@ class OwlViTModel(OwlViTPreTrainedModel):
loss = owlvit_loss(logits_per_text)
if return_base_image_embeds:
warnings.warn(
"`return_base_image_embeds` is deprecated and will be removed in v4.27 of Transformers, one can "
" obtain the base (unprojected) image embeddings from outputs.vision_model_output.",
FutureWarning,
)
last_hidden_state = vision_outputs[0]
image_embeds = self.vision_model.post_layernorm(last_hidden_state)
else:
image_embeds = image_embeds_norm
text_embeds = text_embeds_norm
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
@ -1276,11 +1265,12 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_base_image_embeds=True,
return_dict=True,
)
# Resize class token
image_embeds = outputs[-3]
last_hidden_state = outputs.vision_model_output[0]
image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
@ -1296,7 +1286,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
image_embeds.shape[-1],
)
image_embeds = image_embeds.reshape(new_size)
text_embeds = outputs[-4]
text_embeds = outputs.text_embeds
# Last hidden states from text and vision transformers
text_model_last_hidden_state = outputs[-2][0]

View File

@ -120,7 +120,7 @@ class OwlViTVisionModelTester:
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
num_patches = (self.image_size // self.patch_size) ** 2
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()