277 lines
8.0 KiB
Markdown
277 lines
8.0 KiB
Markdown
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||
the License. You may obtain a copy of the License at
|
||
|
||
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||
specific language governing permissions and limitations under the License.
|
||
|
||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||
rendered properly in your Markdown viewer.
|
||
|
||
-->
|
||
|
||
|
||
# Image captioning
|
||
|
||
[[open-in-colab]]
|
||
|
||
Image captioning is the task of predicting a caption for a given image. Common real world applications of it include
|
||
aiding visually impaired people that can help them navigate through different situations. Therefore, image captioning
|
||
helps to improve content accessibility for people by describing images to them.
|
||
|
||
This guide will show you how to:
|
||
|
||
* Fine-tune an image captioning model.
|
||
* Use the fine-tuned model for inference.
|
||
|
||
Before you begin, make sure you have all the necessary libraries installed:
|
||
|
||
```bash
|
||
pip install transformers datasets evaluate -q
|
||
pip install jiwer -q
|
||
```
|
||
|
||
We encourage you to log in to your Hugging Face account so you can upload and share your model with the community. When prompted, enter your token to log in:
|
||
|
||
|
||
```python
|
||
from huggingface_hub import notebook_login
|
||
|
||
notebook_login()
|
||
```
|
||
|
||
## Load the Pokémon BLIP captions dataset
|
||
|
||
Use the 🤗 Dataset library to load a dataset that consists of {image-caption} pairs. To create your own image captioning dataset
|
||
in PyTorch, you can follow [this notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/GIT/Fine_tune_GIT_on_an_image_captioning_dataset.ipynb).
|
||
|
||
|
||
```python
|
||
from datasets import load_dataset
|
||
|
||
ds = load_dataset("lambdalabs/pokemon-blip-captions")
|
||
ds
|
||
```
|
||
```bash
|
||
DatasetDict({
|
||
train: Dataset({
|
||
features: ['image', 'text'],
|
||
num_rows: 833
|
||
})
|
||
})
|
||
```
|
||
|
||
The dataset has two features, `image` and `text`.
|
||
|
||
<Tip>
|
||
|
||
Many image captioning datasets contain multiple captions per image. In those cases, a common strategy is to randomly sample a caption amongst the available ones during training.
|
||
|
||
</Tip>
|
||
|
||
Split the dataset’s train split into a train and test set with the [`~datasets.Dataset.train_test_split`] method:
|
||
|
||
|
||
```python
|
||
ds = ds["train"].train_test_split(test_size=0.1)
|
||
train_ds = ds["train"]
|
||
test_ds = ds["test"]
|
||
```
|
||
|
||
Let's visualize a couple of samples from the training set.
|
||
|
||
|
||
```python
|
||
from textwrap import wrap
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
|
||
|
||
def plot_images(images, captions):
|
||
plt.figure(figsize=(20, 20))
|
||
for i in range(len(images)):
|
||
ax = plt.subplot(1, len(images), i + 1)
|
||
caption = captions[i]
|
||
caption = "\n".join(wrap(caption, 12))
|
||
plt.title(caption)
|
||
plt.imshow(images[i])
|
||
plt.axis("off")
|
||
|
||
|
||
sample_images_to_visualize = [np.array(train_ds[i]["image"]) for i in range(5)]
|
||
sample_captions = [train_ds[i]["text"] for i in range(5)]
|
||
plot_images(sample_images_to_visualize, sample_captions)
|
||
```
|
||
|
||
<div class="flex justify-center">
|
||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sample_training_images_image_cap.png" alt="Sample training images"/>
|
||
</div>
|
||
|
||
## Preprocess the dataset
|
||
|
||
Since the dataset has two modalities (image and text), the pre-processing pipeline will preprocess images and the captions.
|
||
|
||
To do so, load the processor class associated with the model you are about to fine-tune.
|
||
|
||
```python
|
||
from transformers import AutoProcessor
|
||
|
||
checkpoint = "microsoft/git-base"
|
||
processor = AutoProcessor.from_pretrained(checkpoint)
|
||
```
|
||
|
||
The processor will internally pre-process the image (which includes resizing, and pixel scaling) and tokenize the caption.
|
||
|
||
```python
|
||
def transforms(example_batch):
|
||
images = [x for x in example_batch["image"]]
|
||
captions = [x for x in example_batch["text"]]
|
||
inputs = processor(images=images, text=captions, padding="max_length")
|
||
inputs.update({"labels": inputs["input_ids"]})
|
||
return inputs
|
||
|
||
|
||
train_ds.set_transform(transforms)
|
||
test_ds.set_transform(transforms)
|
||
```
|
||
|
||
With the dataset ready, you can now set up the model for fine-tuning.
|
||
|
||
## Load a base model
|
||
|
||
Load the ["microsoft/git-base"](https://huggingface.co/microsoft/git-base) into a [`AutoModelForCausalLM`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM) object.
|
||
|
||
|
||
```python
|
||
from transformers import AutoModelForCausalLM
|
||
|
||
model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
||
```
|
||
|
||
## Evaluate
|
||
|
||
Image captioning models are typically evaluated with the [Rouge Score](https://huggingface.co/spaces/evaluate-metric/rouge) or [Word Error Rate](https://huggingface.co/spaces/evaluate-metric/wer). For this guide, you will use the Word Error Rate (WER).
|
||
|
||
We use the 🤗 Evaluate library to do so. For potential limitations and other gotchas of the WER, refer to [this guide](https://huggingface.co/spaces/evaluate-metric/wer).
|
||
|
||
|
||
```python
|
||
from evaluate import load
|
||
import torch
|
||
|
||
wer = load("wer")
|
||
|
||
|
||
def compute_metrics(eval_pred):
|
||
logits, labels = eval_pred
|
||
predicted = logits.argmax(-1)
|
||
decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)
|
||
decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)
|
||
wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels)
|
||
return {"wer_score": wer_score}
|
||
```
|
||
|
||
## Train!
|
||
|
||
Now, you are ready to start fine-tuning the model. You will use the 🤗 [`Trainer`] for this.
|
||
|
||
First, define the training arguments using [`TrainingArguments`].
|
||
|
||
|
||
```python
|
||
from transformers import TrainingArguments, Trainer
|
||
|
||
model_name = checkpoint.split("/")[1]
|
||
|
||
training_args = TrainingArguments(
|
||
output_dir=f"{model_name}-pokemon",
|
||
learning_rate=5e-5,
|
||
num_train_epochs=50,
|
||
fp16=True,
|
||
per_device_train_batch_size=32,
|
||
per_device_eval_batch_size=32,
|
||
gradient_accumulation_steps=2,
|
||
save_total_limit=3,
|
||
eval_strategy="steps",
|
||
eval_steps=50,
|
||
save_strategy="steps",
|
||
save_steps=50,
|
||
logging_steps=50,
|
||
remove_unused_columns=False,
|
||
push_to_hub=True,
|
||
label_names=["labels"],
|
||
load_best_model_at_end=True,
|
||
)
|
||
```
|
||
|
||
Then pass them along with the datasets and the model to 🤗 Trainer.
|
||
|
||
```python
|
||
trainer = Trainer(
|
||
model=model,
|
||
args=training_args,
|
||
train_dataset=train_ds,
|
||
eval_dataset=test_ds,
|
||
compute_metrics=compute_metrics,
|
||
)
|
||
```
|
||
|
||
To start training, simply call [`~Trainer.train`] on the [`Trainer`] object.
|
||
|
||
```python
|
||
trainer.train()
|
||
```
|
||
|
||
You should see the training loss drop smoothly as training progresses.
|
||
|
||
Once training is completed, share your model to the Hub with the [`~Trainer.push_to_hub`] method so everyone can use your model:
|
||
|
||
|
||
```python
|
||
trainer.push_to_hub()
|
||
```
|
||
|
||
## Inference
|
||
|
||
Take a sample image from `test_ds` to test the model.
|
||
|
||
|
||
```python
|
||
from PIL import Image
|
||
import requests
|
||
|
||
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"
|
||
image = Image.open(requests.get(url, stream=True).raw)
|
||
image
|
||
```
|
||
|
||
<div class="flex justify-center">
|
||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/test_image_image_cap.png" alt="Test image"/>
|
||
</div>
|
||
|
||
Prepare image for the model.
|
||
|
||
```python
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
||
inputs = processor(images=image, return_tensors="pt").to(device)
|
||
pixel_values = inputs.pixel_values
|
||
```
|
||
|
||
Call [`generate`] and decode the predictions.
|
||
|
||
```python
|
||
generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
|
||
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||
print(generated_caption)
|
||
```
|
||
```bash
|
||
a drawing of a pink and blue pokemon
|
||
```
|
||
|
||
Looks like the fine-tuned model generated a pretty good caption!
|