Add examples for detection models finetuning (#30422)

* Training script for object detection

* Evaluation script for object detection

* Training script for object detection with eval loop outside trainer

* Trainer DETR finetuning

* No trainer DETR finetuning

* Eval script

* Refine object detection example with trainer

* Remove commented code and enable telemetry

* No trainer example

* Add requirements for object detection examples

* Add test for trainer example

* Readme draft

* Fix uploading to HUB

* Readme improvements

* Update eval script

* Adding tests for object-detection examples

* Add object-detection example

* Add object-detection resources to docs

* Update README with custom dataset instructions

* Update year

* Replace valid with validation

* Update instructions for custom dataset

* Remove eval script

* Remove use_auth_token

* Add copied from and telemetry

* Fixup

* Update readme

* Fix id2label

* Fix links in docs

* Update examples/pytorch/object-detection/run_object_detection.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update examples/pytorch/object-detection/run_object_detection.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Move description to the top

* Fix Trainer example

* Update no trainer example

* Update albumentations version

---------

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
Pavel Iakubovskii 2024-05-08 11:42:07 +01:00 committed by GitHub
parent 508c0bfe55
commit 998dbe068b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1612 additions and 5 deletions

View File

@ -33,7 +33,8 @@ This model was contributed by [DepuMeng](https://huggingface.co/DepuMeng). The o
## Resources
- [Object detection task guide](../tasks/object_detection)
- Scripts for finetuning [`ConditionalDetrForObjectDetection`] with [`Trainer`] or [Accelerate](https://huggingface.co/docs/accelerate/index) can be found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection).
- See also: [Object detection task guide](../tasks/object_detection).
## ConditionalDetrConfig

View File

@ -43,6 +43,7 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
<PipelineTag pipeline="object-detection"/>
- Demo notebooks regarding inference + fine-tuning on a custom dataset for [`DeformableDetrForObjectDetection`] can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/Deformable-DETR).
- Scripts for finetuning [`DeformableDetrForObjectDetection`] with [`Trainer`] or [Accelerate](https://huggingface.co/docs/accelerate/index) can be found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection).
- See also: [Object detection task guide](../tasks/object_detection).
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

View File

@ -39,7 +39,8 @@ The original code can be found [here](https://github.com/jozhang97/DETA).
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DETA.
- Demo notebooks for DETA can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/DETA).
- See also: [Object detection task guide](../tasks/object_detection)
- Scripts for finetuning [`DetaForObjectDetection`] with [`Trainer`] or [Accelerate](https://huggingface.co/docs/accelerate/index) can be found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection).
- See also: [Object detection task guide](../tasks/object_detection).
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

View File

@ -162,8 +162,9 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
<PipelineTag pipeline="object-detection"/>
- All example notebooks illustrating fine-tuning [`DetrForObjectDetection`] and [`DetrForSegmentation`] on a custom dataset an be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/DETR).
- See also: [Object detection task guide](../tasks/object_detection)
- All example notebooks illustrating fine-tuning [`DetrForObjectDetection`] and [`DetrForSegmentation`] on a custom dataset can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/DETR).
- Scripts for finetuning [`DetrForObjectDetection`] with [`Trainer`] or [Accelerate](https://huggingface.co/docs/accelerate/index) can be found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection).
- See also: [Object detection task guide](../tasks/object_detection).
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

View File

@ -39,6 +39,7 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
<PipelineTag pipeline="object-detection"/>
- All example notebooks illustrating inference + fine-tuning [`YolosForObjectDetection`] on a custom dataset can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/YOLOS).
- Scripts for finetuning [`YolosForObjectDetection`] with [`Trainer`] or [Accelerate](https://huggingface.co/docs/accelerate/index) can be found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection).
- See also: [Object detection task guide](../tasks/object_detection)
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

View File

@ -46,6 +46,7 @@ Coming soon!
| [**`image-pretraining`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining) | [ImageNet-1k](https://huggingface.co/datasets/imagenet-1k) | ✅ | - |✅ | /
| [**`image-classification`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) | [CIFAR-10](https://huggingface.co/datasets/cifar10) | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb)
| [**`semantic-segmentation`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/semantic-segmentation) | [SCENE_PARSE_150](https://huggingface.co/datasets/scene_parse_150) | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/semantic_segmentation.ipynb)
| [**`object-detection`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection) | [CPPE-5](https://huggingface.co/datasets/cppe-5) | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/object_detection.ipynb)
## Running quick tests

View File

@ -25,4 +25,7 @@ torchaudio
jiwer
librosa
evaluate >= 0.2.0
albumentations
timm
albumentations >= 1.4.5
torchmetrics
pycocotools

View File

@ -0,0 +1,232 @@
<!---
Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Object detection examples
This directory contains 2 scripts that showcase how to fine-tune any model supported by the [`AutoModelForObjectDetection` API](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForObjectDetection) (such as [DETR](https://huggingface.co/docs/transformers/main/en/model_doc/detr), [DETA](https://huggingface.co/docs/transformers/main/en/model_doc/deta), [Deformable DETR](https://huggingface.co/docs/transformers/main/en/model_doc/deformable_detr)) using PyTorch.
Content:
* [PyTorch version, Trainer](#pytorch-version-trainer)
* [PyTorch version, no Trainer](#pytorch-version-no-trainer)
* [Reload and perform inference](#reload-and-perform-inference)
* [Note on custom data](#note-on-custom-data)
## PyTorch version, Trainer
Based on the script [`run_object_detection.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/object-detection/run_object_detection.py).
The script leverages the [🤗 Trainer API](https://huggingface.co/docs/transformers/main_classes/trainer) to automatically take care of the training for you, running on distributed environments right away.
Here we show how to fine-tune a [DETR](https://huggingface.co/facebook/detr-resnet-50) model on the [CPPE-5](https://huggingface.co/datasets/cppe-5) dataset:
```bash
python run_object_detection.py \
--model_name_or_path facebook/detr-resnet-50 \
--dataset_name cppe-5 \
--do_train true \
--do_eval true \
--output_dir detr-finetuned-cppe-5-10k-steps \
--num_train_epochs 100 \
--image_square_size 600 \
--fp16 true \
--learning_rate 5e-5 \
--weight_decay 1e-4 \
--dataloader_num_workers 4 \
--dataloader_prefetch_factor 2 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 1 \
--remove_unused_columns false \
--eval_do_concat_batches false \
--ignore_mismatched_sizes true \
--metric_for_best_model eval_map \
--greater_is_better true \
--load_best_model_at_end true \
--logging_strategy epoch \
--evaluation_strategy epoch \
--save_strategy epoch \
--save_total_limit 2 \
--push_to_hub true \
--push_to_hub_model_id detr-finetuned-cppe-5-10k-steps \
--hub_strategy end \
--seed 1337
```
> Note:
`--eval_do_concat_batches false` is required for correct evaluation of detection models;
`--ignore_mismatched_sizes true` is required to load detection model for finetuning with different number of classes.
The resulting model can be seen here: https://huggingface.co/qubvel-hf/qubvel-hf/detr-resnet-50-finetuned-10k-cppe5. The corresponding Weights and Biases report [here](https://api.wandb.ai/links/qubvel-hf-co/bnm0r5ex). Note that it's always advised to check the original paper to know the details regarding training hyperparameters. Hyperparameters for current example were not tuned. To improve model quality you could try:
- changing image size parameters (`--shortest_edge`/`--longest_edge`)
- changing training parameters, such as learning rate, batch size, warmup, optimizer and many more (see [TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments))
- adding more image augmentations (we created a helpful [HF Space](https://huggingface.co/spaces/qubvel-hf/albumentations-demo) to choose some)
Note that you can replace the model and dataset by simply setting the `model_name_or_path` and `dataset_name` arguments respectively, with model or dataset from the [hub](https://huggingface.co/).
For dataset, make sure it provides labels in the same format as [CPPE-5](https://huggingface.co/datasets/cppe-5) dataset and boxes are provided in [COCO format](https://albumentations.ai/docs/getting_started/bounding_boxes_augmentation/#coco).
![W&B report](https://i.imgur.com/ASNjamQ.png)
## PyTorch version, no Trainer
Based on the script [`run_object_detection_no_trainer.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/object-detection/run_object_detection.py).
The script leverages [🤗 `Accelerate`](https://github.com/huggingface/accelerate), which allows to write your own training loop in PyTorch, but have it run instantly on any (distributed) environment, including CPU, multi-CPU, GPU, multi-GPU and TPU. It also supports mixed precision.
First, run:
```bash
accelerate config
```
and reply to the questions asked regarding the environment on which you'd like to train. Then
```bash
accelerate test
```
that will check everything is ready for training. Finally, you can launch training with
```bash
accelerate launch run_object_detection_no_trainer.py \
--model_name_or_path "facebook/detr-resnet-50" \
--dataset_name cppe-5 \
--output_dir "detr-resnet-50-finetuned" \
--num_train_epochs 100 \
--image_square_size 600 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--checkpointing_steps epoch \
--learning_rate 5e-5 \
--ignore_mismatched_sizes \
--with_tracking \
--push_to_hub
```
and boom, you're training, possibly on multiple GPUs, logging everything to all trackers found in your environment (like Weights and Biases, Tensorboard) and regularly pushing your model to the hub (with the repo name being equal to `args.output_dir` at your HF username) 🤗
With the default settings, the script fine-tunes a [DETR](https://huggingface.co/facebook/detr-resnet-50) model on the [CPPE-5](https://huggingface.co/datasets/cppe-5) dataset. The resulting model can be seen here: https://huggingface.co/qubvel-hf/detr-resnet-50-finetuned-10k-cppe5-no-trainer.
## Reload and perform inference
This means that after training, you can easily load your trained model and perform inference as follows::
```python
import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForObjectDetection
# Name of repo on the hub or path to a local folder
model_name = "qubvel-hf/detr-resnet-50-finetuned-10k-cppe5"
image_processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForObjectDetection.from_pretrained(model_name)
# Load image for inference
url = "https://images.pexels.com/photos/8413299/pexels-photo-8413299.jpeg?auto=compress&cs=tinysrgb&w=630&h=375&dpr=2"
image = Image.open(requests.get(url, stream=True).raw)
# Prepare image for the model
inputs = image_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Post process model predictions
# this include conversion to Pascal VOC format and filtering non confident boxes
width, height = image.size
target_sizes = torch.tensor([height, width]).unsqueeze(0) # add batch dim
results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)
```
And visualize with the following code:
```python
from PIL import ImageDraw
draw = ImageDraw.Draw(image)
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
x, y, x2, y2 = tuple(box)
draw.rectangle((x, y, x2, y2), outline="red", width=1)
draw.text((x, y), model.config.id2label[label.item()], fill="white")
image
```
## Note on custom data
In case you'd like to use the script with custom data, you could prepare your data with the following way:
```bash
custom_dataset/
└── train
├── 0001.jpg
├── 0002.jpg
├── ...
└── metadata.jsonl
└── validation
└── ...
└── test
└── ...
```
Where `metadata.jsonl` is a file with the following structure:
```json
{"file_name": "0001.jpg", "objects": {"bbox": [[302.0, 109.0, 73.0, 52.0]], "categories": [0], "id": [1], "area": [50.0]}}
{"file_name": "0002.jpg", "objects": {"bbox": [[810.0, 100.0, 57.0, 28.0]], "categories": [1], "id": [2], "area": [40.0]}}
...
```
Then, you cat load the dataset with just a few lines of code:
```python
from datasets import load_dataset
# Load dataset
dataset = load_dataset("imagefolder", data_dir="custom_dataset/")
# >>> DatasetDict({
# ... train: Dataset({
# ... features: ['image', 'objects'],
# ... num_rows: 2
# ... })
# ... })
# Push to hub (assumes you have ran the huggingface-cli login command in a terminal/notebook)
dataset.push_to_hub("name of repo on the hub")
# optionally, you can push to a private repo on the hub
# dataset.push_to_hub("name of repo on the hub", private=True)
```
And the final step, for training you should provide id2label mapping in the following way:
```python
id2label = {0: "Car", 1: "Bird", ...}
```
Just find it in code and replace for simplicity, or save `json` locally and with the dataset on the hub!
See also: [Dataset Creation Guide](https://huggingface.co/docs/datasets/image_dataset#create-an-image-dataset)

View File

@ -0,0 +1,5 @@
albumentations >= 1.4.5
timm
datasets
torchmetrics
pycocotools

View File

@ -0,0 +1,523 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
"""Finetuning any 🤗 Transformers model supported by AutoModelForObjectDetection for object detection leveraging the Trainer API."""
import logging
import os
import sys
from dataclasses import dataclass, field
from functools import partial
from typing import Any, List, Mapping, Optional, Tuple, Union
import albumentations as A
import numpy as np
import torch
from datasets import load_dataset
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import transformers
from transformers import (
AutoConfig,
AutoImageProcessor,
AutoModelForObjectDetection,
HfArgumentParser,
Trainer,
TrainingArguments,
)
from transformers.image_processing_utils import BatchFeature
from transformers.image_transforms import center_to_corners_format
from transformers.trainer import EvalPrediction
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.40.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
@dataclass
class ModelOutput:
logits: torch.Tensor
pred_boxes: torch.Tensor
def format_image_annotations_as_coco(
image_id: str, categories: List[int], areas: List[float], bboxes: List[Tuple[float]]
) -> dict:
"""Format one set of image annotations to the COCO format
Args:
image_id (str): image id. e.g. "0001"
categories (List[int]): list of categories/class labels corresponding to provided bounding boxes
areas (List[float]): list of corresponding areas to provided bounding boxes
bboxes (List[Tuple[float]]): list of bounding boxes provided in COCO format
([center_x, center_y, width, height] in absolute coordinates)
Returns:
dict: {
"image_id": image id,
"annotations": list of formatted annotations
}
"""
annotations = []
for category, area, bbox in zip(categories, areas, bboxes):
formatted_annotation = {
"image_id": image_id,
"category_id": category,
"iscrowd": 0,
"area": area,
"bbox": list(bbox),
}
annotations.append(formatted_annotation)
return {
"image_id": image_id,
"annotations": annotations,
}
def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
"""
Convert bounding boxes from YOLO format (x_center, y_center, width, height) in range [0, 1]
to Pascal VOC format (x_min, y_min, x_max, y_max) in absolute coordinates.
Args:
boxes (torch.Tensor): Bounding boxes in YOLO format
image_size (Tuple[int, int]): Image size in format (height, width)
Returns:
torch.Tensor: Bounding boxes in Pascal VOC format (x_min, y_min, x_max, y_max)
"""
# convert center to corners format
boxes = center_to_corners_format(boxes)
# convert to absolute coordinates
height, width = image_size
boxes = boxes * torch.tensor([[width, height, width, height]])
return boxes
def augment_and_transform_batch(
examples: Mapping[str, Any], transform: A.Compose, image_processor: AutoImageProcessor
) -> BatchFeature:
"""Apply augmentations and format annotations in COCO format for object detection task"""
images = []
annotations = []
for image_id, image, objects in zip(examples["image_id"], examples["image"], examples["objects"]):
image = np.array(image.convert("RGB"))
# apply augmentations
output = transform(image=image, bboxes=objects["bbox"], category=objects["category"])
images.append(output["image"])
# format annotations in COCO format
formatted_annotations = format_image_annotations_as_coco(
image_id, output["category"], objects["area"], output["bboxes"]
)
annotations.append(formatted_annotations)
# Apply the image processor transformations: resizing, rescaling, normalization
result = image_processor(images=images, annotations=annotations, return_tensors="pt")
return result
def collate_fn(batch: List[BatchFeature]) -> Mapping[str, Union[torch.Tensor, List[Any]]]:
data = {}
data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
data["labels"] = [x["labels"] for x in batch]
if "pixel_mask" in batch[0]:
data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch])
return data
@torch.no_grad()
def compute_metrics(
evaluation_results: EvalPrediction,
image_processor: AutoImageProcessor,
threshold: float = 0.0,
id2label: Optional[Mapping[int, str]] = None,
) -> Mapping[str, float]:
"""
Compute mean average mAP, mAR and their variants for the object detection task.
Args:
evaluation_results (EvalPrediction): Predictions and targets from evaluation.
threshold (float, optional): Threshold to filter predicted boxes by confidence. Defaults to 0.0.
id2label (Optional[dict], optional): Mapping from class id to class name. Defaults to None.
Returns:
Mapping[str, float]: Metrics in a form of dictionary {<metric_name>: <metric_value>}
"""
predictions, targets = evaluation_results.predictions, evaluation_results.label_ids
# For metric computation we need to provide:
# - targets in a form of list of dictionaries with keys "boxes", "labels"
# - predictions in a form of list of dictionaries with keys "boxes", "scores", "labels"
image_sizes = []
post_processed_targets = []
post_processed_predictions = []
# Collect targets in the required format for metric computation
for batch in targets:
# collect image sizes, we will need them for predictions post processing
batch_image_sizes = torch.tensor([x["orig_size"] for x in batch])
image_sizes.append(batch_image_sizes)
# collect targets in the required format for metric computation
# boxes were converted to YOLO format needed for model training
# here we will convert them to Pascal VOC format (x_min, y_min, x_max, y_max)
for image_target in batch:
boxes = torch.tensor(image_target["boxes"])
boxes = convert_bbox_yolo_to_pascal(boxes, image_target["orig_size"])
labels = torch.tensor(image_target["class_labels"])
post_processed_targets.append({"boxes": boxes, "labels": labels})
# Collect predictions in the required format for metric computation,
# model produce boxes in YOLO format, then image_processor convert them to Pascal VOC format
for batch, target_sizes in zip(predictions, image_sizes):
batch_logits, batch_boxes = batch[1], batch[2]
output = ModelOutput(logits=torch.tensor(batch_logits), pred_boxes=torch.tensor(batch_boxes))
post_processed_output = image_processor.post_process_object_detection(
output, threshold=threshold, target_sizes=target_sizes
)
post_processed_predictions.extend(post_processed_output)
# Compute metrics
metric = MeanAveragePrecision(box_format="xyxy", class_metrics=True)
metric.update(post_processed_predictions, post_processed_targets)
metrics = metric.compute()
# Replace list of per class metrics with separate metric for each class
classes = metrics.pop("classes")
map_per_class = metrics.pop("map_per_class")
mar_100_per_class = metrics.pop("mar_100_per_class")
for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
class_name = id2label[class_id.item()] if id2label is not None else class_id.item()
metrics[f"map_{class_name}"] = class_map
metrics[f"mar_100_{class_name}"] = class_mar
metrics = {k: round(v.item(), 4) for k, v in metrics.items()}
return metrics
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify
them on the command line.
"""
dataset_name: str = field(
default="cppe-5",
metadata={
"help": "Name of a dataset from the hub (could be your own, possibly private dataset hosted on the hub)."
},
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_val_split: Optional[float] = field(
default=0.15, metadata={"help": "Percent to split off of train for validation."}
)
image_square_size: Optional[int] = field(
default=600,
metadata={
"help": "Image longest size will be resized to this value, then image will be padded to square."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
default="facebook/detr-resnet-50",
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
ignore_mismatched_sizes: bool = field(
default=False,
metadata={
"help": "Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a checkpoint with 3 labels)."
},
)
token: str = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
"execute code present on the Hub on your local machine."
)
},
)
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# # information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_object_detection", model_args, data_args)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Detecting last checkpoint.
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
checkpoint = get_last_checkpoint(training_args.output_dir)
if checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# ------------------------------------------------------------------------------------------------
# Load dataset, prepare splits
# ------------------------------------------------------------------------------------------------
dataset = load_dataset(data_args.dataset_name, cache_dir=model_args.cache_dir)
# If we don't have a validation split, split off a percentage of train as validation
data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
split = dataset["train"].train_test_split(data_args.train_val_split, seed=training_args.seed)
dataset["train"] = split["train"]
dataset["validation"] = split["test"]
# Get dataset categories and prepare mappings for label_name <-> label_id
categories = dataset["train"].features["objects"].feature["category"].names
id2label = dict(enumerate(categories))
label2id = {v: k for k, v in id2label.items()}
# ------------------------------------------------------------------------------------------------
# Load pretrained config, model and image processor
# ------------------------------------------------------------------------------------------------
common_pretrained_args = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.token,
"trust_remote_code": model_args.trust_remote_code,
}
config = AutoConfig.from_pretrained(
model_args.config_name or model_args.model_name_or_path,
label2id=label2id,
id2label=id2label,
**common_pretrained_args,
)
model = AutoModelForObjectDetection.from_pretrained(
model_args.model_name_or_path,
config=config,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
**common_pretrained_args,
)
image_processor = AutoImageProcessor.from_pretrained(
model_args.image_processor_name or model_args.model_name_or_path,
# At this moment we recommend using external transform to pad and resize images.
# It`s faster and yields much better results for object-detection models.
do_pad=False,
do_resize=False,
# We will save image size parameter in config just for reference
size={"longest_edge": data_args.image_square_size},
**common_pretrained_args,
)
# ------------------------------------------------------------------------------------------------
# Define image augmentations and dataset transforms
# ------------------------------------------------------------------------------------------------
max_size = data_args.image_square_size
basic_transforms = [
A.LongestMaxSize(max_size=max_size),
A.PadIfNeeded(max_size, max_size, border_mode=0, value=(128, 128, 128), position="top_left"),
]
train_augment_and_transform = A.Compose(
[
A.Compose(
[
A.SmallestMaxSize(max_size=max_size, p=1.0),
A.RandomSizedBBoxSafeCrop(height=max_size, width=max_size, p=1.0),
],
p=0.2,
),
A.OneOf(
[
A.Blur(blur_limit=7, p=0.5),
A.MotionBlur(blur_limit=7, p=0.5),
A.Defocus(radius=(1, 5), alias_blur=(0.1, 0.25), p=0.1),
],
p=0.1,
),
A.Perspective(p=0.1),
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.HueSaturationValue(p=0.1),
*basic_transforms,
],
bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True, min_area=25),
)
validation_transform = A.Compose(
basic_transforms,
bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True),
)
# Make transform functions for batch and apply for dataset splits
train_transform_batch = partial(
augment_and_transform_batch, transform=train_augment_and_transform, image_processor=image_processor
)
validation_transform_batch = partial(
augment_and_transform_batch, transform=validation_transform, image_processor=image_processor
)
dataset["train"] = dataset["train"].with_transform(train_transform_batch)
dataset["validation"] = dataset["validation"].with_transform(validation_transform_batch)
dataset["test"] = dataset["test"].with_transform(validation_transform_batch)
# ------------------------------------------------------------------------------------------------
# Model training and evaluation with Trainer API
# ------------------------------------------------------------------------------------------------
eval_compute_metrics_fn = partial(
compute_metrics, image_processor=image_processor, id2label=id2label, threshold=0.0
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"] if training_args.do_train else None,
eval_dataset=dataset["validation"] if training_args.do_eval else None,
tokenizer=image_processor,
data_collator=collate_fn,
compute_metrics=eval_compute_metrics_fn,
)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
# Final evaluation
if training_args.do_eval:
metrics = trainer.evaluate(eval_dataset=dataset["test"], metric_key_prefix="test")
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
# Write model card and (optionally) push to hub
kwargs = {
"finetuned_from": model_args.model_name_or_path,
"dataset": data_args.dataset_name,
"tags": ["object-detection", "vision"],
}
if training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
trainer.create_model_card(**kwargs)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,784 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning 🤗 Transformers model for object detection with Accelerate."""
import argparse
import json
import logging
import math
import os
from functools import partial
from pathlib import Path
from typing import Any, List, Mapping, Tuple, Union
import albumentations as A
import datasets
import numpy as np
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from huggingface_hub import HfApi
from torch.utils.data import DataLoader
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from tqdm.auto import tqdm
import transformers
from transformers import (
AutoConfig,
AutoImageProcessor,
AutoModelForObjectDetection,
SchedulerType,
get_scheduler,
)
from transformers.image_processing_utils import BatchFeature
from transformers.image_transforms import center_to_corners_format
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.40.0.dev0")
logging.basicConfig(level=logging.INFO)
logger = get_logger(__name__)
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
# Copied from examples/pytorch/object-detection/run_object_detection.format_image_annotations_as_coco
def format_image_annotations_as_coco(
image_id: str, categories: List[int], areas: List[float], bboxes: List[Tuple[float]]
) -> dict:
"""Format one set of image annotations to the COCO format
Args:
image_id (str): image id. e.g. "0001"
categories (List[int]): list of categories/class labels corresponding to provided bounding boxes
areas (List[float]): list of corresponding areas to provided bounding boxes
bboxes (List[Tuple[float]]): list of bounding boxes provided in COCO format
([center_x, center_y, width, height] in absolute coordinates)
Returns:
dict: {
"image_id": image id,
"annotations": list of formatted annotations
}
"""
annotations = []
for category, area, bbox in zip(categories, areas, bboxes):
formatted_annotation = {
"image_id": image_id,
"category_id": category,
"iscrowd": 0,
"area": area,
"bbox": list(bbox),
}
annotations.append(formatted_annotation)
return {
"image_id": image_id,
"annotations": annotations,
}
# Copied from examples/pytorch/object-detection/run_object_detection.convert_bbox_yolo_to_pascal
def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
"""
Convert bounding boxes from YOLO format (x_center, y_center, width, height) in range [0, 1]
to Pascal VOC format (x_min, y_min, x_max, y_max) in absolute coordinates.
Args:
boxes (torch.Tensor): Bounding boxes in YOLO format
image_size (Tuple[int, int]): Image size in format (height, width)
Returns:
torch.Tensor: Bounding boxes in Pascal VOC format (x_min, y_min, x_max, y_max)
"""
# convert center to corners format
boxes = center_to_corners_format(boxes)
# convert to absolute coordinates
height, width = image_size
boxes = boxes * torch.tensor([[width, height, width, height]])
return boxes
# Copied from examples/pytorch/object-detection/run_object_detection.augment_and_transform_batch
def augment_and_transform_batch(
examples: Mapping[str, Any], transform: A.Compose, image_processor: AutoImageProcessor
) -> BatchFeature:
"""Apply augmentations and format annotations in COCO format for object detection task"""
images = []
annotations = []
for image_id, image, objects in zip(examples["image_id"], examples["image"], examples["objects"]):
image = np.array(image.convert("RGB"))
# apply augmentations
output = transform(image=image, bboxes=objects["bbox"], category=objects["category"])
images.append(output["image"])
# format annotations in COCO format
formatted_annotations = format_image_annotations_as_coco(
image_id, output["category"], objects["area"], output["bboxes"]
)
annotations.append(formatted_annotations)
# Apply the image processor transformations: resizing, rescaling, normalization
result = image_processor(images=images, annotations=annotations, return_tensors="pt")
return result
# Copied from examples/pytorch/object-detection/run_object_detection.collate_fn
def collate_fn(batch: List[BatchFeature]) -> Mapping[str, Union[torch.Tensor, List[Any]]]:
data = {}
data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
data["labels"] = [x["labels"] for x in batch]
if "pixel_mask" in batch[0]:
data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch])
return data
def nested_to_cpu(objects):
"""Move nested tesnors in objects to CPU if they are on GPU"""
if isinstance(objects, torch.Tensor):
return objects.cpu()
elif isinstance(objects, Mapping):
return type(objects)({k: nested_to_cpu(v) for k, v in objects.items()})
elif isinstance(objects, (list, tuple)):
return type(objects)([nested_to_cpu(v) for v in objects])
elif isinstance(objects, (np.ndarray, str, int, float, bool)):
return objects
raise ValueError(f"Unsupported type {type(objects)}")
def evaluation_loop(
model: torch.nn.Module,
image_processor: AutoImageProcessor,
accelerator: Accelerator,
dataloader: DataLoader,
id2label: Mapping[int, str],
) -> dict:
model.eval()
metric = MeanAveragePrecision(box_format="xyxy", class_metrics=True)
for step, batch in enumerate(tqdm(dataloader, disable=not accelerator.is_local_main_process)):
with torch.no_grad():
outputs = model(**batch)
# For metric computation we need to collect ground truth and predicted boxes in the same format
# 1. Collect predicted boxes, classes, scores
# image_processor convert boxes from YOLO format to Pascal VOC format
# ([x_min, y_min, x_max, y_max] in absolute coordinates)
image_size = torch.stack([example["orig_size"] for example in batch["labels"]], dim=0)
predictions = image_processor.post_process_object_detection(outputs, threshold=0.0, target_sizes=image_size)
predictions = nested_to_cpu(predictions)
# 2. Collect ground truth boxes in the same format for metric computation
# Do the same, convert YOLO boxes to Pascal VOC format
target = []
for label in batch["labels"]:
label = nested_to_cpu(label)
boxes = convert_bbox_yolo_to_pascal(label["boxes"], label["orig_size"])
labels = label["class_labels"]
target.append({"boxes": boxes, "labels": labels})
metric.update(predictions, target)
metrics = metric.compute()
# Replace list of per class metrics with separate metric for each class
classes = metrics.pop("classes")
map_per_class = metrics.pop("map_per_class")
mar_100_per_class = metrics.pop("mar_100_per_class")
for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
class_name = id2label[class_id.item()]
metrics[f"map_{class_name}"] = class_map
metrics[f"mar_100_{class_name}"] = class_mar
# Convert metrics to float
metrics = {k: round(v.item(), 4) for k, v in metrics.items()}
return metrics
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model for object detection task")
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to a pretrained model or model identifier from huggingface.co/models.",
default="facebook/detr-resnet-50",
)
parser.add_argument(
"--dataset_name",
type=str,
help="Name of the dataset on the hub.",
default="cppe-5",
)
parser.add_argument(
"--train_val_split",
type=float,
default=0.15,
help="Fraction of the dataset to be used for validation.",
)
parser.add_argument(
"--ignore_mismatched_sizes",
action="store_true",
help="Ignore mismatched sizes between the model and the dataset.",
)
parser.add_argument(
"--image_square_size",
type=int,
default=1333,
help="Image longest size will be resized to this value, then image will be padded to square.",
)
parser.add_argument(
"--cache_dir",
type=str,
help="Path to a folder in which the model and dataset will be cached.",
)
parser.add_argument(
"--use_auth_token",
action="store_true",
help="Whether to use an authentication token to access the model repository.",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=8,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=4,
help="Number of workers to use for the dataloaders.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--adam_beta1",
type=float,
default=0.9,
help="Beta1 for AdamW optimizer",
)
parser.add_argument(
"--adam_beta2",
type=float,
default=0.999,
help="Beta2 for AdamW optimizer",
)
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-8,
help="Epsilon for AdamW optimizer",
)
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument(
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--trust_remote_code",
type=bool,
default=False,
help=(
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
"execute code present on the Hub on your local machine."
),
)
parser.add_argument(
"--checkpointing_steps",
type=str,
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--with_tracking",
required=False,
action="store_true",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. '
"Only applicable when `--with_tracking` is passed."
),
)
args = parser.parse_args()
# Sanity checks
if args.push_to_hub or args.with_tracking:
if args.output_dir is None:
raise ValueError(
"Need an `output_dir` to create a repo when `--push_to_hub` or `with_tracking` is specified."
)
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
return args
def main():
args = parse_args()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_object_detection_no_trainer", args)
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator_log_kwargs = {}
if args.with_tracking:
accelerator_log_kwargs["log_with"] = args.report_to
accelerator_log_kwargs["project_dir"] = args.output_dir
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
# We set device_specific to True as we want different data augmentation per device.
if args.seed is not None:
set_seed(args.seed, device_specific=True)
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
# Retrieve of infer repo_name
repo_name = args.hub_model_id
if repo_name is None:
repo_name = Path(args.output_dir).absolute().name
# Create repo and retrieve repo_id
api = HfApi()
repo_id = api.create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
# Load dataset
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
dataset = load_dataset(args.dataset_name, cache_dir=args.cache_dir)
# If we don't have a validation split, split off a percentage of train as validation.
args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
split = dataset["train"].train_test_split(args.train_val_split, seed=args.seed)
dataset["train"] = split["train"]
dataset["validation"] = split["test"]
# Get dataset categories and prepare mappings for label_name <-> label_id
categories = dataset["train"].features["objects"].feature["category"].names
id2label = dict(enumerate(categories))
label2id = {v: k for k, v in id2label.items()}
# ------------------------------------------------------------------------------------------------
# Load pretrained config, model and image processor
# ------------------------------------------------------------------------------------------------
common_pretrained_args = {
"cache_dir": args.cache_dir,
"token": args.hub_token,
"trust_remote_code": args.trust_remote_code,
}
config = AutoConfig.from_pretrained(
args.model_name_or_path, label2id=label2id, id2label=id2label, **common_pretrained_args
)
model = AutoModelForObjectDetection.from_pretrained(
args.model_name_or_path,
config=config,
ignore_mismatched_sizes=args.ignore_mismatched_sizes,
**common_pretrained_args,
)
image_processor = AutoImageProcessor.from_pretrained(
args.model_name_or_path,
# At this moment we recommend using external transform to pad and resize images.
# It`s faster and yields much better results for object-detection models.
do_pad=False,
do_resize=False,
# We will save image size parameter in config just for reference
size={"longest_edge": args.image_square_size},
**common_pretrained_args,
)
# ------------------------------------------------------------------------------------------------
# Define image augmentations and dataset transforms
# ------------------------------------------------------------------------------------------------
max_size = args.image_square_size
basic_transforms = [
A.LongestMaxSize(max_size=max_size),
A.PadIfNeeded(max_size, max_size, border_mode=0, value=(128, 128, 128), position="top_left"),
]
train_augment_and_transform = A.Compose(
[
A.Compose(
[
A.SmallestMaxSize(max_size=max_size, p=1.0),
A.RandomSizedBBoxSafeCrop(height=max_size, width=max_size, p=1.0),
],
p=0.2,
),
A.OneOf(
[
A.Blur(blur_limit=7, p=0.5),
A.MotionBlur(blur_limit=7, p=0.5),
A.Defocus(radius=(1, 5), alias_blur=(0.1, 0.25), p=0.1),
],
p=0.1,
),
A.Perspective(p=0.1),
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.HueSaturationValue(p=0.1),
*basic_transforms,
],
bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True, min_area=25),
)
validation_transform = A.Compose(
basic_transforms,
bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True),
)
# Make transform functions for batch and apply for dataset splits
train_transform_batch = partial(
augment_and_transform_batch, transform=train_augment_and_transform, image_processor=image_processor
)
validation_transform_batch = partial(
augment_and_transform_batch, transform=validation_transform, image_processor=image_processor
)
with accelerator.main_process_first():
train_dataset = dataset["train"].with_transform(train_transform_batch)
valid_dataset = dataset["validation"].with_transform(validation_transform_batch)
test_dataset = dataset["test"].with_transform(validation_transform_batch)
dataloader_common_args = {
"num_workers": args.dataloader_num_workers,
"collate_fn": collate_fn,
}
train_dataloader = DataLoader(
train_dataset, shuffle=True, batch_size=args.per_device_train_batch_size, **dataloader_common_args
)
valid_dataloader = DataLoader(
valid_dataset, shuffle=False, batch_size=args.per_device_eval_batch_size, **dataloader_common_args
)
test_dataloader = DataLoader(
test_dataset, shuffle=False, batch_size=args.per_device_eval_batch_size, **dataloader_common_args
)
# ------------------------------------------------------------------------------------------------
# Define optimizer, scheduler and prepare everything with the accelerator
# ------------------------------------------------------------------------------------------------
# Optimizer
optimizer = torch.optim.AdamW(
list(model.parameters()),
lr=args.learning_rate,
betas=[args.adam_beta1, args.adam_beta2],
eps=args.adam_epsilon,
)
# Figure out how many steps we should save the Accelerator states
checkpointing_steps = args.checkpointing_steps
if checkpointing_steps is not None and checkpointing_steps.isdigit():
checkpointing_steps = int(checkpointing_steps)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps
if overrode_max_train_steps
else args.max_train_steps * accelerator.num_processes,
)
# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, valid_dataloader, test_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, valid_dataloader, test_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("object_detection_no_trainer", experiment_config)
# ------------------------------------------------------------------------------------------------
# Run training with evaluation on each epoch
# ------------------------------------------------------------------------------------------------
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
checkpoint_path = args.resume_from_checkpoint
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(checkpoint_path)
# Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
completed_steps = starting_epoch * num_update_steps_per_epoch
else:
# need to multiply `gradient_accumulation_steps` to reflect real steps
resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
starting_epoch = resume_step // len(train_dataloader)
completed_steps = resume_step // args.gradient_accumulation_steps
resume_step -= starting_epoch * len(train_dataloader)
# update the progress_bar if load from checkpoint
progress_bar.update(completed_steps)
for epoch in range(starting_epoch, args.num_train_epochs):
model.train()
if args.with_tracking:
total_loss = 0
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We skip the first `n` batches in the dataloader when resuming from a checkpoint
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
args.output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
)
if accelerator.is_main_process:
image_processor.save_pretrained(args.output_dir)
api.upload_folder(
commit_message=f"Training in progress epoch {epoch}",
folder_path=args.output_dir,
repo_id=repo_id,
repo_type="model",
token=args.hub_token,
)
if completed_steps >= args.max_train_steps:
break
logger.info("***** Running evaluation *****")
metrics = evaluation_loop(model, image_processor, accelerator, valid_dataloader, id2label)
logger.info(f"epoch {epoch}: {metrics}")
if args.with_tracking:
accelerator.log(
{
"train_loss": total_loss.item() / len(train_dataloader),
**metrics,
"epoch": epoch,
"step": completed_steps,
},
step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process:
image_processor.save_pretrained(args.output_dir)
api.upload_folder(
commit_message=f"Training in progress epoch {epoch}",
folder_path=args.output_dir,
repo_id=repo_id,
repo_type="model",
token=args.hub_token,
)
if args.checkpointing_steps == "epoch":
output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
# ------------------------------------------------------------------------------------------------
# Run evaluation on test dataset and save the model
# ------------------------------------------------------------------------------------------------
logger.info("***** Running evaluation on test dataset *****")
metrics = evaluation_loop(model, image_processor, accelerator, test_dataloader, id2label)
metrics = {f"test_{k}": v for k, v in metrics.items()}
logger.info(f"Test metrics: {metrics}")
if args.with_tracking:
accelerator.end_training()
if args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process:
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump(metrics, f, indent=2)
image_processor.save_pretrained(args.output_dir)
if args.push_to_hub:
api.upload_folder(
commit_message="End of training",
folder_path=args.output_dir,
repo_id=repo_id,
repo_type="model",
token=args.hub_token,
ignore_patterns=["epoch_*"],
)
if __name__ == "__main__":
main()

View File

@ -331,3 +331,27 @@ class ExamplesTestsNoTrainer(TestCasePlus):
self.assertGreaterEqual(result["eval_accuracy"], 0.4)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "step_1")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_classification_no_trainer")))
@slow
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
def test_run_object_detection_no_trainer(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
{self.examples_dir}/pytorch/object-detection/run_object_detection_no_trainer.py
--model_name_or_path qubvel-hf/detr-resnet-50-finetuned-10k-cppe5
--dataset_name qubvel-hf/cppe-5-sample
--output_dir {tmp_dir}
--max_train_steps=10
--num_warmup_steps=2
--learning_rate=1e-6
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--checkpointing_steps epoch
""".split()
run_command(self._launch_args + testargs)
result = get_results(tmp_dir)
self.assertGreaterEqual(result["test_map"], 0.10)

View File

@ -48,6 +48,7 @@ SRC_DIRS = [
"speech-pretraining",
"image-pretraining",
"semantic-segmentation",
"object-detection",
]
]
sys.path.extend(SRC_DIRS)
@ -62,6 +63,7 @@ if SRC_DIRS is not None:
import run_mae
import run_mlm
import run_ner
import run_object_detection
import run_qa as run_squad
import run_semantic_segmentation
import run_seq2seq_qa as run_squad_seq2seq
@ -609,3 +611,31 @@ class ExamplesTests(TestCasePlus):
run_semantic_segmentation.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.1)
@patch.dict(os.environ, {"WANDB_DISABLED": "true"})
def test_run_object_detection(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_object_detection.py
--model_name_or_path qubvel-hf/detr-resnet-50-finetuned-10k-cppe5
--output_dir {tmp_dir}
--dataset_name qubvel-hf/cppe-5-sample
--do_train
--do_eval
--remove_unused_columns False
--overwrite_output_dir True
--eval_do_concat_batches False
--max_steps 10
--learning_rate=1e-6
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--seed 32
""".split()
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")
with patch.object(sys, "argv", testargs):
run_object_detection.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["test_map"], 0.1)