transformers/docs/source/en/model_doc/video_llava.md

8.8 KiB

Video-LLaVA

Overview

Video-LLaVa is an open-source multimodal LLM trained by fine-tuning LlamA/Vicuna on multimodal instruction-following data generated by Llava1.5 and VideChat. It is an auto-regressive language model, based on the transformer architecture. Video-LLaVa unifies visual representations to the language feature space, and enables an LLM to perform visual reasoning capabilities on both images and videos simultaneously.

The Video-LLaVA model was proposed in Video-LLaVA: Learning United Visual Representation by Alignment Before Projection by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan.

The abstract from the paper is the following:

The Large Vision-Language Model (LVLM) has enhanced the performance of various downstream tasks in visual-language understanding. Most existing approaches encode images and videos into separate feature spaces, which are then fed as inputs to large language models. However, due to the lack of unified tokenization for images and videos, namely misalignment before projection, it becomes challenging for a Large Language Model (LLM) to learn multi-modal interactions from several poor projection layers. In this work, we unify visual representation into the language feature space to advance the foundational LLM towards a unified LVLM. As a result, we establish a simple but robust LVLM baseline, Video-LLaVA, which learns from a mixed dataset of images and videos, mutually enhancing each other. Video-LLaVA achieves superior performances on a broad range of 9 image benchmarks across 5 image question-answering datasets and 4 image benchmark toolkits. Additionally, our Video-LLaVA also outperforms Video-ChatGPT by 5.8%, 9.9%, 18.6%, and 10.1% on MSRVTT, MSVD, TGIF, and ActivityNet, respectively. Notably, extensive experiments demonstrate that Video-LLaVA mutually benefits images and videos within a unified visual representation, outperforming models designed specifically for images or videos. We aim for this work to provide modest insights into the multi-modal inputs for the LLM

Usage tips:

  • We advise users to use padding_side="left" when computing batched generation as it leads to more accurate results. Simply make sure to call processor.tokenizer.padding_side = "left" before generating.

  • Note the model has not been explicitly trained to process multiple images/videos in the same prompt, although this is technically possible, you may experience inaccurate results.

  • Note that the video inputs should have exactly 8 frames at the input, since the models were trained in that setting.

This model was contributed by RaushanTurganbay. The original code can be found here.

Usage example

Single Media Mode

The model can accept both images and videos as input. Here's an example code for inference in half-precision (torch.float16):

import av
import torch
import numpy as np
from transformers import VideoLlavaForConditionalGeneration, VideoLlavaProcessor

def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

# Load the model in half-precision
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", torch_dtype=torch.float16, device_map="auto")
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")

# Load the video as an np.arrau, sampling uniformly 8 frames
video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
container = av.open(video_path)
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
video = read_video_pyav(container, indices)

# For better results, we recommend to prompt the model in the following format
prompt = "USER: <video>Why is this funny? ASSISTANT:"
inputs = processor(text=prompt, videos=video, return_tensors="pt")

out = model.generate(**inputs, max_new_tokens=60)
processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)

For multiple turns conversation change the prompt format to:

"USER: <video>What do you see in this video? ASSISTANT: A baby reading a book. USER: Why is the it funny? ASSISTANT:"

Mixed Media Mode

The model can also generate from an interleaved image-video inputs. However note, that it was not trained in interleaved image-video setting which might affect the performance. Below is an example usage for mixed media input, add the following lines to the above code snippet:

from PIL import Image
import requests

# Generate from image and video mixed inputs
# Load and image and write a new prompt
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "USER: <image> How many cats are there in the image? ASSISTANT: There are two cats. USER: <video>Why is this video funny? ASSISTANT:"

inputs = processor(text=prompt, images=image, videos=clip, padding=True, return_tensors="pt")

# Generate
generate_ids = model.generate(**inputs, max_length=50)
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

Model optimization

Quantization using Bitsandbytes for memory efficiency

The model can be loaded in lower bits, significantly reducing memory burden while maintaining the performance of the original model. his allows for efficient deployment on resource-constrained cases.

First make sure to install bitsandbytes by running pip install bitsandbytes and to have access to a CUDA compatible GPU device. Load the quantized model by simply adding BitsAndBytesConfig as shown below:

from transformers import VideoLlavaForConditionalGeneration, BitsAndBytesConfig

# specify how to quantize the model
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", quantization_config=quantization_config, device_map="auto")

Flash-Attention 2 to speed-up generation

Additionally, we can greatly speed-up model inference by using Flash Attention, which is a faster implementation of the attention mechanism used inside the model.

First, make sure to install the latest version of Flash Attention 2:

pip install -U flash-attn --no-build-isolation

Also, you should have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the flash attention repository. FlashAttention-2 can only be used when a model is loaded in torch.float16 or torch.bfloat16.

To load and run a model using Flash Attention-2, simply add attn_implementation="flash_attention_2" when loading the model as follows:

from transformers import VideoLlavaForConditionalGeneration

model = VideoLlavaForConditionalGeneration.from_pretrained(
    "LanguageBind/Video-LLaVA-7B-hf", 
    torch_dtype=torch.float16, 
    attn_implementation="flash_attention_2",
).to(0)

VideoLlavaConfig

autodoc VideoLlavaConfig

VideoLlavaImageProcessor

autodoc VideoLlavaImageProcessor

VideoLlavaProcessor

autodoc VideoLlavaProcessor

VideoLlavaForConditionalGeneration

autodoc VideoLlavaForConditionalGeneration - forward