Add AutoModelForZeroShotImageClassification (#22087)

Adds AutoModelForZeroShotImageClassification to transformers
This commit is contained in:
Alara Dirik 2023-03-13 12:46:14 +03:00 committed by GitHub
parent b90fbc7e0b
commit 32e3466d38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 98 additions and 9 deletions

View File

@ -258,6 +258,14 @@ The following auto classes are available for the following computer vision tasks
[[autodoc]] AutoModelForUniversalSegmentation
### AutoModelForZeroShotImageClassification
[[autodoc]] AutoModelForZeroShotImageClassification
### TFAutoModelForZeroShotImageClassification
[[autodoc]] TFAutoModelForZeroShotImageClassification
### AutoModelForZeroShotObjectDetection
[[autodoc]] AutoModelForZeroShotObjectDetection

View File

@ -1001,6 +1001,7 @@ else:
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
"MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING",
@ -1033,6 +1034,7 @@ else:
"AutoModelForVideoClassification",
"AutoModelForVision2Seq",
"AutoModelForVisualQuestionAnswering",
"AutoModelForZeroShotImageClassification",
"AutoModelForZeroShotObjectDetection",
"AutoModelWithLMHead",
]
@ -2785,6 +2787,7 @@ else:
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
"TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel",
@ -2803,6 +2806,7 @@ else:
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
"TFAutoModelForZeroShotImageClassification",
"TFAutoModelWithLMHead",
]
)
@ -4514,6 +4518,7 @@ if TYPE_CHECKING:
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
@ -4546,6 +4551,7 @@ if TYPE_CHECKING:
AutoModelForVideoClassification,
AutoModelForVision2Seq,
AutoModelForVisualQuestionAnswering,
AutoModelForZeroShotImageClassification,
AutoModelForZeroShotObjectDetection,
AutoModelWithLMHead,
)
@ -5971,6 +5977,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
@ -5989,6 +5996,7 @@ if TYPE_CHECKING:
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelForZeroShotImageClassification,
TFAutoModelWithLMHead,
)
from .models.bart import (

View File

@ -43,6 +43,7 @@ from .models.auto.modeling_auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
)
from .training_args import ParallelMode
from .utils import (
@ -70,6 +71,7 @@ TASK_MAPPING = {
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
"automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES},
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
}
logger = logging.get_logger(__name__)

View File

@ -69,6 +69,7 @@ else:
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
"MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING",
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
"AutoModel",
"AutoBackbone",
@ -100,6 +101,7 @@ else:
"AutoModelForVisualQuestionAnswering",
"AutoModelForDocumentQuestionAnswering",
"AutoModelWithLMHead",
"AutoModelForZeroShotImageClassification",
"AutoModelForZeroShotObjectDetection",
]
@ -126,6 +128,7 @@ else:
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
"TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel",
@ -144,6 +147,7 @@ else:
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
"TFAutoModelForZeroShotImageClassification",
"TFAutoModelWithLMHead",
]
@ -226,6 +230,7 @@ if TYPE_CHECKING:
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
@ -258,6 +263,7 @@ if TYPE_CHECKING:
AutoModelForVideoClassification,
AutoModelForVision2Seq,
AutoModelForVisualQuestionAnswering,
AutoModelForZeroShotImageClassification,
AutoModelForZeroShotObjectDetection,
AutoModelWithLMHead,
)
@ -285,6 +291,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
@ -303,6 +310,7 @@ if TYPE_CHECKING:
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelForZeroShotImageClassification,
TFAutoModelWithLMHead,
)

View File

@ -920,7 +920,7 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
]
)
_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Zero Shot Image Classification mapping
("align", "AlignModel"),
@ -955,6 +955,9 @@ MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
)
@ -1142,6 +1145,15 @@ class AutoModelForImageClassification(_BaseAutoModelClass):
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
class AutoModelForZeroShotImageClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
AutoModelForZeroShotImageClassification = auto_class_update(
AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
)
class AutoModelForImageSegmentation(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING

View File

@ -209,6 +209,15 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
]
)
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Zero Shot Image Classification mapping
("clip", "TFCLIPModel"),
]
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
# Model for Semantic Segmentation mapping
@ -424,6 +433,9 @@ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
)
@ -505,6 +517,15 @@ TFAutoModelForImageClassification = auto_class_update(
)
class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
TFAutoModelForZeroShotImageClassification = auto_class_update(
TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
)
class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING

View File

@ -103,6 +103,7 @@ if is_tf_available():
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelForZeroShotImageClassification,
)
if is_torch_available():
@ -135,6 +136,7 @@ if is_torch_available():
AutoModelForVideoClassification,
AutoModelForVision2Seq,
AutoModelForVisualQuestionAnswering,
AutoModelForZeroShotImageClassification,
AutoModelForZeroShotObjectDetection,
)
if TYPE_CHECKING:
@ -290,8 +292,8 @@ SUPPORTED_TASKS = {
},
"zero-shot-image-classification": {
"impl": ZeroShotImageClassificationPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (),
"tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (),
"pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("openai/clip-vit-base-patch32", "f4881ba"),

View File

@ -18,9 +18,10 @@ if is_vision_available():
from ..image_utils import load_image
if is_torch_available():
pass
from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
from ..tf_utils import stable_softmax
logger = logging.get_logger(__name__)
@ -64,8 +65,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
super().__init__(**kwargs)
requires_backends(self, "vision")
# No specific FOR_XXX available yet
# self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
self.check_model_type(
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
if self.framework == "tf"
else MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
)
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
"""
@ -137,9 +141,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
if self.framework == "pt":
probs = logits.softmax(dim=-1).squeeze(-1)
scores = probs.tolist()
else:
elif self.framework == "tf":
probs = stable_softmax(logits, axis=-1)
scores = probs.numpy().tolist()
else:
raise ValueError(f"Unsupported framework: {self.framework}")
result = [
{"score": score, "label": candidate_label}

View File

@ -526,6 +526,9 @@ MODEL_FOR_VISION_2_SEQ_MAPPING = None
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = None
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = None
@ -738,6 +741,13 @@ class AutoModelForVisualQuestionAnswering(metaclass=DummyObject):
requires_backends(self, ["torch"])
class AutoModelForZeroShotImageClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelForZeroShotObjectDetection(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -316,6 +316,9 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = None
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = None
TF_MODEL_MAPPING = None
@ -434,6 +437,13 @@ class TFAutoModelForVision2Seq(metaclass=DummyObject):
requires_backends(self, ["tf"])
class TFAutoModelForZeroShotImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFAutoModelWithLMHead(metaclass=DummyObject):
_backends = ["tf"]

View File

@ -50,6 +50,7 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
@ -79,6 +80,7 @@ def _generate_supported_model_class_names(
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
"ctc": MODEL_FOR_CTC_MAPPING_NAMES,
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
"semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,

View File

@ -93,8 +93,8 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
("image-to-text", "MODEL_FOR_FOR_VISION_2_SEQ_MAPPING_NAMES", "AutoModelForVision2Seq"),
(
"zero-shot-image-classification",
"_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES",
"AutoModel",
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES",
"AutoModelForZeroShotImageClassification",
),
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),