Fix `_merge_input_ids_with_image_features` for llava model (#28333)
* fix `_merge_input_ids_with_image_features` for llava model * Update src/transformers/models/llava/modeling_llava.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * adress comments * style and tests * ooops * test the backward too * Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update tests/models/vipllava/test_modeling_vipllava.py * style and quality --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
parent
976189a6df
commit
0f2f0c634f
|
@ -276,9 +276,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||
self.vocab_size = model_embeds.num_embeddings
|
||||
return model_embeds
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self, image_features, inputs_embeds, input_ids, attention_mask, position_ids
|
||||
):
|
||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||
|
@ -307,6 +305,10 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||
final_attention_mask = torch.zeros(
|
||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||
)
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
|
@ -321,6 +323,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
||||
if labels is not None:
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
||||
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
||||
|
@ -335,7 +339,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
return final_embedding, final_attention_mask, position_ids
|
||||
|
||||
if labels is None:
|
||||
final_labels = None
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
|
@ -420,8 +428,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||
)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, position_ids
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
if labels is None:
|
||||
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
|
||||
|
|
|
@ -284,9 +284,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||
self.vocab_size = model_embeds.num_embeddings
|
||||
return model_embeds
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self, image_features, inputs_embeds, input_ids, attention_mask, position_ids
|
||||
):
|
||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||
|
@ -315,6 +313,10 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||
final_attention_mask = torch.zeros(
|
||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||
)
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
|
@ -329,6 +331,8 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
||||
if labels is not None:
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
||||
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
||||
|
@ -343,7 +347,11 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
return final_embedding, final_attention_mask, position_ids
|
||||
|
||||
if labels is None:
|
||||
final_labels = None
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||
|
||||
@add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
|
@ -419,8 +427,8 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||
image_features = torch.cat(image_features, dim=-1)
|
||||
|
||||
image_features = self.multi_modal_projector(image_features)
|
||||
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, position_ids
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
if labels is None:
|
||||
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
|
||||
|
|
|
@ -26,7 +26,7 @@ from transformers import (
|
|||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
@ -332,3 +332,41 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||
|
||||
# Make sure that `generate` works
|
||||
_ = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_llava_merge_inputs_error_bug(self):
|
||||
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
|
||||
model_id = "llava-hf/llava-1.5-7b-hf"
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
).to(torch_device)
|
||||
|
||||
# Simulate some user inputs
|
||||
pixel_values = torch.randn(
|
||||
(2, 3, 336, 336),
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
|
||||
[1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900],
|
||||
],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
attention_mask = torch.tensor(
|
||||
[[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
# Make sure that the loss is properly computed
|
||||
loss = model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=input_ids,
|
||||
).loss
|
||||
loss.backward()
|
||||
|
|
|
@ -26,7 +26,7 @@ from transformers import (
|
|||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
@ -214,3 +214,41 @@ class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||
|
||||
EXPECTED_OUTPUT = "USER: <image> \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on"
|
||||
self.assertEqual(processor.decode(outputs[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_vipllava_merge_inputs_error_bug(self):
|
||||
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
|
||||
model_id = "llava-hf/vip-llava-7b-hf"
|
||||
model = VipLlavaForConditionalGeneration.from_pretrained(
|
||||
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
).to(torch_device)
|
||||
|
||||
# Simulate some user inputs
|
||||
pixel_values = torch.randn(
|
||||
(2, 3, 336, 336),
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
|
||||
[1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900],
|
||||
],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
attention_mask = torch.tensor(
|
||||
[[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
# Make sure that the loss is properly computed
|
||||
loss = model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=input_ids,
|
||||
).loss
|
||||
loss.backward()
|
||||
|
|
Loading…
Reference in New Issue