Image Segmentation pipeline (#13828)
* Implement img seg pipeline * Update src/transformers/pipelines/image_segmentation.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/image_segmentation.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update output shape with individual masks * Rm dev change * Remove loops in test Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
parent
be71ac3bcb
commit
026866df92
|
@ -29,6 +29,7 @@ There are two categories of pipeline abstractions to be aware about:
|
|||
- :class:`~transformers.FeatureExtractionPipeline`
|
||||
- :class:`~transformers.FillMaskPipeline`
|
||||
- :class:`~transformers.ImageClassificationPipeline`
|
||||
- :class:`~transformers.ImageSegmentationPipeline`
|
||||
- :class:`~transformers.ObjectDetectionPipeline`
|
||||
- :class:`~transformers.QuestionAnsweringPipeline`
|
||||
- :class:`~transformers.SummarizationPipeline`
|
||||
|
@ -137,6 +138,13 @@ ImageClassificationPipeline
|
|||
:special-members: __call__
|
||||
:members:
|
||||
|
||||
ImageSegmentationPipeline
|
||||
=======================================================================================================================
|
||||
|
||||
.. autoclass:: transformers.ImageSegmentationPipeline
|
||||
:special-members: __call__
|
||||
:members:
|
||||
|
||||
NerPipeline
|
||||
=======================================================================================================================
|
||||
|
||||
|
|
|
@ -163,6 +163,13 @@ AutoModelForObjectDetection
|
|||
:members:
|
||||
|
||||
|
||||
AutoModelForImageSegmentation
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.AutoModelForImageSegmentation
|
||||
:members:
|
||||
|
||||
|
||||
TFAutoModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -294,6 +294,7 @@ _import_structure = {
|
|||
"FeatureExtractionPipeline",
|
||||
"FillMaskPipeline",
|
||||
"ImageClassificationPipeline",
|
||||
"ImageSegmentationPipeline",
|
||||
"JsonPipelineDataFormat",
|
||||
"NerPipeline",
|
||||
"ObjectDetectionPipeline",
|
||||
|
@ -544,6 +545,7 @@ if is_torch_available():
|
|||
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
|
@ -561,6 +563,7 @@ if is_torch_available():
|
|||
"AutoModelForCausalLM",
|
||||
"AutoModelForCTC",
|
||||
"AutoModelForImageClassification",
|
||||
"AutoModelForImageSegmentation",
|
||||
"AutoModelForMaskedLM",
|
||||
"AutoModelForMultipleChoice",
|
||||
"AutoModelForNextSentencePrediction",
|
||||
|
@ -2113,6 +2116,7 @@ if TYPE_CHECKING:
|
|||
FeatureExtractionPipeline,
|
||||
FillMaskPipeline,
|
||||
ImageClassificationPipeline,
|
||||
ImageSegmentationPipeline,
|
||||
JsonPipelineDataFormat,
|
||||
NerPipeline,
|
||||
ObjectDetectionPipeline,
|
||||
|
@ -2320,6 +2324,7 @@ if TYPE_CHECKING:
|
|||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
|
@ -2337,6 +2342,7 @@ if TYPE_CHECKING:
|
|||
AutoModelForCausalLM,
|
||||
AutoModelForCTC,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForImageSegmentation,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoModelForNextSentencePrediction,
|
||||
|
|
|
@ -45,6 +45,7 @@ from .models.auto.modeling_auto import (
|
|||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
|
||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||
|
@ -60,6 +61,7 @@ from .utils import logging
|
|||
TASK_MAPPING = {
|
||||
"text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
"image-segmentation": MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
|
||||
"fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||
"object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
|
||||
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||
|
@ -273,6 +275,7 @@ should probably proofread and complete it, then remove this comment. -->
|
|||
TASK_TAG_TO_NAME_MAPPING = {
|
||||
"fill-mask": "Masked Language Modeling",
|
||||
"image-classification": "Image Classification",
|
||||
"image-segmentation": "Image Segmentation",
|
||||
"multiple-choice": "Multiple Choice",
|
||||
"object-detection": "Object Detection",
|
||||
"question-answering": "Question Answering",
|
||||
|
|
|
@ -34,6 +34,7 @@ if is_torch_available():
|
|||
"MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"MODEL_FOR_CTC_MAPPING",
|
||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
|
@ -52,6 +53,7 @@ if is_torch_available():
|
|||
"AutoModelForCausalLM",
|
||||
"AutoModelForCTC",
|
||||
"AutoModelForImageClassification",
|
||||
"AutoModelForImageSegmentation",
|
||||
"AutoModelForMaskedLM",
|
||||
"AutoModelForMultipleChoice",
|
||||
"AutoModelForNextSentencePrediction",
|
||||
|
@ -130,6 +132,7 @@ if TYPE_CHECKING:
|
|||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_CTC_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
|
@ -148,6 +151,7 @@ if TYPE_CHECKING:
|
|||
AutoModelForCausalLM,
|
||||
AutoModelForCTC,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForImageSegmentation,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoModelForNextSentencePrediction,
|
||||
|
|
|
@ -228,6 +228,13 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Image Segmentation mapping
|
||||
("detr", "DetrForSegmentation"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Masked LM mapping
|
||||
|
@ -484,6 +491,9 @@ MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_C
|
|||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
||||
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||||
|
@ -614,6 +624,13 @@ class AutoModelForImageClassification(_BaseAutoModelClass):
|
|||
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
|
||||
|
||||
|
||||
class AutoModelForImageSegmentation(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
|
||||
|
||||
|
||||
AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")
|
||||
|
||||
|
||||
class AutoModelForObjectDetection(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
|
||||
|
||||
|
|
|
@ -713,8 +713,50 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||
|
||||
return results
|
||||
|
||||
def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5):
|
||||
"""
|
||||
Converts the output of :class:`~transformers.DetrForSegmentation` into image segmentation predictions. Only
|
||||
supports PyTorch.
|
||||
|
||||
Parameters:
|
||||
outputs (:class:`~transformers.DetrSegmentationOutput`):
|
||||
Raw outputs of the model.
|
||||
target_sizes (:obj:`torch.Tensor` of shape :obj:`(batch_size, 2)` or :obj:`List[Tuple]` of length :obj:`batch_size`):
|
||||
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction.
|
||||
threshold (:obj:`float`, `optional`, defaults to 0.9):
|
||||
Threshold to use to filter out queries.
|
||||
mask_threshold (:obj:`float`, `optional`, defaults to 0.5):
|
||||
Threshold to use when turning the predicted masks into binary values.
|
||||
|
||||
Returns:
|
||||
:obj:`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an
|
||||
image in the batch as predicted by the model.
|
||||
"""
|
||||
out_logits, raw_masks = outputs.logits, outputs.pred_masks
|
||||
preds = []
|
||||
|
||||
def to_tuple(tup):
|
||||
if isinstance(tup, tuple):
|
||||
return tup
|
||||
return tuple(tup.cpu().tolist())
|
||||
|
||||
for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
|
||||
# we filter empty queries and detection below threshold
|
||||
scores, labels = cur_logits.softmax(-1).max(-1)
|
||||
keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold)
|
||||
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
|
||||
cur_scores = cur_scores[keep]
|
||||
cur_classes = cur_classes[keep]
|
||||
cur_masks = cur_masks[keep]
|
||||
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
||||
cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
|
||||
|
||||
predictions = {"scores": cur_scores, "labels": cur_classes, "masks": cur_masks}
|
||||
preds.append(predictions)
|
||||
return preds
|
||||
|
||||
# inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L218
|
||||
def post_process_segmentation(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
|
||||
def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
|
||||
"""
|
||||
Converts the output of :class:`~transformers.DetrForSegmentation` into actual instance segmentation
|
||||
predictions. Only supports PyTorch.
|
||||
|
|
|
@ -44,6 +44,7 @@ from .conversational import Conversation, ConversationalPipeline
|
|||
from .feature_extraction import FeatureExtractionPipeline
|
||||
from .fill_mask import FillMaskPipeline
|
||||
from .image_classification import ImageClassificationPipeline
|
||||
from .image_segmentation import ImageSegmentationPipeline
|
||||
from .object_detection import ObjectDetectionPipeline
|
||||
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
|
||||
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
|
||||
|
@ -92,6 +93,7 @@ if is_torch_available():
|
|||
AutoModelForCausalLM,
|
||||
AutoModelForCTC,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForImageSegmentation,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForObjectDetection,
|
||||
AutoModelForQuestionAnswering,
|
||||
|
@ -231,6 +233,12 @@ SUPPORTED_TASKS = {
|
|||
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
|
||||
},
|
||||
"image-segmentation": {
|
||||
"impl": ImageSegmentationPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForImageSegmentation,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "facebook/detr-resnet-50-panoptic"}},
|
||||
},
|
||||
"object-detection": {
|
||||
"impl": ObjectDetectionPipeline,
|
||||
"tf": (),
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
import base64
|
||||
import io
|
||||
import os
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import requests
|
||||
|
||||
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
||||
from ..utils import logging
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
Prediction = Dict[str, Any]
|
||||
Predictions = List[Prediction]
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class ImageSegmentationPipeline(Pipeline):
|
||||
"""
|
||||
Image segmentation pipeline using any :obj:`AutoModelForImageSegmentation`. This pipeline predicts masks of objects
|
||||
and their classes.
|
||||
|
||||
This image segmntation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
|
||||
task identifier: :obj:`"image-segmentation"`.
|
||||
|
||||
See the list of available models on `huggingface.co/models
|
||||
<https://huggingface.co/models?filter=image-segmentation>`__.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.framework == "tf":
|
||||
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
||||
|
||||
requires_backends(self, "vision")
|
||||
self.check_model_type(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING)
|
||||
|
||||
@staticmethod
|
||||
def load_image(image: Union[str, "Image.Image"]):
|
||||
if isinstance(image, str):
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||
# like http_huggingface_co.png
|
||||
image = Image.open(requests.get(image, stream=True).raw)
|
||||
elif os.path.isfile(image):
|
||||
image = Image.open(image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
|
||||
)
|
||||
elif isinstance(image, Image.Image):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image."
|
||||
)
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
postprocess_kwargs = {}
|
||||
if "threshold" in kwargs:
|
||||
postprocess_kwargs["threshold"] = kwargs["threshold"]
|
||||
if "mask_threshold" in kwargs:
|
||||
postprocess_kwargs["mask_threshold"] = kwargs["mask_threshold"]
|
||||
return {}, {}, postprocess_kwargs
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]:
|
||||
"""
|
||||
Perform segmentation (detect masks & classes) in the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (:obj:`str`, :obj:`List[str]`, :obj:`PIL.Image` or :obj:`List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing an HTTP(S) link pointing to an image
|
||||
- A string containing a local path to an image
|
||||
- An image loaded in PIL directly
|
||||
|
||||
The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
|
||||
same format: all as HTTP(S) links, all as local paths, or all as PIL images.
|
||||
threshold (:obj:`float`, `optional`, defaults to 0.9):
|
||||
The probability necessary to make a prediction.
|
||||
mask_threshold (:obj:`float`, `optional`, defaults to 0.5):
|
||||
Threshold to use when turning the predicted masks into binary values.
|
||||
|
||||
Return:
|
||||
A dictionary or a list of dictionaries containing the result. If the input is a single image, will return a
|
||||
dictionary, if the input is a list of several images, will return a list of dictionaries corresponding to
|
||||
each image.
|
||||
|
||||
The dictionaries contain the following keys:
|
||||
|
||||
- **label** (:obj:`str`) -- The class label identified by the model.
|
||||
- **score** (:obj:`float`) -- The score attributed by the model for that label.
|
||||
- **mask** (:obj:`str`) -- base64 string of a single-channel PNG image that contain masks information. The
|
||||
PNG image has size (heigth, width) of the original image. Pixel values in the image are either 0 or 255
|
||||
(i.e. mask is absent VS mask is present).
|
||||
"""
|
||||
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
def preprocess(self, image):
|
||||
image = self.load_image(image)
|
||||
target_size = torch.IntTensor([[image.height, image.width]])
|
||||
inputs = self.feature_extractor(images=[image], return_tensors="pt")
|
||||
inputs["target_size"] = target_size
|
||||
return inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
target_size = model_inputs.pop("target_size")
|
||||
outputs = self.model(**model_inputs)
|
||||
model_outputs = {"outputs": outputs, "target_size": target_size}
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs, threshold=0.9, mask_threshold=0.5):
|
||||
raw_annotations = self.feature_extractor.post_process_segmentation(
|
||||
model_outputs["outputs"], model_outputs["target_size"], threshold=threshold, mask_threshold=0.5
|
||||
)
|
||||
raw_annotation = raw_annotations[0]
|
||||
|
||||
raw_annotation["masks"] *= 255 # [0,1] -> [0,255] black and white pixels
|
||||
|
||||
raw_annotation["scores"] = raw_annotation["scores"].tolist()
|
||||
raw_annotation["labels"] = [self.model.config.id2label[label.item()] for label in raw_annotation["labels"]]
|
||||
raw_annotation["masks"] = [self._get_mask_str(mask) for mask in raw_annotation["masks"].cpu().numpy()]
|
||||
|
||||
# {"scores": [...], ...} --> [{"score":x, ...}, ...]
|
||||
keys = ["score", "label", "mask"]
|
||||
annotation = [
|
||||
dict(zip(keys, vals))
|
||||
for vals in zip(raw_annotation["scores"], raw_annotation["labels"], raw_annotation["masks"])
|
||||
]
|
||||
|
||||
return annotation
|
||||
|
||||
def _get_mask_str(self, mask: np.array) -> str:
|
||||
"""
|
||||
Turns mask numpy array into mask base64 str.
|
||||
|
||||
Args:
|
||||
mask (np.array): Numpy array (with shape (heigth, width) of the original image) containing masks information. Values in the array are either 0 or 255 (i.e. mask is absent VS mask is present).
|
||||
|
||||
Returns:
|
||||
A base64 string of a single-channel PNG image that contain masks information.
|
||||
"""
|
||||
img = Image.fromarray(mask.astype(np.int8))
|
||||
with io.BytesIO() as out:
|
||||
img.save(out, format="PNG")
|
||||
png_string = out.getvalue()
|
||||
return base64.b64encode(png_string).decode("utf-8")
|
|
@ -316,6 +316,9 @@ MODEL_FOR_CAUSAL_LM_MAPPING = None
|
|||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_MASKED_LM_MAPPING = None
|
||||
|
||||
|
||||
|
@ -397,6 +400,15 @@ class AutoModelForImageClassification:
|
|||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForImageSegmentation:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
|
|
@ -0,0 +1,241 @@
|
|||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||
AutoFeatureExtractor,
|
||||
AutoModelForImageSegmentation,
|
||||
ImageSegmentationPipeline,
|
||||
is_vision_available,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_datasets,
|
||||
require_tf,
|
||||
require_timm,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
else:
|
||||
|
||||
class Image:
|
||||
@staticmethod
|
||||
def open(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_timm
|
||||
@require_torch
|
||||
@is_pipeline_test
|
||||
class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
|
||||
|
||||
@require_datasets
|
||||
def run_pipeline_test(self, model, tokenizer, feature_extractor):
|
||||
image_segmenter = ImageSegmentationPipeline(model=model, feature_extractor=feature_extractor)
|
||||
outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
|
||||
self.assertEqual(outputs, [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12)
|
||||
|
||||
import datasets
|
||||
|
||||
dataset = datasets.load_dataset("Narsil/image_dummy", "image", split="test")
|
||||
|
||||
batch = [
|
||||
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
# RGBA
|
||||
dataset[0]["file"],
|
||||
# LA
|
||||
dataset[1]["file"],
|
||||
# L
|
||||
dataset[2]["file"],
|
||||
]
|
||||
outputs = image_segmenter(batch, threshold=0.0)
|
||||
|
||||
self.assertEqual(len(batch), len(outputs))
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12,
|
||||
[{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12,
|
||||
[{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12,
|
||||
[{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12,
|
||||
[{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12,
|
||||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip("Image segmentation not implemented in TF")
|
||||
def test_small_model_tf(self):
|
||||
pass
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
model_id = "mishig/tiny-detr-mobilenetsv3-panoptic"
|
||||
|
||||
model = AutoModelForImageSegmentation.from_pretrained(model_id)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||
image_segmenter = ImageSegmentationPipeline(model=model, feature_extractor=feature_extractor)
|
||||
|
||||
outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=0.0)
|
||||
for o in outputs:
|
||||
# shortening by hashing
|
||||
o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest()
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_0",
|
||||
"mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc",
|
||||
},
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_0",
|
||||
"mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = image_segmenter(
|
||||
[
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
],
|
||||
threshold=0.0,
|
||||
)
|
||||
for output in outputs:
|
||||
for o in output:
|
||||
o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest()
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_0",
|
||||
"mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc",
|
||||
},
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_0",
|
||||
"mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc",
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_0",
|
||||
"mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc",
|
||||
},
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_0",
|
||||
"mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc",
|
||||
},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_image_segmentation(self):
|
||||
model_id = "facebook/detr-resnet-50-panoptic"
|
||||
|
||||
image_segmenter = pipeline("image-segmentation", model=model_id)
|
||||
|
||||
outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||
for o in outputs:
|
||||
o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest()
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9094, "label": "blanket", "mask": "f939d943609821ad27cdb92844f2754ad3735b52"},
|
||||
{"score": 0.9941, "label": "cat", "mask": "32913606de3958812ced0090df7b699abb6e2644"},
|
||||
{"score": 0.9987, "label": "remote", "mask": "f3988d35f3065f591fa6a0a9414614d98a9ca13e"},
|
||||
{"score": 0.9995, "label": "remote", "mask": "ff0d541ace4fe386fc14ced0c546490a8e7001d7"},
|
||||
{"score": 0.9722, "label": "couch", "mask": "543c3244b291c4aec134f1d8f92af553da795529"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "891313e21290200e6169613e6a9cb7aff9e7b22f"},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = image_segmenter(
|
||||
[
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
],
|
||||
threshold=0.0,
|
||||
)
|
||||
for output in outputs:
|
||||
for o in output:
|
||||
o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest()
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.9094, "label": "blanket", "mask": "f939d943609821ad27cdb92844f2754ad3735b52"},
|
||||
{"score": 0.9941, "label": "cat", "mask": "32913606de3958812ced0090df7b699abb6e2644"},
|
||||
{"score": 0.9987, "label": "remote", "mask": "f3988d35f3065f591fa6a0a9414614d98a9ca13e"},
|
||||
{"score": 0.9995, "label": "remote", "mask": "ff0d541ace4fe386fc14ced0c546490a8e7001d7"},
|
||||
{"score": 0.9722, "label": "couch", "mask": "543c3244b291c4aec134f1d8f92af553da795529"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "891313e21290200e6169613e6a9cb7aff9e7b22f"},
|
||||
],
|
||||
[
|
||||
{"score": 0.9094, "label": "blanket", "mask": "f939d943609821ad27cdb92844f2754ad3735b52"},
|
||||
{"score": 0.9941, "label": "cat", "mask": "32913606de3958812ced0090df7b699abb6e2644"},
|
||||
{"score": 0.9987, "label": "remote", "mask": "f3988d35f3065f591fa6a0a9414614d98a9ca13e"},
|
||||
{"score": 0.9995, "label": "remote", "mask": "ff0d541ace4fe386fc14ced0c546490a8e7001d7"},
|
||||
{"score": 0.9722, "label": "couch", "mask": "543c3244b291c4aec134f1d8f92af553da795529"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "891313e21290200e6169613e6a9cb7aff9e7b22f"},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_threshold(self):
|
||||
threshold = 0.999
|
||||
model_id = "facebook/detr-resnet-50-panoptic"
|
||||
|
||||
image_segmenter = pipeline("image-segmentation", model=model_id)
|
||||
|
||||
outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=threshold)
|
||||
|
||||
for o in outputs:
|
||||
o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest()
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9995, "label": "remote", "mask": "ff0d541ace4fe386fc14ced0c546490a8e7001d7"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "891313e21290200e6169613e6a9cb7aff9e7b22f"},
|
||||
],
|
||||
)
|
Loading…
Reference in New Issue