Handle image_embeds in ViltModel (#16696)

* update

* batch_size -> text_batch_size

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-04-11 22:16:20 +02:00 committed by GitHub
parent 161c0a2eec
commit 7f7300856d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 11 deletions

View File

@ -704,7 +704,7 @@ VILT_IMAGES_AND_TEXT_CLASSIFICATION_INPUTS_DOCSTRING = r"""
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
image_embeds (`torch.FloatTensor` of shape `(batch_size, num_images, num_patches, hidden_size)`, *optional*):
Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
@ -805,18 +805,22 @@ class ViltModel(ViltPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
text_batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
attention_mask = torch.ones(((text_batch_size, seq_length)), device=device)
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
if pixel_values is not None and image_embeds is not None:
raise ValueError("You cannot specify both pixel_values and image_embeds at the same time")
elif pixel_values is None and image_embeds is None:
raise ValueError("You have to specify either pixel_values or image_embeds")
batch_size, num_channels, height, width = pixel_values.shape
image_batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeds.shape[0]
if image_batch_size != text_batch_size:
raise ValueError("The text inputs and image inputs need to have the same batch size")
if pixel_mask is None:
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
pixel_mask = torch.ones((image_batch_size, self.config.image_size, self.config.image_size), device=device)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@ -1338,11 +1342,17 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values.ndim == 4:
if pixel_values is not None and pixel_values.ndim == 4:
# add dummy num_images dimension
pixel_values = pixel_values.unsqueeze(1)
num_images = pixel_values.shape[1]
if image_embeds is not None and image_embeds.ndim == 3:
# add dummy num_images dimension
image_embeds = image_embeds.unsqueeze(1)
num_images = pixel_values.shape[1] if pixel_values is not None else None
if num_images is None:
num_images = image_embeds.shape[1] if image_embeds is not None else None
if num_images != self.config.num_images:
raise ValueError(
"Make sure to match the number of images in the model with the number of images in the input."
@ -1356,11 +1366,11 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
pixel_values=pixel_values[:, i, :, :, :],
pixel_values=pixel_values[:, i, :, :, :] if pixel_values is not None else None,
pixel_mask=pixel_mask[:, i, :, :] if pixel_mask is not None else None,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
image_embeds=image_embeds,
image_embeds=image_embeds[:, i, :, :] if image_embeds is not None else None,
image_token_type_idx=i + 1,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,