Blip dynamic input resolution (#30722)

* blip with interpolated pos encoding

* feat: Add interpolate_pos_encoding option to other models from `BLIP` family.

* include check for textual generated content in tests
This commit is contained in:
Zafir Stojanovski 2024-05-13 13:20:16 +02:00 committed by GitHub
parent a4e530e3c8
commit f63d822242
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 240 additions and 20 deletions

View File

@ -14,6 +14,7 @@
# limitations under the License.
""" PyTorch BLIP model."""
import math
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
@ -231,15 +232,51 @@ class BlipVisionEmbeddings(nn.Module):
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embedding
class_pos_embed = self.position_embedding[:, 0, :]
patch_pos_embed = self.position_embedding[:, 1:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
if interpolate_pos_encoding:
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
else:
position_embedding = self.position_embedding
embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
return embeddings
@ -509,6 +546,8 @@ BLIP_VISION_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
"""
BLIP_INPUTS_DOCSTRING = r"""
@ -545,6 +584,8 @@ BLIP_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
"""
@ -657,6 +698,7 @@ class BlipVisionModel(BlipPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
@ -671,7 +713,7 @@ class BlipVisionModel(BlipPreTrainedModel):
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
@ -779,6 +821,7 @@ class BlipModel(BlipPreTrainedModel):
self,
pixel_values: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> torch.FloatTensor:
r"""
Returns:
@ -804,7 +847,11 @@ class BlipModel(BlipPreTrainedModel):
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(pixel_values=pixel_values, return_dict=return_dict)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
pooled_output = vision_outputs[1] # pooled_output
image_features = self.visual_projection(pooled_output)
@ -818,6 +865,7 @@ class BlipModel(BlipPreTrainedModel):
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> torch.FloatTensor:
r"""
Returns:
@ -846,6 +894,7 @@ class BlipModel(BlipPreTrainedModel):
output_attentions=True,
output_hidden_states=True,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
@ -876,6 +925,7 @@ class BlipModel(BlipPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BlipOutput]:
r"""
Returns:
@ -913,6 +963,7 @@ class BlipModel(BlipPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
text_outputs = self.text_model(
@ -999,6 +1050,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BlipForConditionalGenerationModelOutput]:
r"""
Returns:
@ -1033,6 +1085,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
@ -1065,6 +1118,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
pixel_values: torch.FloatTensor,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**generate_kwargs,
) -> torch.LongTensor:
r"""
@ -1100,7 +1154,10 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
"""
batch_size = pixel_values.shape[0]
vision_outputs = self.vision_model(pixel_values=pixel_values)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
@ -1174,6 +1231,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BlipTextVisionModelOutput]:
r"""
Returns:
@ -1227,6 +1285,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
@ -1279,6 +1338,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**generate_kwargs,
) -> torch.LongTensor:
r"""
@ -1316,7 +1376,10 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
2
```
"""
vision_outputs = self.vision_model(pixel_values=pixel_values)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
@ -1408,6 +1471,7 @@ class BlipForImageTextRetrieval(BlipPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BlipTextVisionModelOutput]:
r"""
Returns:
@ -1441,6 +1505,7 @@ class BlipForImageTextRetrieval(BlipPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]

View File

@ -101,15 +101,51 @@ class Blip2VisionEmbeddings(nn.Module):
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embedding
class_pos_embed = self.position_embedding[:, 0, :]
patch_pos_embed = self.position_embedding[:, 1:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
if interpolate_pos_encoding:
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
else:
position_embedding = self.position_embedding
embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
return embeddings
@ -321,6 +357,8 @@ BLIP_2_VISION_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
"""
BLIP_2_TEXT_INPUTS_DOCSTRING = r"""
@ -402,6 +440,8 @@ BLIP_2_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
"""
@ -516,6 +556,7 @@ class Blip2VisionModel(Blip2PreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
@ -530,7 +571,7 @@ class Blip2VisionModel(Blip2PreTrainedModel):
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
@ -1297,6 +1338,7 @@ class Blip2Model(Blip2PreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
):
r"""
Returns:
@ -1330,6 +1372,7 @@ class Blip2Model(Blip2PreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
return vision_outputs
@ -1341,6 +1384,7 @@ class Blip2Model(Blip2PreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
):
r"""
Returns:
@ -1374,6 +1418,7 @@ class Blip2Model(Blip2PreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
@ -1406,6 +1451,7 @@ class Blip2Model(Blip2PreTrainedModel):
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
r"""
Returns:
@ -1441,6 +1487,7 @@ class Blip2Model(Blip2PreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
@ -1623,6 +1670,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
r"""
Returns:
@ -1695,6 +1743,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
@ -1779,6 +1828,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
pixel_values: torch.FloatTensor,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**generate_kwargs,
) -> torch.LongTensor:
"""
@ -1800,7 +1850,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
self._preprocess_accelerate()
batch_size = pixel_values.shape[0]
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
image_embeds = self.vision_model(
pixel_values,
return_dict=True,
interpolate_pos_encoding=interpolate_pos_encoding,
).last_hidden_state
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)

View File

@ -102,15 +102,51 @@ class InstructBlipVisionEmbeddings(nn.Module):
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embedding
class_pos_embed = self.position_embedding[:, 0, :]
patch_pos_embed = self.position_embedding[:, 1:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
if interpolate_pos_encoding:
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
else:
position_embedding = self.position_embedding
embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
return embeddings
@ -328,6 +364,8 @@ INSTRUCTBLIP_VISION_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
"""
INSTRUCTBLIP_INPUTS_DOCSTRING = r"""
@ -391,6 +429,8 @@ INSTRUCTBLIP_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
"""
@ -505,6 +545,7 @@ class InstructBlipVisionModel(InstructBlipPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
@ -519,7 +560,7 @@ class InstructBlipVisionModel(InstructBlipPreTrainedModel):
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
@ -1327,6 +1368,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -1379,6 +1421,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
@ -1473,6 +1516,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
qformer_attention_mask: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**generate_kwargs,
) -> torch.LongTensor:
"""
@ -1489,6 +1533,8 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
The sequence used as a prompt for the generation.
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
Mask to avoid performing attention on padding token indices.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the positional encoding of the image embeddings.
Returns:
captions (list): A list of strings of length batch_size * num_captions.
@ -1498,7 +1544,11 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
self._preprocess_accelerate()
batch_size = pixel_values.shape[0]
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
image_embeds = self.vision_model(
pixel_values,
return_dict=True,
interpolate_pos_encoding=interpolate_pos_encoding,
).last_hidden_state
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)

View File

@ -1381,6 +1381,20 @@ class BlipModelIntegrationTest(unittest.TestCase):
[30522, 1037, 3861, 1997, 1037, 2450, 1998, 2014, 3899, 2006, 1996, 3509, 102],
)
def test_inference_interpolate_pos_encoding(self):
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(torch_device)
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
processor.image_processor.size = {"height": 500, "width": 500}
image = prepare_img()
inputs = processor(images=image, return_tensors="pt").to(torch_device)
predictions = model.generate(**inputs, interpolate_pos_encoding=True)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
self.assertEqual(predictions[0].tolist(), [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 1037, 3899, 102])
self.assertEqual(generated_text, "a woman sitting on the beach with a dog")
def test_inference_vqa(self):
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(torch_device)
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")

View File

@ -882,6 +882,22 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
)
self.assertEqual(generated_text, "it's not a city, it's a beach")
def test_inference_interpolate_pos_encoding(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
).to(torch_device)
processor.image_processor.size = {"height": 500, "width": 500}
image = prepare_img()
inputs = processor(images=image, return_tensors="pt").to(torch_device)
predictions = model.generate(**inputs, interpolate_pos_encoding=True)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 8, 2335, 15, 5, 4105, 50118])
self.assertEqual(generated_text, "a woman and dog on the beach")
def test_inference_opt_batched_beam_search(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(

View File

@ -612,3 +612,24 @@ class InstructBlipModelIntegrationTest(unittest.TestCase):
generated_text,
"The image depicts a man ironing clothes on the back of a yellow van in the middle of a busy city street. The man is wearing a yellow shirt with a bright yellow tie, and he is using an ironing board to complete his task. The image is unusual due to the fact that it shows a man ironing clothes on the back of a van in the middle of a busy city street. It is possible that the man is trying to save money by doing his laundry on the back of the van, but it is also possible that he is trying to save time by doing his laundry on the back of the van in the middle of a busy city street. Regardless of the reason for the man's actions, it is clear that he is trying to save time by doing his laundry on the back of the van in the middle of a busy city street.",
)
def test_inference_interpolate_pos_encoding(self):
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl")
model = InstructBlipForConditionalGeneration.from_pretrained(
"Salesforce/instructblip-flan-t5-xl",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
).to(torch_device)
processor.image_processor.size = {"height": 500, "width": 500}
image = prepare_img()
prompt = "What's in the image?"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device)
predictions = model.generate(**inputs, interpolate_pos_encoding=True)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
self.assertEqual(
predictions[0].tolist(), [0, 37, 1023, 753, 3, 9, 2335, 3823, 30, 8, 2608, 28, 3, 9, 1782, 5, 1]
)
self.assertEqual(generated_text, "The image features a woman sitting on the beach with a dog.")