Add the ImageClassificationPipeline (#11598)

* Add the ImageClassificationPipeline

* Code review

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>

* Have `load_image` at the module level

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Lysandre Debut 2021-05-07 14:08:40 +02:00 committed by GitHub
parent e7bff0aabe
commit 39084ca663
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 428 additions and 74 deletions

View File

@ -37,6 +37,7 @@ jobs:
- name: Install dependencies
run: |
pip install --upgrade pip
sudo apt -y update && sudo apt install -y libsndfile1-dev
pip install .[dev]
- name: Create model files
run: |

View File

@ -36,6 +36,7 @@ There are two categories of pipeline abstractions to be aware about:
- :class:`~transformers.ZeroShotClassificationPipeline`
- :class:`~transformers.Text2TextGenerationPipeline`
- :class:`~transformers.TableQuestionAnsweringPipeline`
- :class:`~transformers.ImageClassificationPipeline`
The pipeline abstraction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -79,6 +80,13 @@ FillMaskPipeline
:special-members: __call__
:members:
ImageClassificationPipeline
=======================================================================================================================
.. autoclass:: transformers.ImageClassificationPipeline
:special-members: __call__
:members:
NerPipeline
=======================================================================================================================

View File

@ -128,6 +128,13 @@ AutoModelForTableQuestionAnswering
:members:
AutoModelForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoModelForImageClassification
:members:
TFAutoModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -244,6 +244,7 @@ _import_structure = {
"CsvPipelineDataFormat",
"FeatureExtractionPipeline",
"FillMaskPipeline",
"ImageClassificationPipeline",
"JsonPipelineDataFormat",
"NerPipeline",
"PipedPipelineDataFormat",
@ -483,6 +484,7 @@ if is_torch_available():
"MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel",
"AutoModelForCausalLM",
"AutoModelForImageClassification",
"AutoModelForMaskedLM",
"AutoModelForMultipleChoice",
"AutoModelForNextSentencePrediction",
@ -1640,6 +1642,7 @@ if TYPE_CHECKING:
CsvPipelineDataFormat,
FeatureExtractionPipeline,
FillMaskPipeline,
ImageClassificationPipeline,
JsonPipelineDataFormat,
NerPipeline,
PipedPipelineDataFormat,
@ -1845,6 +1848,7 @@ if TYPE_CHECKING:
MODEL_WITH_LM_HEAD_MAPPING,
AutoModel,
AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForNextSentencePrediction,

View File

@ -226,7 +226,7 @@ class FeatureExtractionMixin:
:func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g.,
``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``.
``./my_model_directory/preprocessor_config.json``.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
standard cache should not be used.

View File

@ -14,34 +14,26 @@
# limitations under the License.
""" AutoFeatureExtractor class. """
import os
from collections import OrderedDict
from transformers import DeiTFeatureExtractor, Speech2TextFeatureExtractor, ViTFeatureExtractor
from ... import DeiTConfig, PretrainedConfig, Speech2TextConfig, ViTConfig, Wav2Vec2Config
from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import is_speech_available, is_vision_available
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .configuration_auto import replace_list_option_in_docstrings
if is_speech_available():
from ..speech_to_text.feature_extraction_speech_to_text import Speech2TextFeatureExtractor
else:
Speech2TextFeatureExtractor = None
if is_vision_available():
from ..deit.feature_extraction_deit import DeiTFeatureExtractor
from ..vit.feature_extraction_vit import ViTFeatureExtractor
else:
DeiTFeatureExtractor = None
ViTFeatureExtractor = None
# Build the list of all feature extractors
from ...file_utils import FEATURE_EXTRACTOR_NAME
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
FEATURE_EXTRACTOR_MAPPING = OrderedDict(
[
("deit", DeiTFeatureExtractor),
("s2t", Speech2TextFeatureExtractor),
("vit", ViTFeatureExtractor),
("wav2vec2", Wav2Vec2FeatureExtractor),
(DeiTConfig, DeiTFeatureExtractor),
(Speech2TextConfig, Speech2TextFeatureExtractor),
(ViTConfig, ViTFeatureExtractor),
(Wav2Vec2Config, Wav2Vec2FeatureExtractor),
]
)
@ -89,7 +81,7 @@ class AutoFeatureExtractor:
:func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g.,
``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``.
``./my_model_directory/preprocessor_config.json``.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
standard cache should not be used.
@ -134,20 +126,29 @@ class AutoFeatureExtractor:
>>> feature_extractor = AutoFeatureExtractor.from_pretrained('./test/saved_model/')
"""
config = kwargs.pop("config", None)
kwargs["_from_auto"] = True
is_feature_extraction_file = os.path.isfile(pretrained_model_name_or_path)
is_directory = os.path.isdir(pretrained_model_name_or_path) and os.path.exists(
os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
)
if not is_feature_extraction_file and not is_directory:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
kwargs["_from_auto"] = True
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if "feature_extractor_type" in config_dict:
if type(config) in FEATURE_EXTRACTOR_MAPPING.keys():
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
elif "feature_extractor_type" in config_dict:
feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"])
return feature_extractor_class.from_dict(config_dict, **kwargs)
else:
# Fallback: use pattern matching on the string.
for pattern, feature_extractor_class in FEATURE_EXTRACTOR_MAPPING.items():
if pattern in str(pretrained_model_name_or_path):
return feature_extractor_class.from_dict(config_dict, **kwargs)
raise ValueError(
f"Unrecognized model in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
"its feature_extraction_config.json, or contain one of the following strings "
f"its {FEATURE_EXTRACTOR_NAME}, or contain one of the following strings "
f"in its name: {', '.join(FEATURE_EXTRACTOR_MAPPING.keys())}"
)

View File

@ -97,7 +97,7 @@ class Speech2TextProcessor:
:meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g.,
``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``.
``./my_model_directory/preprocessor_config.json``.
**kwargs
Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and
:class:`~transformers.PreTrainedTokenizer`

View File

@ -96,7 +96,7 @@ class Wav2Vec2Processor:
:meth:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g.,
``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``.
``./my_model_directory/preprocessor_config.json``.
**kwargs
Additional keyword arguments passed along to both :class:`~transformers.SequenceFeatureExtractor` and
:class:`~transformers.PreTrainedTokenizer`

View File

@ -20,9 +20,12 @@ import warnings
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from ..configuration_utils import PretrainedConfig
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..models.auto.tokenization_auto import AutoTokenizer
from ..models.auto.configuration_auto import AutoConfig
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import logging
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
@ -40,6 +43,7 @@ from .base import (
from .conversational import Conversation, ConversationalPipeline
from .feature_extraction import FeatureExtractionPipeline
from .fill_mask import FillMaskPipeline
from .image_classification import ImageClassificationPipeline
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
@ -79,6 +83,7 @@ if is_torch_available():
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
AutoModel,
AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
@ -198,6 +203,12 @@ SUPPORTED_TASKS = {
"pt": AutoModelForCausalLM if is_torch_available() else None,
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
},
"image-classification": {
"impl": ImageClassificationPipeline,
"tf": None,
"pt": AutoModelForImageClassification if is_torch_available() else None,
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
},
}
@ -252,6 +263,7 @@ def pipeline(
model: Optional = None,
config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
framework: Optional[str] = None,
revision: Optional[str] = None,
use_fast: bool = True,
@ -309,6 +321,18 @@ def pipeline(
:obj:`model` is not specified or not a string, then the default tokenizer for :obj:`config` is loaded (if
it is a string). However, if :obj:`config` is also not given or not a string, then the default tokenizer
for the given :obj:`task` will be loaded.
feature_extractor (:obj:`str` or :obj:`~transformers.PreTrainedFeatureExtractor`, `optional`):
The feature extractor that will be used by the pipeline to encode data for the model. This can be a model
identifier or an actual pretrained feature extractor inheriting from
:class:`~transformers.PreTrainedFeatureExtractor`.
Feature extractors are used for non-NLP models, such as Speech or Vision models as well as multi-modal
models. Multi-modal models will also require a tokenizer to be passed.
If not provided, the default feature extractor for the given :obj:`model` will be loaded (if it is a
string). If :obj:`model` is not specified or not a string, then the default feature extractor for
:obj:`config` is loaded (if it is a string). However, if :obj:`config` is also not given or not a string,
then the default feature extractor for the given :obj:`task` will be loaded.
framework (:obj:`str`, `optional`):
The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
must be installed.
@ -359,19 +383,7 @@ def pipeline(
# At that point framework might still be undetermined
model = get_default_model(targeted_task, framework, task_options)
# Try to infer tokenizer from model or config name (if provided as str)
if tokenizer is None:
if isinstance(model, str):
tokenizer = model
elif isinstance(config, str):
tokenizer = config
else:
# Impossible to guest what is the right tokenizer here
raise Exception(
"Impossible to guess which tokenizer to use. "
"Please provided a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
)
model_name = model if isinstance(model, str) else None
modelcard = None
# Try to infer modelcard from model or config name (if provided as str)
if isinstance(model, str):
@ -388,19 +400,6 @@ def pipeline(
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
# Instantiate tokenizer if needed
if isinstance(tokenizer, (str, tuple)):
if isinstance(tokenizer, tuple):
# For tuple we have (tokenizer name, {kwargs})
use_fast = tokenizer[1].pop("use_fast", use_fast)
tokenizer = AutoTokenizer.from_pretrained(
tokenizer[0], use_fast=use_fast, revision=revision, _from_pipeline=task, **tokenizer[1]
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer, revision=revision, use_fast=use_fast, _from_pipeline=task, **model_kwargs
)
# Instantiate config if needed
if isinstance(config, str):
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
@ -434,6 +433,61 @@ def pipeline(
model, config=config, revision=revision, _from_pipeline=task, **model_kwargs
)
model_config = model.config
load_tokenizer = type(model_config) in TOKENIZER_MAPPING
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING
if load_tokenizer:
# Try to infer tokenizer from model or config name (if provided as str)
if tokenizer is None:
if isinstance(model_name, str):
tokenizer = model_name
elif isinstance(config, str):
tokenizer = config
else:
# Impossible to guess what is the right tokenizer here
raise Exception(
"Impossible to guess which tokenizer to use. "
"Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
)
# Instantiate tokenizer if needed
if isinstance(tokenizer, (str, tuple)):
if isinstance(tokenizer, tuple):
# For tuple we have (tokenizer name, {kwargs})
use_fast = tokenizer[1].pop("use_fast", use_fast)
tokenizer_identifier = tokenizer[0]
tokenizer_kwargs = tokenizer[1]
else:
tokenizer_identifier = tokenizer
tokenizer_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_identifier, revision=revision, use_fast=use_fast, _from_pipeline=task, **tokenizer_kwargs
)
if load_feature_extractor:
# Try to infer feature extractor from model or config name (if provided as str)
if feature_extractor is None:
if isinstance(model_name, str):
feature_extractor = model_name
elif isinstance(config, str):
feature_extractor = config
else:
# Impossible to guess what is the right feature_extractor here
raise Exception(
"Impossible to guess which feature extractor to use. "
"Please provide a PreTrainedFeatureExtractor class or a path/identifier "
"to a pretrained feature extractor."
)
# Instantiate feature_extractor if needed
if isinstance(feature_extractor, (str, tuple)):
feature_extractor = AutoFeatureExtractor.from_pretrained(
feature_extractor, revision=revision, _from_pipeline=task, **model_kwargs
)
if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params:
if key.startswith("translation"):
@ -444,4 +498,16 @@ def pipeline(
)
break
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)
if tokenizer is not None:
kwargs["tokenizer"] = tokenizer
if feature_extractor is not None:
kwargs["feature_extractor"] = feature_extractor
return task_class(
model=model,
modelcard=modelcard,
framework=framework,
task=task,
**kwargs,
)

View File

@ -23,6 +23,7 @@ from contextlib import contextmanager
from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
@ -522,7 +523,8 @@ class Pipeline(_ScikitCompat):
def __init__(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel"],
tokenizer: PreTrainedTokenizer,
tokenizer: Optional[PreTrainedTokenizer] = None,
feature_extractor: Optional[PreTrainedFeatureExtractor] = None,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
task: str = "",
@ -537,6 +539,7 @@ class Pipeline(_ScikitCompat):
self.task = task
self.model = model
self.tokenizer = tokenizer
self.feature_extractor = feature_extractor
self.modelcard = modelcard
self.framework = framework
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
@ -565,7 +568,13 @@ class Pipeline(_ScikitCompat):
os.makedirs(save_directory, exist_ok=True)
self.model.save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(save_directory)
if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(save_directory)
if self.modelcard is not None:
self.modelcard.save_pretrained(save_directory)
@ -630,7 +639,14 @@ class Pipeline(_ScikitCompat):
The list of models supported by the pipeline, or a dictionary with model class values.
"""
if not isinstance(supported_models, list): # Create from a model mapping
supported_models = [item[1].__name__ for item in supported_models.items()]
supported_models_names = []
for config, model in supported_models.items():
# Mapping can now contain tuples of models for the same configuration.
if isinstance(model, tuple):
supported_models_names.extend([_model.__name__ for _model in model])
else:
supported_models_names.append(model.__name__)
supported_models = supported_models_names
if self.model.__class__.__name__ not in supported_models:
raise PipelineException(
self.task,

View File

@ -0,0 +1,129 @@
import os
from typing import TYPE_CHECKING, List, Optional, Union
import requests
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel
if is_vision_available():
from PIL import Image
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
class ImageClassificationPipeline(Pipeline):
"""
Image classification pipeline using any :obj:`AutoModelForImageClassification`. This pipeline predicts the class of
an image.
This image classification pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
task identifier: :obj:`"image-classification"`.
See the list of available models on `huggingface.co/models
<https://huggingface.co/models?filter=image-classification>`__.
"""
def __init__(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel"],
feature_extractor: PreTrainedFeatureExtractor,
framework: Optional[str] = None,
**kwargs
):
super().__init__(model, feature_extractor=feature_extractor, framework=framework, **kwargs)
if self.framework == "tf":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
self.feature_extractor = feature_extractor
@staticmethod
def load_image(image: Union[str, "Image.Image"]):
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
return Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
return Image.open(image)
elif isinstance(image, Image.Image):
return image
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], top_k=5):
"""
Assign labels to the image(s) passed as inputs.
Args:
images (:obj:`str`, :obj:`List[str]`, :obj:`PIL.Image` or :obj:`List[PIL.Image]`):
The pipeline handles three types of images:
- A string containing a http link pointing to an image
- A string containing a local path to an image
- An image loaded in PIL directly
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
images.
top_k (:obj:`int`, `optional`, defaults to 5):
The number of top labels that will be returned by the pipeline.
Return:
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
dictionary, if the input is a list of several images, will return a list of dictionaries corresponding to
the images.
The dictionaries contain the following keys:
- **label** (:obj:`str`) -- The label identified by the model.
- **score** (:obj:`int`) -- The score attributed by the model for that label.
"""
is_batched = isinstance(images, list)
if not is_batched:
images = [images]
images = [self.load_image(image) for image in images]
with torch.no_grad():
inputs = self.feature_extractor(images=images, return_tensors="pt")
outputs = self.model(**inputs)
probs = outputs.logits.softmax(-1)
scores, ids = probs.topk(top_k)
scores = scores.tolist()
ids = ids.tolist()
if not is_batched:
scores, ids = scores[0], ids[0]
labels = [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
else:
labels = []
for scores, ids in zip(scores, ids):
labels.append(
[{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
)
return labels

View File

@ -376,6 +376,15 @@ class AutoModelForCausalLM:
requires_backends(self, ["torch"])
class AutoModelForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

BIN
tests/fixtures/coco.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 86 KiB

View File

@ -0,0 +1,3 @@
{
"feature_extractor_type": "Wav2Vec2FeatureExtractor"
}

View File

@ -16,9 +16,10 @@
import os
import unittest
from transformers import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor, Wav2Vec2FeatureExtractor
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
)
@ -29,16 +30,10 @@ class AutoFeatureExtractorTest(unittest.TestCase):
config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_directory(self):
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_file(self):
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_pattern_matching_fallback(self):
"""
In cases where config.json doesn't include a model_type,
perform a few safety checks on the config mapping's order.
"""
# no key string should be included in a later key string (typical failure case)
keys = list(FEATURE_EXTRACTOR_MAPPING.keys())
for i, key in enumerate(keys):
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))

View File

@ -0,0 +1,115 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import (
AutoFeatureExtractor,
AutoModelForImageClassification,
PreTrainedTokenizer,
is_vision_available,
)
from transformers.pipelines import ImageClassificationPipeline, pipeline
from transformers.testing_utils import require_torch, require_vision
if is_vision_available():
from PIL import Image
else:
class Image:
@staticmethod
def open(*args, **kwargs):
pass
@require_vision
@require_torch
class ImageClassificationPipelineTests(unittest.TestCase):
pipeline_task = "image-classification"
small_models = ["lysandre/tiny-vit-random"] # Models tested without the @slow decorator
valid_inputs = [
{"images": "http://images.cocodataset.org/val2017/000000039769.jpg"},
{
"images": [
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/val2017/000000039769.jpg",
]
},
{"images": "tests/fixtures/coco.jpg"},
{"images": ["tests/fixtures/coco.jpg", "tests/fixtures/coco.jpg"]},
{"images": Image.open("tests/fixtures/coco.jpg")},
{"images": [Image.open("tests/fixtures/coco.jpg"), Image.open("tests/fixtures/coco.jpg")]},
{"images": [Image.open("tests/fixtures/coco.jpg"), "tests/fixtures/coco.jpg"]},
]
def test_small_model_from_factory(self):
for small_model in self.small_models:
image_classifier = pipeline("image-classification", model=small_model)
for valid_input in self.valid_inputs:
output = image_classifier(**valid_input)
top_k = valid_input.get("top_k", 5)
def assert_valid_pipeline_output(pipeline_output):
self.assertTrue(isinstance(pipeline_output, list))
self.assertEqual(len(pipeline_output), top_k)
for label_result in pipeline_output:
self.assertTrue(isinstance(label_result, dict))
self.assertIn("label", label_result)
self.assertIn("score", label_result)
if isinstance(valid_input["images"], list):
self.assertEqual(len(valid_input["images"]), len(output))
for individual_output in output:
assert_valid_pipeline_output(individual_output)
else:
assert_valid_pipeline_output(output)
def test_small_model_from_pipeline(self):
for small_model in self.small_models:
model = AutoModelForImageClassification.from_pretrained(small_model)
feature_extractor = AutoFeatureExtractor.from_pretrained(small_model)
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
for valid_input in self.valid_inputs:
output = image_classifier(**valid_input)
top_k = valid_input.get("top_k", 5)
def assert_valid_pipeline_output(pipeline_output):
self.assertTrue(isinstance(pipeline_output, list))
self.assertEqual(len(pipeline_output), top_k)
for label_result in pipeline_output:
self.assertTrue(isinstance(label_result, dict))
self.assertIn("label", label_result)
self.assertIn("score", label_result)
if isinstance(valid_input["images"], list):
# When images are batched, pipeline output is a list of lists of dictionaries
self.assertEqual(len(valid_input["images"]), len(output))
for individual_output in output:
assert_valid_pipeline_output(individual_output)
else:
# When images are batched, pipeline output is a list of dictionaries
assert_valid_pipeline_output(output)
def test_custom_tokenizer(self):
tokenizer = PreTrainedTokenizer()
# Assert that the pipeline can be initialized with a feature extractor that is not in any mapping
image_classifier = pipeline("image-classification", model=self.small_models[0], tokenizer=tokenizer)
self.assertIs(image_classifier.tokenizer, tokenizer)