Add image classification script, no trainer (#16727)
* Add first draft * Improve README and run fixup * Make script aligned with other scripts, improve README * Improve script and add test * Remove print statement * Apply suggestions from code review * Add num_labels to make test pass * Improve README
This commit is contained in:
parent
db9f189121
commit
b96e82c80a
|
@ -43,7 +43,7 @@ Coming soon!
|
|||
| [**`speech-recognition`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition) | TIMIT | ✅ | - |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/speech_recognition.ipynb)
|
||||
| [**`multi-lingual speech-recognition`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition) | Common Voice | ✅ | - |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/multi_lingual_speech_recognition.ipynb)
|
||||
| [**`audio-classification`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/audio-classification) | SUPERB KS | ✅ | - |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/audio_classification.ipynb)
|
||||
| [**`image-classification`**](https://github.com/huggingface/notebooks/blob/main/examples/image_classification.ipynb) | CIFAR-10 | ✅ | - |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb)
|
||||
| [**`image-classification`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) | CIFAR-10 | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb)
|
||||
|
||||
|
||||
## Running quick tests
|
||||
|
|
|
@ -14,21 +14,28 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Image classification example
|
||||
# Image classification examples
|
||||
|
||||
This directory contains a script, `run_image_classification.py`, that showcases how to fine-tune any model supported by the [`AutoModelForImageClassification` API](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification) (such as [ViT](https://huggingface.co/docs/transformers/main/en/model_doc/vit), [ConvNeXT](https://huggingface.co/docs/transformers/main/en/model_doc/convnext), [ResNet](https://huggingface.co/docs/transformers/main/en/model_doc/resnet), [Swin Transformer](https://huggingface.co/docs/transformers/main/en/model_doc/swin)...) using PyTorch. It can be used to fine-tune models on both well-known datasets (like [CIFAR-10](https://huggingface.co/datasets/cifar10), [Fashion MNIST](https://huggingface.co/datasets/fashion_mnist), ...) as well as on your own custom data.
|
||||
This directory contains 2 scripts that showcase how to fine-tune any model supported by the [`AutoModelForImageClassification` API](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification) (such as [ViT](https://huggingface.co/docs/transformers/main/en/model_doc/vit), [ConvNeXT](https://huggingface.co/docs/transformers/main/en/model_doc/convnext), [ResNet](https://huggingface.co/docs/transformers/main/en/model_doc/resnet), [Swin Transformer](https://huggingface.co/docs/transformers/main/en/model_doc/swin)...) using PyTorch. They can be used to fine-tune models on both [datasets from the hub](#using-datasets-from-hub) as well as on [your own custom data](#using-your-own-data).
|
||||
|
||||
This page includes 2 sections:
|
||||
- [Using datasets from the 🤗 hub](#using-datasets-from-hub)
|
||||
- [Using your own data](#using-your-own-data).
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_classification_inference_widget.png" height="400" />
|
||||
|
||||
Try out the inference widget here: https://huggingface.co/google/vit-base-patch16-224
|
||||
|
||||
## Using datasets from Hub
|
||||
Content:
|
||||
- [PyTorch version, Trainer](#pytorch-version-no-trainer)
|
||||
- [PyTorch version, no Trainer](#pytorch-version-trainer)
|
||||
|
||||
## PyTorch version, Trainer
|
||||
|
||||
Based on the script [`run_image_classification.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/image-classification/run_image_classification.py).
|
||||
|
||||
The script leverages the 🤗 [Trainer API](https://huggingface.co/docs/transformers/main_classes/trainer) to automatically take care of the training for you, running on distributed environments right away.
|
||||
|
||||
### Using datasets from Hub
|
||||
|
||||
Here we show how to fine-tune a Vision Transformer (`ViT`) on the [beans](https://huggingface.co/datasets/beans) dataset, to classify the disease type of bean leaves.
|
||||
|
||||
👀 See the results here: [nateraw/vit-base-beans](https://huggingface.co/nateraw/vit-base-beans).
|
||||
|
||||
```bash
|
||||
python run_image_classification.py \
|
||||
--dataset_name beans \
|
||||
|
@ -51,9 +58,11 @@ python run_image_classification.py \
|
|||
--seed 1337
|
||||
```
|
||||
|
||||
To fine-tune another model, simply provide the `--model_name_or_path` argument. To train on another dataset, simply set the `--dataset_name` argument.
|
||||
👀 See the results here: [nateraw/vit-base-beans](https://huggingface.co/nateraw/vit-base-beans).
|
||||
|
||||
## Using your own data
|
||||
Note that you can replace the model and dataset by simply setting the `model_name_or_path` and `dataset_name` arguments respectively, with any model or dataset from the [hub](https://huggingface.co/). For an overview of all possible arguments, we refer to the [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) of the `TrainingArguments`, which can be passed as flags.
|
||||
|
||||
### Using your own data
|
||||
|
||||
To use your own dataset, there are 2 ways:
|
||||
- you can either provide your own folders as `--train_dir` and/or `--validation_dir` arguments
|
||||
|
@ -61,7 +70,7 @@ To use your own dataset, there are 2 ways:
|
|||
|
||||
Below, we explain both in more detail.
|
||||
|
||||
### Provide them as folders
|
||||
#### Provide them as folders
|
||||
|
||||
If you provide your own folders with images, the script expects the following directory structure:
|
||||
|
||||
|
@ -88,11 +97,11 @@ python run_image_classification.py \
|
|||
|
||||
Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects.
|
||||
|
||||
#### 💡 The above will split the train dir into training and evaluation sets
|
||||
##### 💡 The above will split the train dir into training and evaluation sets
|
||||
- To control the split amount, use the `--train_val_split` flag.
|
||||
- To provide your own validation split in its own directory, you can pass the `--validation_dir <path-to-val-root>` flag.
|
||||
|
||||
### Upload your data to the hub, as a (possibly private) repo
|
||||
#### Upload your data to the hub, as a (possibly private) repo
|
||||
|
||||
It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following:
|
||||
|
||||
|
@ -117,17 +126,18 @@ dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "pa
|
|||
Next, push it to the hub!
|
||||
|
||||
```python
|
||||
# assuming you have ran the huggingface-cli login command in a terminal
|
||||
dataset.push_to_hub("name_of_your_dataset")
|
||||
|
||||
# if you want to push to a private repo, simply pass private=True:
|
||||
dataset.push_to_hub("name_of_your_dataset", private=True)
|
||||
```
|
||||
|
||||
and that's it! You can now simply train your model simply by setting the `--dataset_name` argument to the name of your dataset on the hub (as explained in [Using datasets from the 🤗 hub](#using-datasets-from-hub)).
|
||||
and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub (as explained in [Using datasets from the 🤗 hub](#using-datasets-from-hub)).
|
||||
|
||||
More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).
|
||||
|
||||
# Sharing your model on 🤗 Hub
|
||||
### Sharing your model on 🤗 Hub
|
||||
|
||||
0. If you haven't already, [sign up](https://huggingface.co/join) for a 🤗 account
|
||||
|
||||
|
@ -154,3 +164,46 @@ python run_image_classification.py \
|
|||
--push_to_hub_model_id <name-your-model> \
|
||||
...
|
||||
```
|
||||
|
||||
## PyTorch version, no Trainer
|
||||
|
||||
Based on the script [`run_image_classification_no_trainer.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/image-classification/run_image_classification_no_trainer.py).
|
||||
|
||||
Like `run_image_classification.py`, this script allows you to fine-tune any of the models on the [hub](https://huggingface.co/models) on an image classification task. The main difference is that this script exposes the bare training loop, to allow you to quickly experiment and add any customization you would like.
|
||||
|
||||
It offers less options than the script with `Trainer` (for instance you can easily change the options for the optimizer
|
||||
or the dataloaders directly in the script) but still run in a distributed setup, and supports mixed precision by
|
||||
the means of the [🤗 `Accelerate`](https://github.com/huggingface/accelerate) library. You can use the script normally
|
||||
after installing it:
|
||||
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
You can then use your usual launchers to run in it in a distributed environment, but the easiest way is to run
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
and reply to the questions asked. Then
|
||||
|
||||
```bash
|
||||
accelerate test
|
||||
```
|
||||
|
||||
that will check everything is ready for training. Finally, you can launch training with
|
||||
|
||||
```bash
|
||||
accelerate launch run_image_classification_trainer.py
|
||||
```
|
||||
|
||||
This command is the same and will work for:
|
||||
|
||||
- single/multiple CPUs
|
||||
- single/multiple GPUs
|
||||
- TPUs
|
||||
|
||||
Note that this library is in alpha release so your feedback is more than welcome if you encounter any problem using it.
|
||||
|
||||
Regarding using custom data with this script, we refer to [using your own data](#using-your-own-data).
|
|
@ -0,0 +1,512 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning any 🤗 Transformers model for image classification leveraging 🤗 Accelerate."""
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
Normalize,
|
||||
RandomHorizontalFlip,
|
||||
RandomResizedCrop,
|
||||
Resize,
|
||||
ToTensor,
|
||||
)
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
from huggingface_hub import Repository
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoFeatureExtractor,
|
||||
AutoModelForImageClassification,
|
||||
SchedulerType,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils import get_full_repo_name
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Fine-tune a Transformers model on an image classification dataset")
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default="cifar10",
|
||||
help="The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private, dataset).",
|
||||
)
|
||||
parser.add_argument("--train_dir", type=str, default=None, help="A folder containing the training data.")
|
||||
parser.add_argument("--validation_dir", type=str, default=None, help="A folder containing the validation data.")
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_eval_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_val_split",
|
||||
type=float,
|
||||
default=0.15,
|
||||
help="Percent to split off of train for validation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
default="google/vit-base-patch16-224-in21k",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_eval_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for the evaluation dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_type",
|
||||
type=SchedulerType,
|
||||
default="linear",
|
||||
help="The scheduler type to use.",
|
||||
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
||||
)
|
||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="If the training should continue from a checkpoint folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_tracking",
|
||||
action="store_true",
|
||||
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
if args.dataset_name is None and args.train_dir is None and args.validation_dir is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation folder.")
|
||||
|
||||
if args.push_to_hub or args.with_tracking:
|
||||
if args.output_dir is None:
|
||||
raise ValueError(
|
||||
"Need an `output_dir` to create a repo when `--push_to_hub` or `with_tracking` is specified."
|
||||
)
|
||||
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
||||
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
|
||||
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
|
||||
logger.info(accelerator.state)
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state)
|
||||
|
||||
# Setup logging, we only want one process per machine to log things on the screen.
|
||||
# accelerator.is_local_main_process is only True for one process per machine.
|
||||
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
||||
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
||||
|
||||
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = load_dataset(args.dataset_name, task="image-classification")
|
||||
else:
|
||||
data_files = {}
|
||||
if args.train_dir is not None:
|
||||
data_files["train"] = os.path.join(args.train_dir, "**")
|
||||
if args.validation_dir is not None:
|
||||
data_files["validation"] = os.path.join(args.validation_dir, "**")
|
||||
dataset = load_dataset(
|
||||
"imagefolder",
|
||||
data_files=data_files,
|
||||
cache_dir=args.cache_dir,
|
||||
task="image-classification",
|
||||
)
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder.
|
||||
|
||||
# If we don't have a validation split, split off a percentage of train as validation.
|
||||
args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
|
||||
if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
|
||||
split = dataset["train"].train_test_split(args.train_val_split)
|
||||
dataset["train"] = split["train"]
|
||||
dataset["validation"] = split["test"]
|
||||
|
||||
# Prepare label mappings.
|
||||
# We'll include these in the model's config to get human readable labels in the Inference API.
|
||||
labels = dataset["train"].features["labels"].names
|
||||
label2id = {label: str(i) for i, label in enumerate(labels)}
|
||||
id2label = {str(i): label for i, label in enumerate(labels)}
|
||||
|
||||
# Load pretrained model and feature extractor
|
||||
#
|
||||
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
num_labels=len(labels),
|
||||
i2label=id2label,
|
||||
label2id=label2id,
|
||||
finetuning_task="image-classification",
|
||||
)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_name_or_path)
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Preprocessing the datasets
|
||||
|
||||
# Define torchvision transforms to be applied to each image.
|
||||
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
||||
train_transforms = Compose(
|
||||
[
|
||||
RandomResizedCrop(feature_extractor.size),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
val_transforms = Compose(
|
||||
[
|
||||
Resize(feature_extractor.size),
|
||||
CenterCrop(feature_extractor.size),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_train(example_batch):
|
||||
"""Apply _train_transforms across a batch."""
|
||||
example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
|
||||
return example_batch
|
||||
|
||||
def preprocess_val(example_batch):
|
||||
"""Apply _val_transforms across a batch."""
|
||||
example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
|
||||
return example_batch
|
||||
|
||||
with accelerator.main_process_first():
|
||||
if args.max_train_samples is not None:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
# Set the training transforms
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
if args.max_eval_samples is not None:
|
||||
dataset["validation"] = dataset["validation"].shuffle(seed=args.seed).select(range(args.max_eval_samples))
|
||||
# Set the validation transforms
|
||||
eval_dataset = dataset["validation"].with_transform(preprocess_val)
|
||||
|
||||
# DataLoaders creation:
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
labels = torch.tensor([example["labels"] for example in examples])
|
||||
return {"pixel_values": pixel_values, "labels": labels}
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size
|
||||
)
|
||||
eval_dataloader = DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=args.per_device_eval_batch_size)
|
||||
|
||||
# Optimizer
|
||||
# Split weights in two groups, one with weight decay and the other not.
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
else:
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
name=args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.num_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# Figure out how many steps we should save the Accelerator states
|
||||
if hasattr(args.checkpointing_steps, "isdigit"):
|
||||
checkpointing_steps = args.checkpointing_steps
|
||||
if args.checkpointing_steps.isdigit():
|
||||
checkpointing_steps = int(args.checkpointing_steps)
|
||||
else:
|
||||
checkpointing_steps = None
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration
|
||||
if args.with_tracking:
|
||||
experiment_config = vars(args)
|
||||
# TensorBoard cannot log Enums, need the raw value
|
||||
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
|
||||
accelerator.init_trackers("image_classification_no_trainer", experiment_config)
|
||||
|
||||
# Get the metric function
|
||||
metric = load_metric("accuracy")
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
completed_steps = 0
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
|
||||
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
|
||||
accelerator.load_state(args.resume_from_checkpoint)
|
||||
resume_step = None
|
||||
path = args.resume_from_checkpoint
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
|
||||
dirs.sort(key=os.path.getctime)
|
||||
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
|
||||
if "epoch" in path:
|
||||
args.num_train_epochs -= int(path.replace("epoch_", ""))
|
||||
else:
|
||||
resume_step = int(path.replace("step_", ""))
|
||||
args.num_train_epochs -= resume_step // len(train_dataloader)
|
||||
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
model.train()
|
||||
if args.with_tracking:
|
||||
total_loss = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# We need to skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == 0 and step < resume_step:
|
||||
continue
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
# We keep track of the loss at each epoch
|
||||
if args.with_tracking:
|
||||
total_loss += loss.detach().float()
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
progress_bar.update(1)
|
||||
completed_steps += 1
|
||||
|
||||
if isinstance(checkpointing_steps, int):
|
||||
if completed_steps % checkpointing_steps == 0:
|
||||
output_dir = f"step_{completed_steps}"
|
||||
if args.output_dir is not None:
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
|
||||
if args.push_to_hub and epoch < args.num_train_epochs - 1:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
feature_extractor.save_pretrained(args.output_dir)
|
||||
repo.push_to_hub(
|
||||
commit_message=f"Training in progress {completed_steps} steps",
|
||||
blocking=False,
|
||||
auto_lfs_prune=True,
|
||||
)
|
||||
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
model.eval()
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1)
|
||||
metric.add_batch(
|
||||
predictions=accelerator.gather(predictions),
|
||||
references=accelerator.gather(batch["labels"]),
|
||||
)
|
||||
|
||||
eval_metric = metric.compute()
|
||||
logger.info(f"epoch {epoch}: {eval_metric}")
|
||||
|
||||
if args.with_tracking:
|
||||
accelerator.log(
|
||||
{
|
||||
"accuracy": eval_metric,
|
||||
"train_loss": total_loss,
|
||||
"epoch": epoch,
|
||||
"step": completed_steps,
|
||||
},
|
||||
)
|
||||
|
||||
if args.push_to_hub and epoch < args.num_train_epochs - 1:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
feature_extractor.save_pretrained(args.output_dir)
|
||||
repo.push_to_hub(
|
||||
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
|
||||
)
|
||||
|
||||
if args.checkpointing_steps == "epoch":
|
||||
output_dir = f"epoch_{epoch}"
|
||||
if args.output_dir is not None:
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
|
||||
if args.output_dir is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
feature_extractor.save_pretrained(args.output_dir)
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||
|
||||
if args.output_dir is not None:
|
||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||
json.dump({"eval_accuracy": eval_metric["accuracy"]}, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -52,6 +52,7 @@ sys.path.extend(SRC_DIRS)
|
|||
if SRC_DIRS is not None:
|
||||
import run_clm_no_trainer
|
||||
import run_glue_no_trainer
|
||||
import run_image_classification_no_trainer
|
||||
import run_mlm_no_trainer
|
||||
import run_ner_no_trainer
|
||||
import run_qa_no_trainer as run_squad_no_trainer
|
||||
|
@ -321,3 +322,25 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||
run_semantic_segmentation_no_trainer.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
|
||||
|
||||
def test_run_image_classification_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_image_classification_no_trainer.py
|
||||
--dataset_name huggingface/image-classification-test-sample
|
||||
--output_dir {tmp_dir}
|
||||
--num_warmup_steps=8
|
||||
--learning_rate=3e-3
|
||||
--per_device_train_batch_size=2
|
||||
--per_device_eval_batch_size=1
|
||||
--checkpointing_steps epoch
|
||||
--with_tracking
|
||||
--seed 42
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_image_classification_no_trainer.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.50)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_classification_no_trainer")))
|
||||
|
|
Loading…
Reference in New Issue