[OWL-ViT] Make model consistent with CLIP (#20144)
* Apply fix * Fix test * Remove another argument which is not used * Fix pipeline test * Add argument back, add deprecation warning * Add warning add other location * Use warnings instead * Add num_channels to config Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain>
This commit is contained in:
parent
d3c0566679
commit
cbbeca3d17
|
@ -165,6 +165,8 @@ class OwlViTVisionConfig(PretrainedConfig):
|
|||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of channels in the input images.
|
||||
image_size (`int`, *optional*, defaults to 768):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int`, *optional*, defaults to 32):
|
||||
|
@ -206,6 +208,7 @@ class OwlViTVisionConfig(PretrainedConfig):
|
|||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
num_channels=3,
|
||||
image_size=768,
|
||||
patch_size=32,
|
||||
hidden_act="quick_gelu",
|
||||
|
@ -222,6 +225,7 @@ class OwlViTVisionConfig(PretrainedConfig):
|
|||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.hidden_act = hidden_act
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
""" PyTorch OWL-ViT model."""
|
||||
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
|
@ -516,9 +517,6 @@ OWLVIT_INPUTS_DOCSTRING = r"""
|
|||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_base_image_embeds (`bool`, *optional*):
|
||||
Whether or not to return unprojected image embeddings. Set to `True` when `OwlViTModel` is called within
|
||||
`OwlViTForObjectDetection`.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
@ -785,7 +783,6 @@ class OwlViTVisionTransformer(nn.Module):
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
use_hidden_state: Optional[bool] = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
@ -809,10 +806,7 @@ class OwlViTVisionTransformer(nn.Module):
|
|||
last_hidden_state = encoder_outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
|
||||
if use_hidden_state:
|
||||
pooled_output = self.post_layernorm(last_hidden_state)
|
||||
else:
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
@ -963,7 +957,6 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
return_projected: Optional[bool] = True,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Returns:
|
||||
|
@ -1000,10 +993,8 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
|||
pooled_output = vision_outputs[1] # pooled_output
|
||||
|
||||
# Return projected output
|
||||
if return_projected:
|
||||
image_features = self.visual_projection(pooled_output)
|
||||
else:
|
||||
image_features = pooled_output
|
||||
image_features = self.visual_projection(pooled_output)
|
||||
|
||||
return image_features
|
||||
|
||||
@add_start_docstrings_to_model_forward(OWLVIT_INPUTS_DOCSTRING)
|
||||
|
@ -1044,15 +1035,11 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
|||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Whether to return unprojected image features
|
||||
return_base_image_embeds = return_base_image_embeds if return_base_image_embeds is not None else False
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_hidden_state=False,
|
||||
)
|
||||
|
||||
# Get embeddings for all text queries in all batch samples
|
||||
|
@ -1070,12 +1057,12 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
|||
image_embeds = self.visual_projection(image_embeds)
|
||||
|
||||
# normalized features
|
||||
image_embeds_norm = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
|
||||
text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
|
||||
image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
|
||||
text_embeds = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_text = torch.matmul(text_embeds_norm, image_embeds_norm.t()) * logit_scale
|
||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
loss = None
|
||||
|
@ -1083,11 +1070,13 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
|||
loss = owlvit_loss(logits_per_text)
|
||||
|
||||
if return_base_image_embeds:
|
||||
warnings.warn(
|
||||
"`return_base_image_embeds` is deprecated and will be removed in v4.27 of Transformers, one can "
|
||||
" obtain the base (unprojected) image embeddings from outputs.vision_model_output.",
|
||||
FutureWarning,
|
||||
)
|
||||
last_hidden_state = vision_outputs[0]
|
||||
image_embeds = self.vision_model.post_layernorm(last_hidden_state)
|
||||
else:
|
||||
image_embeds = image_embeds_norm
|
||||
text_embeds = text_embeds_norm
|
||||
|
||||
if not return_dict:
|
||||
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
||||
|
@ -1276,11 +1265,12 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
|||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_base_image_embeds=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# Resize class token
|
||||
image_embeds = outputs[-3]
|
||||
last_hidden_state = outputs.vision_model_output[0]
|
||||
image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)
|
||||
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
|
||||
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
|
||||
|
||||
|
@ -1296,7 +1286,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
|||
image_embeds.shape[-1],
|
||||
)
|
||||
image_embeds = image_embeds.reshape(new_size)
|
||||
text_embeds = outputs[-4]
|
||||
text_embeds = outputs.text_embeds
|
||||
|
||||
# Last hidden states from text and vision transformers
|
||||
text_model_last_hidden_state = outputs[-2][0]
|
||||
|
|
|
@ -120,7 +120,7 @@ class OwlViTVisionModelTester:
|
|||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||
num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
|
Loading…
Reference in New Issue