Correct llava mask & fix missing setter for `vocab_size` (#29389)
* correct llava mask * fix vipllava as wlel * mask out embedding for padding tokens * add test * fix style * add setter * fix test on suggestion
This commit is contained in:
parent
aa17cf986f
commit
13b23704a8
|
@ -147,6 +147,10 @@ class LlavaConfig(PretrainedConfig):
|
|||
)
|
||||
return self._vocab_size
|
||||
|
||||
@vocab_size.setter
|
||||
def vocab_size(self, value):
|
||||
self._vocab_size = value
|
||||
|
||||
def to_dict(self):
|
||||
output = super().to_dict()
|
||||
output.pop("_vocab_size", None)
|
||||
|
|
|
@ -344,6 +344,12 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
|
||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
|
||||
final_embedding[batch_indices, indices_to_mask] = 0
|
||||
|
||||
if labels is None:
|
||||
final_labels = None
|
||||
|
||||
|
@ -449,10 +455,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
# Get the target length
|
||||
target_seqlen = first_layer_past_key_value.shape[-1] + 1
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
@ -467,7 +474,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
|
||||
outputs = self.language_model(
|
||||
|
|
|
@ -356,7 +356,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
|||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
|
||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||
# 1. Create a mask to know where special image tokens are
|
||||
special_image_token_mask = input_ids == self.config.image_token_index
|
||||
|
@ -418,6 +417,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
|||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
|
||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
|
||||
final_embedding[batch_indices, indices_to_mask] = 0
|
||||
|
||||
if labels is None:
|
||||
final_labels = None
|
||||
|
||||
|
|
|
@ -347,6 +347,12 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
|
||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
|
||||
final_embedding[batch_indices, indices_to_mask] = 0
|
||||
|
||||
if labels is None:
|
||||
final_labels = None
|
||||
|
||||
|
@ -442,11 +448,11 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||
# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-1) == 0)
|
||||
|
||||
# Get the target length
|
||||
target_seqlen = first_layer_past_key_value.shape[-2] + 1
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
@ -461,7 +467,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
|
||||
outputs = self.language_model(
|
||||
|
|
|
@ -27,7 +27,14 @@ from transformers import (
|
|||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
@ -470,10 +477,45 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
|
||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
|
||||
|
||||
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_batched_generation(self):
|
||||
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf").to(torch_device)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
prompt1 = "<image>\n<image>\nUSER: What's the the difference of two images?\nASSISTANT:"
|
||||
prompt2 = "<image>\nUSER: Describe the image.\nASSISTANT:"
|
||||
prompt3 = "<image>\nUSER: Describe the image.\nASSISTANT:"
|
||||
url1 = "https://images.unsplash.com/photo-1552053831-71594a27632d?q=80&w=3062&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
url2 = "https://images.unsplash.com/photo-1617258683320-61900b281ced?q=80&w=3087&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
image1 = Image.open(requests.get(url1, stream=True).raw)
|
||||
image2 = Image.open(requests.get(url2, stream=True).raw)
|
||||
|
||||
inputs = processor(
|
||||
text=[prompt1, prompt2, prompt3],
|
||||
images=[image1, image2, image1, image2],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device)
|
||||
|
||||
model = model.eval()
|
||||
|
||||
EXPECTED_OUTPUT = [
|
||||
"\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog holding a flower in one",
|
||||
"\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding",
|
||||
"\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama",
|
||||
]
|
||||
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=20)
|
||||
outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
self.assertEqual(outputs, EXPECTED_OUTPUT)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_llava_index_error_bug(self):
|
||||
|
|
Loading…
Reference in New Issue