Handle image_embeds in ViltModel (#16696)
* update * batch_size -> text_batch_size Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
161c0a2eec
commit
7f7300856d
|
@ -704,7 +704,7 @@ VILT_IMAGES_AND_TEXT_CLASSIFICATION_INPUTS_DOCSTRING = r"""
|
|||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
|
||||
image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
|
||||
image_embeds (`torch.FloatTensor` of shape `(batch_size, num_images, num_patches, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
|
||||
|
||||
|
@ -805,18 +805,22 @@ class ViltModel(ViltPreTrainedModel):
|
|||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
text_batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
||||
attention_mask = torch.ones(((text_batch_size, seq_length)), device=device)
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
if pixel_values is not None and image_embeds is not None:
|
||||
raise ValueError("You cannot specify both pixel_values and image_embeds at the same time")
|
||||
elif pixel_values is None and image_embeds is None:
|
||||
raise ValueError("You have to specify either pixel_values or image_embeds")
|
||||
|
||||
batch_size, num_channels, height, width = pixel_values.shape
|
||||
image_batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeds.shape[0]
|
||||
if image_batch_size != text_batch_size:
|
||||
raise ValueError("The text inputs and image inputs need to have the same batch size")
|
||||
if pixel_mask is None:
|
||||
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
|
||||
pixel_mask = torch.ones((image_batch_size, self.config.image_size, self.config.image_size), device=device)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
|
@ -1338,11 +1342,17 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
|
|||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values.ndim == 4:
|
||||
if pixel_values is not None and pixel_values.ndim == 4:
|
||||
# add dummy num_images dimension
|
||||
pixel_values = pixel_values.unsqueeze(1)
|
||||
|
||||
num_images = pixel_values.shape[1]
|
||||
if image_embeds is not None and image_embeds.ndim == 3:
|
||||
# add dummy num_images dimension
|
||||
image_embeds = image_embeds.unsqueeze(1)
|
||||
|
||||
num_images = pixel_values.shape[1] if pixel_values is not None else None
|
||||
if num_images is None:
|
||||
num_images = image_embeds.shape[1] if image_embeds is not None else None
|
||||
if num_images != self.config.num_images:
|
||||
raise ValueError(
|
||||
"Make sure to match the number of images in the model with the number of images in the input."
|
||||
|
@ -1356,11 +1366,11 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
|
|||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
pixel_values=pixel_values[:, i, :, :, :],
|
||||
pixel_values=pixel_values[:, i, :, :, :] if pixel_values is not None else None,
|
||||
pixel_mask=pixel_mask[:, i, :, :] if pixel_mask is not None else None,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
image_embeds=image_embeds,
|
||||
image_embeds=image_embeds[:, i, :, :] if image_embeds is not None else None,
|
||||
image_token_type_idx=i + 1,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
|
|
Loading…
Reference in New Issue