From 5d0bf59b4d5be72c8c956e0240a67d7c3100fdaf Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 20 May 2024 13:45:56 +0500 Subject: [PATCH] LLaVa-Next: Update docs with batched inference (#30857) * update docs with batch ex * Update docs/source/en/model_doc/llava_next.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * accept nested list of img --------- Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> --- docs/source/en/model_doc/llava_next.md | 41 +++++++++++++++++++ .../llava_next/image_processing_llava_next.py | 26 +++++++++++- .../test_image_processor_llava_next.py | 18 ++++++++ 3 files changed, 84 insertions(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/llava_next.md b/docs/source/en/model_doc/llava_next.md index a2a3913fca..a4a1419ee0 100644 --- a/docs/source/en/model_doc/llava_next.md +++ b/docs/source/en/model_doc/llava_next.md @@ -68,6 +68,8 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/ ## Usage example +### Single image inference + Here's how to load the model and perform inference in half-precision (`torch.float16`): ```python @@ -94,6 +96,45 @@ output = model.generate(**inputs, max_new_tokens=100) print(processor.decode(output[0], skip_special_tokens=True)) ``` +### Multi image inference + +LLaVa-Next can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). Here is how you can do it: + +```python +import requests +from PIL import Image +import torch +from transformers import AutoProcessor, LlavaNextForConditionalGeneration + +# Load the model in half-precision +model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, device_map="auto") +processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") + +# Get three different images +url = "https://www.ilankelman.org/stopsigns/australia.jpg" +image_stop = Image.open(requests.get(url, stream=True).raw) + +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image_cats = Image.open(requests.get(url, stream=True).raw) + +url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" +image_snowman = Image.open(requests.get(url, stream=True).raw) + +# Prepare a batched prompt, where the first one is a multi-turn conversation and the second is not +prompt = [ + "[INST] \nWhat is shown in this image? [/INST] There is a red stop sign in the image. [INST] \nWhat about this image? How many cats do you see [/INST]", + "[INST] \nWhat is shown in this image? [/INST]" +] + +# We can simply feed images in the order they have to be used in the text prompt +# Each "" token uses one image leaving the next for the subsequent "" tokens +inputs = processor(text=prompt, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device) + +# Generate +generate_ids = model.generate(**inputs, max_new_tokens=30) +processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) +``` + ## Model optimization ### Quantization using Bitsandbytes diff --git a/src/transformers/models/llava_next/image_processing_llava_next.py b/src/transformers/models/llava_next/image_processing_llava_next.py index 34de0f4db0..6295fb9562 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next.py +++ b/src/transformers/models/llava_next/image_processing_llava_next.py @@ -37,6 +37,7 @@ from ...image_utils import ( get_image_size, infer_channel_dimension_format, is_scaled_image, + is_valid_image, make_list_of_images, to_numpy_array, valid_images, @@ -52,6 +53,29 @@ if is_vision_available(): from PIL import Image +def make_batched_images(images) -> List[List[ImageInput]]: + """ + Accepts images in list or nested list format, and makes a list of images for preprocessing. + + Args: + images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): + The input image. + + Returns: + list: A list of images. + """ + if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): + return [img for img_list in images for img in img_list] + + elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): + return images + + elif is_valid_image(images): + return [images] + + raise ValueError(f"Could not make batched video from {images}") + + def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]: """ Divides an image into patches of a specified size. @@ -651,7 +675,7 @@ class LlavaNextImageProcessor(BaseImageProcessor): do_pad = do_pad if do_pad is not None else self.do_pad do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - images = make_list_of_images(images) + images = make_batched_images(images) if not valid_images(images): raise ValueError( diff --git a/tests/models/llava_next/test_image_processor_llava_next.py b/tests/models/llava_next/test_image_processor_llava_next.py index 7369f8a918..8b1f98bbca 100644 --- a/tests/models/llava_next/test_image_processor_llava_next.py +++ b/tests/models/llava_next/test_image_processor_llava_next.py @@ -199,3 +199,21 @@ class LlavaNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): @unittest.skip("LlavaNextImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy def test_call_numpy_4_channels(self): pass + + def test_nested_input(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + + # Test batched as a list of images + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched as a nested list of images, where each sublist is one batch + image_inputs_nested = [image_inputs[:3], image_inputs[3:]] + encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) + + # Image processor should return same pixel values, independently of ipnut format + self.assertTrue((encoded_images_nested == encoded_images).all())