Update video-llava docs (#30935)

* update video-llava

* Update docs/source/en/model_doc/video_llava.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay 2024-05-22 16:56:41 +05:00 committed by GitHub
parent edb14eba64
commit 934e1b84e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 91 additions and 19 deletions

View File

@ -42,21 +42,28 @@ a unified visual representation, outperforming models designed specifically for
work to provide modest insights into the multi-modal inputs
for the LLM*
Tips:
## Usage tips:
- We advise users to use padding_side="left" when computing batched generation as it leads to more accurate results. Simply make sure to call processor.tokenizer.padding_side = "left" before generating.
- Note the model has not been explicitly trained to process multiple images/videos in the same prompt, although this is technically possible, you may experience inaccurate results.
- For better results, we recommend users prompt the model with the correct prompt format:
- Note that the video inputs should have exactly 8 frames at the input, since the models were trained in that setting.
This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanTurganbay).
The original code can be found [here](https://github.com/PKU-YuanGroup/Video-LLaVA).
## Usage example
### Single Media Mode
The model can accept both images and videos as input. Here's an example code for inference in half-precision (`torch.float16`):
```python
import av
import torch
import numpy as np
import requests
from PIL import Image
from transformers import VideoLlavaForConditionalGeneration, VideoLlavaProcessor
def read_video_pyav(container, indices):
@ -79,36 +86,99 @@ def read_video_pyav(container, indices):
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", device_map="auto")
# Load the model in half-precision
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", torch_dtype=torch.float16, device_map="auto")
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
# Load the video as an np.arrau, sampling uniformly 8 frames
video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
container = av.open(video_path)
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
video = read_video_pyav(container, indices)
# For better results, we recommend to prompt the model in the following format
prompt = "USER: <video>Why is this funny? ASSISTANT:"
inputs = processor(text=prompt, videos=video, return_tensors="pt")
out = model.generate(**inputs, max_new_tokens=40)
print(processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True))
out = model.generate(**inputs, max_new_tokens=60)
processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)
```
For multiple turns conversation change the prompt to:
For multiple turns conversation change the prompt format to:
```bash
"USER: <video>What do you see in this video? ASSISTANT: A baby reading a book. USER: Why is the it funny? ASSISTANT:"
```
- Note that the video inputs should have exactly 8 frames at the input, since the models were trained in that setting.
### Mixed Media Mode
The model can also generate from an interleaved image-video inputs. However note, that it was not trained in interleaved image-video setting which might affect the performance. Below is an example usage for mixed media input, add the following lines to the above code snippet:
```python
from PIL import Image
import requests
# Generate from image and video mixed inputs
# Load and image and write a new prompt
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "USER: <image> How many cats are there in the image? ASSISTANT: There are two cats. USER: <video>Why is this video funny? ASSISTANT:"
inputs = processor(text=prompt, images=image, videos=clip, padding=True, return_tensors="pt")
# Generate
generate_ids = model.generate(**inputs, max_length=50)
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
```
## Model optimization
### Quantization using Bitsandbytes for memory efficiency
The model can be loaded in lower bits, significantly reducing memory burden while maintaining the performance of the original model. his allows for efficient deployment on resource-constrained cases.
First make sure to install bitsandbytes by running `pip install bitsandbytes` and to have access to a CUDA compatible GPU device. Load the quantized model by simply adding [`BitsAndBytesConfig`](../main_classes/quantization#transformers.BitsAndBytesConfig) as shown below:
```python
from transformers import VideoLlavaForConditionalGeneration, BitsAndBytesConfig
This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanTurganbay).
The original code can be found [here](https://github.com/PKU-YuanGroup/Video-LLaVA).
# specify how to quantize the model
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", quantization_config=quantization_config, device_map="auto")
```
### Flash-Attention 2 to speed-up generation
Additionally, we can greatly speed-up model inference by using [Flash Attention](../perf_train_gpu_one.md#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
First, make sure to install the latest version of Flash Attention 2:
```bash
pip install -U flash-attn --no-build-isolation
```
Also, you should have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). FlashAttention-2 can only be used when a model is loaded in `torch.float16` or `torch.bfloat16`.
To load and run a model using Flash Attention-2, simply add `attn_implementation="flash_attention_2"` when loading the model as follows:
```python
from transformers import VideoLlavaForConditionalGeneration
model = VideoLlavaForConditionalGeneration.from_pretrained(
"LanguageBind/Video-LLaVA-7B-hf",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
).to(0)
```
## VideoLlavaConfig

View File

@ -123,6 +123,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
config_class = VideoLlavaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["VideoLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_no_split_modules = ["VideoLlavaVisionAttention"]
@ -474,22 +475,23 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=80)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'USER: Why is this video funny? ASSISTANT: The video is funny because the baby is sitting on the bed and reading a book, which is an unusual and amusing sight.Ъ'
"USER: Why is this video funny? ASSISTANT: The video is funny because the baby is playing with a Wii remote while sitting on the floor, and the baby is wearing glasses.Ъ. The baby's actions are amusing because it is a young child trying to interact with a video game, which is not a typical activity for a"
>>> # to generate from image and video mix
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = [
"USER: <image> How many cats are there in the image? ASSISTANT:",
"USER: <video>Why is this video funny? ASSISTANT:"
]
... "USER: <image> How many cats do you see? ASSISTANT:",
... "USER: <video>Why is this video funny? ASSISTANT:"
... ]
>>> inputs = processor(text=prompt, images=image, videos=clip, padding=True, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=50)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
['USER: How many cats are there in the image? ASSISTANT: There are two cats in the image.\nHow many cats are sleeping on the couch?\nThere are', 'USER: Why is this video funny? ASSISTANT: The video is funny because the baby is sitting on the bed and reading a book, which is an unusual and amusing']
```"""
['USER: How many cats do you see? ASSISTANT: There are two cats visible in the image. (or three, if you count the one in the background).', 'USER: Why is this video funny? ASSISTANT: The video is funny because it shows a baby sitting on a bed and playing with a Wii remote.Ъ. The baby is holding the remote']
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (