Add image classification script, no trainer (#16727)

* Add first draft

* Improve README and run fixup

* Make script aligned with other scripts, improve README

* Improve script and add test

* Remove print statement

* Apply suggestions from code review

* Add num_labels to make test pass

* Improve README
This commit is contained in:
NielsRogge 2022-04-19 16:32:08 +02:00 committed by GitHub
parent db9f189121
commit b96e82c80a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 604 additions and 16 deletions

View File

@ -43,7 +43,7 @@ Coming soon!
| [**`speech-recognition`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition) | TIMIT | ✅ | - |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/speech_recognition.ipynb)
| [**`multi-lingual speech-recognition`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition) | Common Voice | ✅ | - |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/multi_lingual_speech_recognition.ipynb)
| [**`audio-classification`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/audio-classification) | SUPERB KS | ✅ | - |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/audio_classification.ipynb)
| [**`image-classification`**](https://github.com/huggingface/notebooks/blob/main/examples/image_classification.ipynb) | CIFAR-10 | ✅ | - |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb)
| [**`image-classification`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) | CIFAR-10 | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb)
## Running quick tests

View File

@ -14,21 +14,28 @@ See the License for the specific language governing permissions and
limitations under the License.
-->
# Image classification example
# Image classification examples
This directory contains a script, `run_image_classification.py`, that showcases how to fine-tune any model supported by the [`AutoModelForImageClassification` API](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification) (such as [ViT](https://huggingface.co/docs/transformers/main/en/model_doc/vit), [ConvNeXT](https://huggingface.co/docs/transformers/main/en/model_doc/convnext), [ResNet](https://huggingface.co/docs/transformers/main/en/model_doc/resnet), [Swin Transformer](https://huggingface.co/docs/transformers/main/en/model_doc/swin)...) using PyTorch. It can be used to fine-tune models on both well-known datasets (like [CIFAR-10](https://huggingface.co/datasets/cifar10), [Fashion MNIST](https://huggingface.co/datasets/fashion_mnist), ...) as well as on your own custom data.
This directory contains 2 scripts that showcase how to fine-tune any model supported by the [`AutoModelForImageClassification` API](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification) (such as [ViT](https://huggingface.co/docs/transformers/main/en/model_doc/vit), [ConvNeXT](https://huggingface.co/docs/transformers/main/en/model_doc/convnext), [ResNet](https://huggingface.co/docs/transformers/main/en/model_doc/resnet), [Swin Transformer](https://huggingface.co/docs/transformers/main/en/model_doc/swin)...) using PyTorch. They can be used to fine-tune models on both [datasets from the hub](#using-datasets-from-hub) as well as on [your own custom data](#using-your-own-data).
This page includes 2 sections:
- [Using datasets from the 🤗 hub](#using-datasets-from-hub)
- [Using your own data](#using-your-own-data).
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_classification_inference_widget.png" height="400" />
Try out the inference widget here: https://huggingface.co/google/vit-base-patch16-224
## Using datasets from Hub
Content:
- [PyTorch version, Trainer](#pytorch-version-no-trainer)
- [PyTorch version, no Trainer](#pytorch-version-trainer)
## PyTorch version, Trainer
Based on the script [`run_image_classification.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/image-classification/run_image_classification.py).
The script leverages the 🤗 [Trainer API](https://huggingface.co/docs/transformers/main_classes/trainer) to automatically take care of the training for you, running on distributed environments right away.
### Using datasets from Hub
Here we show how to fine-tune a Vision Transformer (`ViT`) on the [beans](https://huggingface.co/datasets/beans) dataset, to classify the disease type of bean leaves.
👀 See the results here: [nateraw/vit-base-beans](https://huggingface.co/nateraw/vit-base-beans).
```bash
python run_image_classification.py \
--dataset_name beans \
@ -51,9 +58,11 @@ python run_image_classification.py \
--seed 1337
```
To fine-tune another model, simply provide the `--model_name_or_path` argument. To train on another dataset, simply set the `--dataset_name` argument.
👀 See the results here: [nateraw/vit-base-beans](https://huggingface.co/nateraw/vit-base-beans).
## Using your own data
Note that you can replace the model and dataset by simply setting the `model_name_or_path` and `dataset_name` arguments respectively, with any model or dataset from the [hub](https://huggingface.co/). For an overview of all possible arguments, we refer to the [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) of the `TrainingArguments`, which can be passed as flags.
### Using your own data
To use your own dataset, there are 2 ways:
- you can either provide your own folders as `--train_dir` and/or `--validation_dir` arguments
@ -61,7 +70,7 @@ To use your own dataset, there are 2 ways:
Below, we explain both in more detail.
### Provide them as folders
#### Provide them as folders
If you provide your own folders with images, the script expects the following directory structure:
@ -88,11 +97,11 @@ python run_image_classification.py \
Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects.
#### 💡 The above will split the train dir into training and evaluation sets
##### 💡 The above will split the train dir into training and evaluation sets
- To control the split amount, use the `--train_val_split` flag.
- To provide your own validation split in its own directory, you can pass the `--validation_dir <path-to-val-root>` flag.
### Upload your data to the hub, as a (possibly private) repo
#### Upload your data to the hub, as a (possibly private) repo
It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following:
@ -117,17 +126,18 @@ dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "pa
Next, push it to the hub!
```python
# assuming you have ran the huggingface-cli login command in a terminal
dataset.push_to_hub("name_of_your_dataset")
# if you want to push to a private repo, simply pass private=True:
dataset.push_to_hub("name_of_your_dataset", private=True)
```
and that's it! You can now simply train your model simply by setting the `--dataset_name` argument to the name of your dataset on the hub (as explained in [Using datasets from the 🤗 hub](#using-datasets-from-hub)).
and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub (as explained in [Using datasets from the 🤗 hub](#using-datasets-from-hub)).
More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).
# Sharing your model on 🤗 Hub
### Sharing your model on 🤗 Hub
0. If you haven't already, [sign up](https://huggingface.co/join) for a 🤗 account
@ -154,3 +164,46 @@ python run_image_classification.py \
--push_to_hub_model_id <name-your-model> \
...
```
## PyTorch version, no Trainer
Based on the script [`run_image_classification_no_trainer.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/image-classification/run_image_classification_no_trainer.py).
Like `run_image_classification.py`, this script allows you to fine-tune any of the models on the [hub](https://huggingface.co/models) on an image classification task. The main difference is that this script exposes the bare training loop, to allow you to quickly experiment and add any customization you would like.
It offers less options than the script with `Trainer` (for instance you can easily change the options for the optimizer
or the dataloaders directly in the script) but still run in a distributed setup, and supports mixed precision by
the means of the [🤗 `Accelerate`](https://github.com/huggingface/accelerate) library. You can use the script normally
after installing it:
```bash
pip install accelerate
```
You can then use your usual launchers to run in it in a distributed environment, but the easiest way is to run
```bash
accelerate config
```
and reply to the questions asked. Then
```bash
accelerate test
```
that will check everything is ready for training. Finally, you can launch training with
```bash
accelerate launch run_image_classification_trainer.py
```
This command is the same and will work for:
- single/multiple CPUs
- single/multiple GPUs
- TPUs
Note that this library is in alpha release so your feedback is more than welcome if you encounter any problem using it.
Regarding using custom data with this script, we refer to [using your own data](#using-your-own-data).

View File

@ -0,0 +1,512 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning any 🤗 Transformers model for image classification leveraging 🤗 Accelerate."""
import argparse
import json
import logging
import math
import os
from pathlib import Path
import datasets
import torch
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
from torchvision.transforms import (
CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor,
)
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from accelerate.utils import set_seed
from huggingface_hub import Repository
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoModelForImageClassification,
SchedulerType,
get_scheduler,
)
from transformers.utils import get_full_repo_name
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
def parse_args():
parser = argparse.ArgumentParser(description="Fine-tune a Transformers model on an image classification dataset")
parser.add_argument(
"--dataset_name",
type=str,
default="cifar10",
help="The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private, dataset).",
)
parser.add_argument("--train_dir", type=str, default=None, help="A folder containing the training data.")
parser.add_argument("--validation_dir", type=str, default=None, help="A folder containing the validation data.")
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help="For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set.",
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set.",
)
parser.add_argument(
"--train_val_split",
type=float,
default=0.15,
help="Percent to split off of train for validation",
)
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
default="google/vit-base-patch16-224-in21k",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=8,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument(
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--checkpointing_steps",
type=str,
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
args = parser.parse_args()
# Sanity checks
if args.dataset_name is None and args.train_dir is None and args.validation_dir is None:
raise ValueError("Need either a dataset name or a training/validation folder.")
if args.push_to_hub or args.with_tracking:
if args.output_dir is None:
raise ValueError(
"Need an `output_dir` to create a repo when `--push_to_hub` or `with_tracking` is specified."
)
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
return args
def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
logger.info(accelerator.state)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state)
# Setup logging, we only want one process per machine to log things on the screen.
# accelerator.is_local_main_process is only True for one process per machine.
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(args.dataset_name, task="image-classification")
else:
data_files = {}
if args.train_dir is not None:
data_files["train"] = os.path.join(args.train_dir, "**")
if args.validation_dir is not None:
data_files["validation"] = os.path.join(args.validation_dir, "**")
dataset = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=args.cache_dir,
task="image-classification",
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder.
# If we don't have a validation split, split off a percentage of train as validation.
args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
split = dataset["train"].train_test_split(args.train_val_split)
dataset["train"] = split["train"]
dataset["validation"] = split["test"]
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names
label2id = {label: str(i) for i, label in enumerate(labels)}
id2label = {str(i): label for i, label in enumerate(labels)}
# Load pretrained model and feature extractor
#
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
config = AutoConfig.from_pretrained(
args.model_name_or_path,
num_labels=len(labels),
i2label=id2label,
label2id=label2id,
finetuning_task="image-classification",
)
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_name_or_path)
model = AutoModelForImageClassification.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
)
# Preprocessing the datasets
# Define torchvision transforms to be applied to each image.
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
train_transforms = Compose(
[
RandomResizedCrop(feature_extractor.size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
val_transforms = Compose(
[
Resize(feature_extractor.size),
CenterCrop(feature_extractor.size),
ToTensor(),
normalize,
]
)
def preprocess_train(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
return example_batch
def preprocess_val(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
return example_batch
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
if args.max_eval_samples is not None:
dataset["validation"] = dataset["validation"].shuffle(seed=args.seed).select(range(args.max_eval_samples))
# Set the validation transforms
eval_dataset = dataset["validation"].with_transform(preprocess_val)
# DataLoaders creation:
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["labels"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=args.per_device_eval_batch_size)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
else:
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps,
)
# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
if args.checkpointing_steps.isdigit():
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("image_classification_no_trainer", experiment_config)
# Get the metric function
metric = load_metric("accuracy")
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint)
resume_step = None
path = args.resume_from_checkpoint
else:
# Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path:
args.num_train_epochs -= int(path.replace("epoch_", ""))
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step
for epoch in range(args.num_train_epochs):
model.train()
if args.with_tracking:
total_loss = 0
for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step:
continue
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
completed_steps += 1
if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir)
repo.push_to_hub(
commit_message=f"Training in progress {completed_steps} steps",
blocking=False,
auto_lfs_prune=True,
)
if completed_steps >= args.max_train_steps:
break
model.eval()
for step, batch in enumerate(eval_dataloader):
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
metric.add_batch(
predictions=accelerator.gather(predictions),
references=accelerator.gather(batch["labels"]),
)
eval_metric = metric.compute()
logger.info(f"epoch {epoch}: {eval_metric}")
if args.with_tracking:
accelerator.log(
{
"accuracy": eval_metric,
"train_loss": total_loss,
"epoch": epoch,
"step": completed_steps,
},
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir)
repo.push_to_hub(
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
)
if args.checkpointing_steps == "epoch":
output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
if args.output_dir is not None:
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"eval_accuracy": eval_metric["accuracy"]}, f)
if __name__ == "__main__":
main()

View File

@ -52,6 +52,7 @@ sys.path.extend(SRC_DIRS)
if SRC_DIRS is not None:
import run_clm_no_trainer
import run_glue_no_trainer
import run_image_classification_no_trainer
import run_mlm_no_trainer
import run_ner_no_trainer
import run_qa_no_trainer as run_squad_no_trainer
@ -321,3 +322,25 @@ class ExamplesTestsNoTrainer(TestCasePlus):
run_semantic_segmentation_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
def test_run_image_classification_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_image_classification_no_trainer.py
--dataset_name huggingface/image-classification-test-sample
--output_dir {tmp_dir}
--num_warmup_steps=8
--learning_rate=3e-3
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--checkpointing_steps epoch
--with_tracking
--seed 42
""".split()
with patch.object(sys, "argv", testargs):
run_image_classification_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.50)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_classification_no_trainer")))