Update examples with image processors (#21155)

* Update examples to use image processors

* Small fixes

* Resolve conflicts
This commit is contained in:
amyeroberts 2023-01-19 15:14:58 +00:00 committed by GitHub
parent fc8a93507c
commit 4bc18e7a83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 124 additions and 137 deletions

View File

@ -22,13 +22,7 @@ The cross-attention will be randomly initialized.
from dataclasses import dataclass, field
from typing import Optional
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoTokenizer,
FlaxVisionEncoderDecoderModel,
HfArgumentParser,
)
from transformers import AutoConfig, AutoImageProcessor, AutoTokenizer, FlaxVisionEncoderDecoderModel, HfArgumentParser
@dataclass
@ -108,13 +102,13 @@ def main():
model.config.decoder_start_token_id = decoder_start_token_id
model.config.pad_token_id = pad_token_id
feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.encoder_model_name_or_path)
image_processor = AutoImageProcessor.from_pretrained(model_args.encoder_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_args.decoder_model_name_or_path)
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.pad_token_id)
model.save_pretrained(model_args.output_dir)
feature_extractor.save_pretrained(model_args.output_dir)
image_processor.save_pretrained(model_args.output_dir)
tokenizer.save_pretrained(model_args.output_dir)

View File

@ -47,7 +47,7 @@ from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository, create_repo
from transformers import (
AutoFeatureExtractor,
AutoImageProcessor,
AutoTokenizer,
FlaxVisionEncoderDecoderModel,
HfArgumentParser,
@ -106,12 +106,12 @@ class TrainingArguments:
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
)
_block_size_doc = """
The default value `0` will preprocess (tokenization + feature extraction) the whole dataset before training and
The default value `0` will preprocess (tokenization + image processing) the whole dataset before training and
cache the results. This uses more disk space, but avoids (repeated) processing time during training. This is a
good option if your disk space is large enough to store the whole processed dataset.
If a positive value is given, the captions in the dataset will be tokenized before training and the results are
cached. During training, it iterates the dataset in chunks of size `block_size`. On each block, images are
transformed by the feature extractor with the results being kept in memory (no cache), and batches of size
transformed by the image processor with the results being kept in memory (no cache), and batches of size
`batch_size` are yielded before processing the next block. This could avoid the heavy disk usage when the
dataset is large.
"""
@ -477,7 +477,7 @@ def main():
dtype=getattr(jnp, model_args.dtype),
use_auth_token=True if model_args.use_auth_token else None,
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
image_processor = AutoImageProcessor.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
@ -546,7 +546,7 @@ def main():
for image_file in examples[image_column]:
try:
image = Image.open(image_file)
feature_extractor(images=image, return_tensors="np")
image_processor(images=image, return_tensors="np")
bools.append(True)
except Exception:
bools.append(False)
@ -582,9 +582,9 @@ def main():
return model_inputs
def feature_extraction_fn(examples, check_image=True):
def image_processing_fn(examples, check_image=True):
"""
Run feature extraction on images
Run preprocessing on images
If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded.
Otherwise, an exception will be thrown.
@ -609,18 +609,18 @@ def main():
else:
images = [Image.open(image_file) for image_file in examples[image_column]]
encoder_inputs = feature_extractor(images=images, return_tensors="np")
encoder_inputs = image_processor(images=images, return_tensors="np")
model_inputs["pixel_values"] = encoder_inputs.pixel_values
return model_inputs
def preprocess_fn(examples, max_target_length, check_image=True):
"""Run tokenization + image feature extraction"""
"""Run tokenization + image processing"""
model_inputs = {}
# This contains image path column
model_inputs.update(tokenization_fn(examples, max_target_length))
model_inputs.update(feature_extraction_fn(model_inputs, check_image=check_image))
model_inputs.update(image_processing_fn(model_inputs, check_image=check_image))
# Remove image path column
model_inputs.pop(image_column)
@ -644,15 +644,15 @@ def main():
}
)
# If `block_size` is `0`, tokenization & image feature extraction is done at the beginning
run_feat_ext_at_beginning = training_args.block_size == 0
# If `block_size` is `0`, tokenization & image processing is done at the beginning
run_img_proc_at_beginning = training_args.block_size == 0
# Used in .map() below
function_kwarg = preprocess_fn if run_feat_ext_at_beginning else tokenization_fn
function_kwarg = preprocess_fn if run_img_proc_at_beginning else tokenization_fn
# `features` is used only for the final preprocessed dataset (for the performance purpose).
features_kwarg = features if run_feat_ext_at_beginning else None
# Keep `image_column` if the feature extraction is done during training
remove_columns_kwarg = [x for x in column_names if x != image_column or run_feat_ext_at_beginning]
processor_names = "tokenizer and feature extractor" if run_feat_ext_at_beginning else "tokenizer"
features_kwarg = features if run_img_proc_at_beginning else None
# Keep `image_column` if the image processing is done during training
remove_columns_kwarg = [x for x in column_names if x != image_column or run_img_proc_at_beginning]
processor_names = "tokenizer and image processor" if run_img_proc_at_beginning else "tokenizer"
# Store some constant
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
@ -671,9 +671,9 @@ def main():
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
# remove problematic examples
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
# (if image processing is performed at the beginning, the filtering is done during preprocessing below
# instead here.)
if not run_feat_ext_at_beginning:
if not run_img_proc_at_beginning:
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
train_dataset = train_dataset.map(
function=function_kwarg,
@ -686,7 +686,7 @@ def main():
fn_kwargs={"max_target_length": data_args.max_target_length},
features=features_kwarg,
)
if run_feat_ext_at_beginning:
if run_img_proc_at_beginning:
# set format (for performance) since the dataset is ready to be used
train_dataset = train_dataset.with_format("numpy")
@ -705,9 +705,9 @@ def main():
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
# remove problematic examples
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
# (if image processing is performed at the beginning, the filtering is done during preprocessing below
# instead here.)
if not run_feat_ext_at_beginning:
if not run_img_proc_at_beginning:
eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
eval_dataset = eval_dataset.map(
function=function_kwarg,
@ -720,7 +720,7 @@ def main():
fn_kwargs={"max_target_length": data_args.val_max_target_length},
features=features_kwarg,
)
if run_feat_ext_at_beginning:
if run_img_proc_at_beginning:
# set format (for performance) since the dataset is ready to be used
eval_dataset = eval_dataset.with_format("numpy")
@ -735,9 +735,9 @@ def main():
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))
# remove problematic examples
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
# (if image processing is performed at the beginning, the filtering is done during preprocessing below
# instead here.)
if not run_feat_ext_at_beginning:
if not run_img_proc_at_beginning:
predict_dataset = predict_dataset.filter(
filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers
)
@ -752,7 +752,7 @@ def main():
fn_kwargs={"max_target_length": data_args.val_max_target_length},
features=features_kwarg,
)
if run_feat_ext_at_beginning:
if run_img_proc_at_beginning:
# set format (for performance) since the dataset is ready to be used
predict_dataset = predict_dataset.with_format("numpy")
@ -771,8 +771,8 @@ def main():
"""
Wrap the simple `data_loader` in a block-wise way if `block_size` > 0, else it's the same as `data_loader`.
If `block_size` > 0, it requires `ds` to have a column that gives image paths in order to perform image feature
extraction (with the column name being specified by `image_column`). The tokenization should be done before
If `block_size` > 0, it requires `ds` to have a column that gives image paths in order to perform image
processing (with the column name being specified by `image_column`). The tokenization should be done before
training in this case.
"""
@ -804,7 +804,7 @@ def main():
_ds = ds.select(selected_indices)
_ds = _ds.map(
feature_extraction_fn,
image_processing_fn,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[image_column],
@ -813,7 +813,7 @@ def main():
keep_in_memory=keep_in_memory,
# The images are already checked either in `.filter()` or in `preprocess_fn()`
fn_kwargs={"check_image": False},
desc=f"Running feature extraction on {split} dataset".replace(" ", " "),
desc=f"Running image processing on {split} dataset".replace(" ", " "),
)
_ds = _ds.with_format("numpy")

View File

@ -52,15 +52,15 @@ ds = datasets.load_dataset("ydshieh/coco_dataset_script", "2017", data_dir=COCO_
### Create a model from a vision encoder model and a text decoder model
Next, we create a [VisionTextDualEncoderModel](https://huggingface.co/docs/transformers/model_doc/vision-text-dual-encoder#visiontextdualencoder).
The `VisionTextDualEncoderModel` class let's you load any vision and text encoder model to create a dual encoder.
The `VisionTextDualEncoderModel` class let's you load any vision and text encoder model to create a dual encoder.
Here is an example of how to load the model using pre-trained vision and text models.
```python3
from transformers import (
VisionTextDualEncoderModel,
VisionTextDualEncoderProcessor,
AutoTokenizer,
AutoFeatureExtractor
VisionTextDualEncoderModel,
VisionTextDualEncoderProcessor,
AutoTokenizer,
AutoImageProcessor
)
model = VisionTextDualEncoderModel.from_vision_text_pretrained(
@ -68,8 +68,8 @@ model = VisionTextDualEncoderModel.from_vision_text_pretrained(
)
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
feat_ext = AutoFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
processor = VisionTextDualEncoderProcessor(feat_ext, tokenizer)
image_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
processor = VisionTextDualEncoderProcessor(image_processor, tokenizer)
# save the model and processor
model.save_pretrained("clip-roberta")

View File

@ -38,7 +38,7 @@ from torchvision.transforms.functional import InterpolationMode
import transformers
from transformers import (
AutoFeatureExtractor,
AutoImageProcessor,
AutoModel,
AutoTokenizer,
HfArgumentParser,
@ -74,7 +74,7 @@ class ModelArguments:
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
@ -308,7 +308,7 @@ def main():
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# 5. Load pretrained model, tokenizer, and feature extractor
# 5. Load pretrained model, tokenizer, and image processor
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
@ -323,9 +323,9 @@ def main():
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
# Load feature_extractor, in this script we only use this to get the mean and std for normalization.
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name or model_args.model_name_or_path,
# Load image_processor, in this script we only use this to get the mean and std for normalization.
image_processor = AutoImageProcessor.from_pretrained(
model_args.image_processor_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
@ -386,7 +386,7 @@ def main():
# 7. Preprocessing the datasets.
# Initialize torchvision transforms and jit it for faster processing.
image_transformations = Transform(
config.vision_config.image_size, feature_extractor.image_mean, feature_extractor.image_std
config.vision_config.image_size, image_processor.image_mean, image_processor.image_std
)
image_transformations = torch.jit.script(image_transformations)

View File

@ -38,7 +38,7 @@ import transformers
from transformers import (
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
AutoConfig,
AutoFeatureExtractor,
AutoImageProcessor,
AutoModelForImageClassification,
HfArgumentParser,
Trainer,
@ -141,7 +141,7 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
use_auth_token: bool = field(
default=False,
metadata={
@ -283,19 +283,19 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name or model_args.model_name_or_path,
image_processor = AutoImageProcessor.from_pretrained(
model_args.image_processor_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
# Define torchvision transforms to be applied to each image.
if "shortest_edge" in feature_extractor.size:
size = feature_extractor.size["shortest_edge"]
if "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
else:
size = (feature_extractor.size["height"], feature_extractor.size["width"])
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
size = (image_processor.size["height"], image_processor.size["width"])
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
_train_transforms = Compose(
[
RandomResizedCrop(size),
@ -352,7 +352,7 @@ def main():
train_dataset=dataset["train"] if training_args.do_train else None,
eval_dataset=dataset["validation"] if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=feature_extractor,
tokenizer=image_processor,
data_collator=collate_fn,
)

View File

@ -41,13 +41,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository, create_repo
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoModelForImageClassification,
SchedulerType,
get_scheduler,
)
from transformers import AutoConfig, AutoImageProcessor, AutoModelForImageClassification, SchedulerType, get_scheduler
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
@ -294,7 +288,7 @@ def main():
label2id = {label: str(i) for i, label in enumerate(labels)}
id2label = {str(i): label for i, label in enumerate(labels)}
# Load pretrained model and feature extractor
# Load pretrained model and image processor
#
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
@ -305,7 +299,7 @@ def main():
label2id=label2id,
finetuning_task="image-classification",
)
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_name_or_path)
image_processor = AutoImageProcessor.from_pretrained(args.model_name_or_path)
model = AutoModelForImageClassification.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
@ -316,11 +310,11 @@ def main():
# Preprocessing the datasets
# Define torchvision transforms to be applied to each image.
if "shortest_edge" in feature_extractor.size:
size = feature_extractor.size["shortest_edge"]
if "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
else:
size = (feature_extractor.size["height"], feature_extractor.size["width"])
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
size = (image_processor.size["height"], image_processor.size["width"])
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
train_transforms = Compose(
[
RandomResizedCrop(size),
@ -505,7 +499,7 @@ def main():
save_function=accelerator.save,
)
if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir)
image_processor.save_pretrained(args.output_dir)
repo.push_to_hub(
commit_message=f"Training in progress {completed_steps} steps",
blocking=False,
@ -547,7 +541,7 @@ def main():
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir)
image_processor.save_pretrained(args.output_dir)
repo.push_to_hub(
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
)
@ -568,7 +562,7 @@ def main():
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir)
image_processor.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)

View File

@ -29,7 +29,7 @@ from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
ViTFeatureExtractor,
ViTImageProcessor,
ViTMAEConfig,
ViTMAEForPreTraining,
)
@ -102,7 +102,7 @@ class DataTrainingArguments:
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/feature extractor we are going to pre-train.
Arguments pertaining to which model/config/image processor we are going to pre-train.
"""
model_name_or_path: str = field(
@ -132,7 +132,7 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
use_auth_token: bool = field(
default=False,
metadata={
@ -230,7 +230,7 @@ def main():
ds["train"] = split["train"]
ds["validation"] = split["test"]
# Load pretrained model and feature extractor
# Load pretrained model and image processor
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
@ -260,13 +260,13 @@ def main():
}
)
# create feature extractor
if model_args.feature_extractor_name:
feature_extractor = ViTFeatureExtractor.from_pretrained(model_args.feature_extractor_name, **config_kwargs)
# create image processor
if model_args.image_processor_name:
image_processor = ViTImageProcessor.from_pretrained(model_args.image_processor_name, **config_kwargs)
elif model_args.model_name_or_path:
feature_extractor = ViTFeatureExtractor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
image_processor = ViTImageProcessor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else:
feature_extractor = ViTFeatureExtractor()
image_processor = ViTImageProcessor()
# create model
if model_args.model_name_or_path:
@ -298,17 +298,17 @@ def main():
# transformations as done in original MAE paper
# source: https://github.com/facebookresearch/mae/blob/main/main_pretrain.py
if "shortest_edge" in feature_extractor.size:
size = feature_extractor.size["shortest_edge"]
if "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
else:
size = (feature_extractor.size["height"], feature_extractor.size["width"])
size = (image_processor.size["height"], image_processor.size["width"])
transforms = Compose(
[
Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
RandomResizedCrop(size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC),
RandomHorizontalFlip(),
ToTensor(),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
]
)
@ -349,7 +349,7 @@ def main():
args=training_args,
train_dataset=ds["train"] if training_args.do_train else None,
eval_dataset=ds["validation"] if training_args.do_eval else None,
tokenizer=feature_extractor,
tokenizer=image_processor,
data_collator=collate_fn,
)

View File

@ -27,10 +27,10 @@ from torchvision.transforms import Compose, Lambda, Normalize, RandomHorizontalF
import transformers
from transformers import (
CONFIG_MAPPING,
FEATURE_EXTRACTOR_MAPPING,
IMAGE_PROCESSOR_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
AutoConfig,
AutoFeatureExtractor,
AutoImageProcessor,
AutoModelForMaskedImageModeling,
HfArgumentParser,
Trainer,
@ -115,7 +115,7 @@ class DataTrainingArguments:
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/feature extractor we are going to pre-train.
Arguments pertaining to which model/config/image processor we are going to pre-train.
"""
model_name_or_path: str = field(
@ -152,7 +152,7 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
use_auth_token: bool = field(
default=False,
metadata={
@ -334,17 +334,16 @@ def main():
}
)
# create feature extractor
if model_args.feature_extractor_name:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.feature_extractor_name, **config_kwargs)
# create image processor
if model_args.image_processor_name:
image_processor = AutoImageProcessor.from_pretrained(model_args.image_processor_name, **config_kwargs)
elif model_args.model_name_or_path:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
image_processor = AutoImageProcessor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else:
FEATURE_EXTRACTOR_TYPES = {
conf.model_type: feature_extractor_class
for conf, feature_extractor_class in FEATURE_EXTRACTOR_MAPPING.items()
IMAGE_PROCESSOR_TYPES = {
conf.model_type: image_processor_class for conf, image_processor_class in IMAGE_PROCESSOR_MAPPING.items()
}
feature_extractor = FEATURE_EXTRACTOR_TYPES[model_args.model_type]()
image_processor = IMAGE_PROCESSOR_TYPES[model_args.model_type]()
# create model
if model_args.model_name_or_path:
@ -382,7 +381,7 @@ def main():
RandomResizedCrop(model_args.image_size, scale=(0.67, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
RandomHorizontalFlip(),
ToTensor(),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
]
)
@ -427,7 +426,7 @@ def main():
args=training_args,
train_dataset=ds["train"] if training_args.do_train else None,
eval_dataset=ds["validation"] if training_args.do_eval else None,
tokenizer=feature_extractor,
tokenizer=image_processor,
data_collator=collate_fn,
)

View File

@ -40,7 +40,7 @@ from datasets import Dataset, DatasetDict, Image
# your images can of course have a different extension
# semantic segmentation maps are typically stored in the png format
image_paths_train = ["path/to/image_1.jpg/jpg", "path/to/image_2.jpg/jpg", ..., "path/to/image_n.jpg/jpg"]
image_paths_train = ["path/to/image_1.jpg/jpg", "path/to/image_2.jpg/jpg", ..., "path/to/image_n.jpg/jpg"]
label_paths_train = ["path/to/annotation_1.png", "path/to/annotation_2.png", ..., "path/to/annotation_n.png"]
# same for validation
@ -52,7 +52,7 @@ def create_dataset(image_paths, label_paths):
"label": sorted(label_paths)})
dataset = dataset.cast_column("image", Image())
dataset = dataset.cast_column("label", Image())
return dataset
# step 1: create Dataset objects
@ -91,7 +91,7 @@ You can easily upload this by clicking on "Add file" in the "Files and versions"
## PyTorch version, Trainer
Based on the script [`run_semantic_segmentation.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py).
Based on the script [`run_semantic_segmentation.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py).
The script leverages the [🤗 Trainer API](https://huggingface.co/docs/transformers/main_classes/trainer) to automatically take care of the training for you, running on distributed environments right away.
@ -130,7 +130,7 @@ Note that you can replace the model and dataset by simply setting the `model_nam
Based on the script [`run_semantic_segmentation_no_trainer.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py).
The script leverages [🤗 `Accelerate`](https://github.com/huggingface/accelerate), which allows to write your own training loop in PyTorch, but have it run instantly on any (distributed) environment, including CPU, multi-CPU, GPU, multi-GPU and TPU. It also supports mixed precision.
The script leverages [🤗 `Accelerate`](https://github.com/huggingface/accelerate), which allows to write your own training loop in PyTorch, but have it run instantly on any (distributed) environment, including CPU, multi-CPU, GPU, multi-GPU and TPU. It also supports mixed precision.
First, run:
@ -161,11 +161,11 @@ The resulting model can be seen here: https://huggingface.co/nielsr/segformer-fi
This means that after training, you can easily load your trained model as follows:
```python
from transformers import AutoFeatureExtractor, AutoModelForSemanticSegmentation
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
model_name = "name_of_repo_on_the_hub_or_path_to_local_folder"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
image_processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForSemanticSegmentation.from_pretrained(model_name)
```
@ -180,7 +180,7 @@ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
# prepare image for the model
inputs = feature_extractor(images=image, return_tensors="pt")
inputs = image_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
@ -201,4 +201,4 @@ For visualization of the segmentation maps, we refer to the [example notebook](h
Some datasets, like [`scene_parse_150`](https://huggingface.co/datasets/scene_parse_150), contain a "background" label that is not part of the classes. The Scene Parse 150 dataset for instance contains labels between 0 and 150, with 0 being the background class, and 1 to 150 being actual class names (like "tree", "person", etc.). For these kind of datasets, one replaces the background label (0) by 255, which is the `ignore_index` of the PyTorch model's loss function, and reduces all labels by 1. This way, the `labels` are PyTorch tensors containing values between 0 and 149, and 255 for all background/padding.
In case you're training on such a dataset, make sure to set the ``reduce_labels`` flag, which will take care of this.
In case you're training on such a dataset, make sure to set the ``reduce_labels`` flag, which will take care of this.

View File

@ -34,7 +34,7 @@ import transformers
from huggingface_hub import hf_hub_download
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoImageProcessor,
AutoModelForSemanticSegmentation,
HfArgumentParser,
Trainer,
@ -240,7 +240,7 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
use_auth_token: bool = field(
default=False,
metadata={
@ -358,7 +358,7 @@ def main():
references=labels,
num_labels=len(id2label),
ignore_index=0,
reduce_labels=feature_extractor.do_reduce_labels,
reduce_labels=image_processor.do_reduce_labels,
)
# add per category metrics as individual key-value pairs
per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
@ -385,8 +385,8 @@ def main():
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name or model_args.model_name_or_path,
image_processor = AutoImageProcessor.from_pretrained(
model_args.image_processor_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
@ -395,11 +395,11 @@ def main():
# Define torchvision transforms to be applied to each image + target.
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
if "shortest_edge" in feature_extractor.size:
if "shortest_edge" in image_processor.size:
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
size = (feature_extractor.size["shortest_edge"], feature_extractor.size["shortest_edge"])
size = (image_processor.size["shortest_edge"], image_processor.size["shortest_edge"])
else:
size = (feature_extractor.size["height"], feature_extractor.size["width"])
size = (image_processor.size["height"], image_processor.size["width"])
train_transforms = Compose(
[
ReduceLabels() if data_args.reduce_labels else Identity(),
@ -407,7 +407,7 @@ def main():
RandomHorizontalFlip(flip_prob=0.5),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
]
)
# Define torchvision transform to be applied to each image.
@ -418,7 +418,7 @@ def main():
Resize(size=size),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
]
)
@ -477,7 +477,7 @@ def main():
train_dataset=dataset["train"] if training_args.do_train else None,
eval_dataset=dataset["validation"] if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=feature_extractor,
tokenizer=image_processor,
data_collator=default_data_collator,
)

View File

@ -39,7 +39,7 @@ from accelerate.utils import set_seed
from huggingface_hub import Repository, create_repo, hf_hub_download
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoImageProcessor,
AutoModelForSemanticSegmentation,
SchedulerType,
default_data_collator,
@ -397,20 +397,20 @@ def main():
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
# Load pretrained model and feature extractor
# Load pretrained model and image processor
config = AutoConfig.from_pretrained(args.model_name_or_path, id2label=id2label, label2id=label2id)
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_name_or_path)
image_processor = AutoImageProcessor.from_pretrained(args.model_name_or_path)
model = AutoModelForSemanticSegmentation.from_pretrained(args.model_name_or_path, config=config)
# Preprocessing the datasets
# Define torchvision transforms to be applied to each image + target.
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
if "shortest_edge" in feature_extractor.size:
if "shortest_edge" in image_processor.size:
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
size = (feature_extractor.size["shortest_edge"], feature_extractor.size["shortest_edge"])
size = (image_processor.size["shortest_edge"], image_processor.size["shortest_edge"])
else:
size = (feature_extractor.size["height"], feature_extractor.size["width"])
size = (image_processor.size["height"], image_processor.size["width"])
train_transforms = Compose(
[
ReduceLabels() if args.reduce_labels else Identity(),
@ -418,7 +418,7 @@ def main():
RandomHorizontalFlip(flip_prob=0.5),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
]
)
# Define torchvision transform to be applied to each image.
@ -429,7 +429,7 @@ def main():
Resize(size=size),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
]
)
@ -602,7 +602,7 @@ def main():
save_function=accelerator.save,
)
if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir)
image_processor.save_pretrained(args.output_dir)
repo.push_to_hub(
commit_message=f"Training in progress {completed_steps} steps",
blocking=False,
@ -657,7 +657,7 @@ def main():
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir)
image_processor.save_pretrained(args.output_dir)
repo.push_to_hub(
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
)
@ -678,7 +678,7 @@ def main():
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir)
image_processor.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)

View File

@ -340,7 +340,7 @@ def main():
model.config.id2label = id2label
# Preprocessing the dataset
# The processor does everything for us (prepare the image using LayoutLMv3FeatureExtractor
# The processor does everything for us (prepare the image using LayoutLMv3ImageProcessor
# and prepare the words, boxes and word-level labels using LayoutLMv3TokenizerFast)
def prepare_examples(examples):
images = examples[image_column_name]