[`Llava`] Fix llava index errors (#28032)
* fix llava index errors * forward contrib credits from original implementation and fix * better fix * final fixes and fix all tests * fix * fix nit * fix tests * add regression tests --------- Co-authored-by: gullalc <gullalc@users.noreply.github.com>
This commit is contained in:
parent
68fa1e855b
commit
29e7a1e183
|
@ -431,8 +431,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0]
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0)
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
# Get the target length
|
||||
target_seqlen = first_layer_past_key_value.shape[-1] + 1
|
||||
|
||||
|
|
|
@ -430,10 +430,13 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0]
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0)
|
||||
first_layer_past_key_value = past_key_values[0][0][:, 0, :, :]
|
||||
|
||||
# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-1) == 0)
|
||||
|
||||
# Get the target length
|
||||
target_seqlen = first_layer_past_key_value.shape[-1] + 1
|
||||
target_seqlen = first_layer_past_key_value.shape[-2] + 1
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
||||
|
|
|
@ -247,6 +247,53 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||
model_id = "llava-hf/llava-1.5-7b-hf"
|
||||
|
||||
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
prompts = [
|
||||
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
|
||||
"USER: <image>\nWhat is this?\nASSISTANT:",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(prompts, images=[image1, image2], return_tensors="pt", padding=True)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip
|
||||
|
||||
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_batch(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/bakLlava-v1-hf", load_in_4bit=True)
|
||||
# The first batch is longer in terms of text, but only has 1 image. The second batch will be padded in text, but the first will be padded because images take more space!.
|
||||
prompts = [
|
||||
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
|
||||
"USER: <image>\nWhat is this?\nASSISTANT:",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = self.processor(prompts, images=[image1, image2], return_tensors="pt", padding=True)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along', 'USER: \nWhat is this?\nASSISTANT: Cats'] # fmt: skip
|
||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_llama_batched_regression(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model_id = "llava-hf/llava-1.5-7b-hf"
|
||||
|
||||
# Multi-image & multi-prompt (e.g. 3 images and 2 prompts now fails with SDPA, this tests if "eager" works as before)
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/llava-1.5-7b-hf", load_in_4bit=True, attn_implementation="eager"
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_id, pad_token="<pad>")
|
||||
|
||||
prompts = [
|
||||
|
@ -260,26 +307,28 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: the water is calm and clear\n\nThe image shows a wooden pier on a lake, with a', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
|
||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
|
||||
|
||||
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_batch(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/bakLlava-v1-hf", load_in_4bit=True)
|
||||
# The first batch is longer in terms of text, but only has 1 image. The second batch will be padded in text, but the first will be padded because images take more space!.
|
||||
prompts = [
|
||||
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
|
||||
"USER: <image>\nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: <image>\nAnd this?\nASSISTANT:",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
def test_llava_index_error_bug(self):
|
||||
# This is a reproducer of https://github.com/huggingface/transformers/pull/28032 and makes sure it does not happen anymore
|
||||
# Please refer to that PR, or specifically https://github.com/huggingface/transformers/pull/28032#issuecomment-1860650043 for
|
||||
# more details
|
||||
model_id = "llava-hf/llava-1.5-7b-hf"
|
||||
model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
|
||||
|
||||
inputs = self.processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
# Simulate a super long prompt
|
||||
user_prompt = "Describe the image:?\n" * 200
|
||||
prompt = f"USER: <image>\n{user_prompt}ASSISTANT:"
|
||||
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
|
||||
EXPECTED_DECODED_TEXT = ['\nUSER: What are the things I should be cautious about when I visit this place? What should I bring with me\nASSISTANT: When visiting this place, bring a camera, Park rules may apply, Per person, Sunrise over', '\nUSER: What is this?\nASSISTANT: Two cats lying on a bed!\nUSER: And this? a dock on a lake with two cats on it (Photo credit: Jocelyn R.'] # fmt: skip
|
||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
|
||||
# Make sure that `generate` works
|
||||
_ = model.generate(**inputs, max_new_tokens=20)
|
||||
|
|
Loading…
Reference in New Issue