Add AutoModelForZeroShotImageClassification (#22087)
Adds AutoModelForZeroShotImageClassification to transformers
This commit is contained in:
parent
b90fbc7e0b
commit
32e3466d38
|
@ -258,6 +258,14 @@ The following auto classes are available for the following computer vision tasks
|
|||
|
||||
[[autodoc]] AutoModelForUniversalSegmentation
|
||||
|
||||
### AutoModelForZeroShotImageClassification
|
||||
|
||||
[[autodoc]] AutoModelForZeroShotImageClassification
|
||||
|
||||
### TFAutoModelForZeroShotImageClassification
|
||||
|
||||
[[autodoc]] TFAutoModelForZeroShotImageClassification
|
||||
|
||||
### AutoModelForZeroShotObjectDetection
|
||||
|
||||
[[autodoc]] AutoModelForZeroShotObjectDetection
|
||||
|
|
|
@ -1001,6 +1001,7 @@ else:
|
|||
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
|
||||
"MODEL_MAPPING",
|
||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||
|
@ -1033,6 +1034,7 @@ else:
|
|||
"AutoModelForVideoClassification",
|
||||
"AutoModelForVision2Seq",
|
||||
"AutoModelForVisualQuestionAnswering",
|
||||
"AutoModelForZeroShotImageClassification",
|
||||
"AutoModelForZeroShotObjectDetection",
|
||||
"AutoModelWithLMHead",
|
||||
]
|
||||
|
@ -2785,6 +2787,7 @@ else:
|
|||
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_MAPPING",
|
||||
"TF_MODEL_WITH_LM_HEAD_MAPPING",
|
||||
"TFAutoModel",
|
||||
|
@ -2803,6 +2806,7 @@ else:
|
|||
"TFAutoModelForTableQuestionAnswering",
|
||||
"TFAutoModelForTokenClassification",
|
||||
"TFAutoModelForVision2Seq",
|
||||
"TFAutoModelForZeroShotImageClassification",
|
||||
"TFAutoModelWithLMHead",
|
||||
]
|
||||
)
|
||||
|
@ -4514,6 +4518,7 @@ if TYPE_CHECKING:
|
|||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
|
@ -4546,6 +4551,7 @@ if TYPE_CHECKING:
|
|||
AutoModelForVideoClassification,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModelForVisualQuestionAnswering,
|
||||
AutoModelForZeroShotImageClassification,
|
||||
AutoModelForZeroShotObjectDetection,
|
||||
AutoModelWithLMHead,
|
||||
)
|
||||
|
@ -5971,6 +5977,7 @@ if TYPE_CHECKING:
|
|||
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_MAPPING,
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
||||
TFAutoModel,
|
||||
|
@ -5989,6 +5996,7 @@ if TYPE_CHECKING:
|
|||
TFAutoModelForTableQuestionAnswering,
|
||||
TFAutoModelForTokenClassification,
|
||||
TFAutoModelForVision2Seq,
|
||||
TFAutoModelForZeroShotImageClassification,
|
||||
TFAutoModelWithLMHead,
|
||||
)
|
||||
from .models.bart import (
|
||||
|
|
|
@ -43,6 +43,7 @@ from .models.auto.modeling_auto import (
|
|||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
)
|
||||
from .training_args import ParallelMode
|
||||
from .utils import (
|
||||
|
@ -70,6 +71,7 @@ TASK_MAPPING = {
|
|||
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||
"automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES},
|
||||
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
}
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
|
|
@ -69,6 +69,7 @@ else:
|
|||
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_MAPPING",
|
||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
|
||||
"AutoModel",
|
||||
"AutoBackbone",
|
||||
|
@ -100,6 +101,7 @@ else:
|
|||
"AutoModelForVisualQuestionAnswering",
|
||||
"AutoModelForDocumentQuestionAnswering",
|
||||
"AutoModelWithLMHead",
|
||||
"AutoModelForZeroShotImageClassification",
|
||||
"AutoModelForZeroShotObjectDetection",
|
||||
]
|
||||
|
||||
|
@ -126,6 +128,7 @@ else:
|
|||
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_MAPPING",
|
||||
"TF_MODEL_WITH_LM_HEAD_MAPPING",
|
||||
"TFAutoModel",
|
||||
|
@ -144,6 +147,7 @@ else:
|
|||
"TFAutoModelForTableQuestionAnswering",
|
||||
"TFAutoModelForTokenClassification",
|
||||
"TFAutoModelForVision2Seq",
|
||||
"TFAutoModelForZeroShotImageClassification",
|
||||
"TFAutoModelWithLMHead",
|
||||
]
|
||||
|
||||
|
@ -226,6 +230,7 @@ if TYPE_CHECKING:
|
|||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
|
@ -258,6 +263,7 @@ if TYPE_CHECKING:
|
|||
AutoModelForVideoClassification,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModelForVisualQuestionAnswering,
|
||||
AutoModelForZeroShotImageClassification,
|
||||
AutoModelForZeroShotObjectDetection,
|
||||
AutoModelWithLMHead,
|
||||
)
|
||||
|
@ -285,6 +291,7 @@ if TYPE_CHECKING:
|
|||
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_MAPPING,
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
||||
TFAutoModel,
|
||||
|
@ -303,6 +310,7 @@ if TYPE_CHECKING:
|
|||
TFAutoModelForTableQuestionAnswering,
|
||||
TFAutoModelForTokenClassification,
|
||||
TFAutoModelForVision2Seq,
|
||||
TFAutoModelForZeroShotImageClassification,
|
||||
TFAutoModelWithLMHead,
|
||||
)
|
||||
|
||||
|
|
|
@ -920,7 +920,7 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
|
|||
]
|
||||
)
|
||||
|
||||
_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Zero Shot Image Classification mapping
|
||||
("align", "AlignModel"),
|
||||
|
@ -955,6 +955,9 @@ MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
|||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
|
||||
)
|
||||
|
@ -1142,6 +1145,15 @@ class AutoModelForImageClassification(_BaseAutoModelClass):
|
|||
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
|
||||
|
||||
|
||||
class AutoModelForZeroShotImageClassification(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
|
||||
AutoModelForZeroShotImageClassification = auto_class_update(
|
||||
AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForImageSegmentation(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
|
||||
|
||||
|
|
|
@ -209,6 +209,15 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||
]
|
||||
)
|
||||
|
||||
|
||||
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Zero Shot Image Classification mapping
|
||||
("clip", "TFCLIPModel"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Semantic Segmentation mapping
|
||||
|
@ -424,6 +433,9 @@ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
|||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
|
||||
)
|
||||
|
@ -505,6 +517,15 @@ TFAutoModelForImageClassification = auto_class_update(
|
|||
)
|
||||
|
||||
|
||||
class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
|
||||
TFAutoModelForZeroShotImageClassification = auto_class_update(
|
||||
TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
|
||||
)
|
||||
|
||||
|
||||
class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
||||
|
||||
|
|
|
@ -103,6 +103,7 @@ if is_tf_available():
|
|||
TFAutoModelForTableQuestionAnswering,
|
||||
TFAutoModelForTokenClassification,
|
||||
TFAutoModelForVision2Seq,
|
||||
TFAutoModelForZeroShotImageClassification,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
|
@ -135,6 +136,7 @@ if is_torch_available():
|
|||
AutoModelForVideoClassification,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModelForVisualQuestionAnswering,
|
||||
AutoModelForZeroShotImageClassification,
|
||||
AutoModelForZeroShotObjectDetection,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
|
@ -290,8 +292,8 @@ SUPPORTED_TASKS = {
|
|||
},
|
||||
"zero-shot-image-classification": {
|
||||
"impl": ZeroShotImageClassificationPipeline,
|
||||
"tf": (TFAutoModel,) if is_tf_available() else (),
|
||||
"pt": (AutoModel,) if is_torch_available() else (),
|
||||
"tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": ("openai/clip-vit-base-patch32", "f4881ba"),
|
||||
|
|
|
@ -18,9 +18,10 @@ if is_vision_available():
|
|||
from ..image_utils import load_image
|
||||
|
||||
if is_torch_available():
|
||||
pass
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
if is_tf_available():
|
||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||
from ..tf_utils import stable_softmax
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -64,8 +65,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
|||
super().__init__(**kwargs)
|
||||
|
||||
requires_backends(self, "vision")
|
||||
# No specific FOR_XXX available yet
|
||||
# self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
|
||||
self.check_model_type(
|
||||
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||
if self.framework == "tf"
|
||||
else MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||
)
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
|
||||
"""
|
||||
|
@ -137,9 +141,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
|||
if self.framework == "pt":
|
||||
probs = logits.softmax(dim=-1).squeeze(-1)
|
||||
scores = probs.tolist()
|
||||
else:
|
||||
elif self.framework == "tf":
|
||||
probs = stable_softmax(logits, axis=-1)
|
||||
scores = probs.numpy().tolist()
|
||||
else:
|
||||
raise ValueError(f"Unsupported framework: {self.framework}")
|
||||
|
||||
result = [
|
||||
{"score": score, "label": candidate_label}
|
||||
|
|
|
@ -526,6 +526,9 @@ MODEL_FOR_VISION_2_SEQ_MAPPING = None
|
|||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = None
|
||||
|
||||
|
||||
|
@ -738,6 +741,13 @@ class AutoModelForVisualQuestionAnswering(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForZeroShotImageClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForZeroShotObjectDetection(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
|
|
@ -316,6 +316,9 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
|
|||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = None
|
||||
|
||||
|
||||
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
TF_MODEL_MAPPING = None
|
||||
|
||||
|
||||
|
@ -434,6 +437,13 @@ class TFAutoModelForVision2Seq(metaclass=DummyObject):
|
|||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFAutoModelForZeroShotImageClassification(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFAutoModelWithLMHead(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ from ..models.auto.modeling_auto import (
|
|||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_MAPPING_NAMES,
|
||||
)
|
||||
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
|
||||
|
@ -79,6 +80,7 @@ def _generate_supported_model_class_names(
|
|||
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
||||
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
"ctc": MODEL_FOR_CTC_MAPPING_NAMES,
|
||||
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||
"semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
|
||||
|
|
|
@ -93,8 +93,8 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
|||
("image-to-text", "MODEL_FOR_FOR_VISION_2_SEQ_MAPPING_NAMES", "AutoModelForVision2Seq"),
|
||||
(
|
||||
"zero-shot-image-classification",
|
||||
"_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES",
|
||||
"AutoModel",
|
||||
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES",
|
||||
"AutoModelForZeroShotImageClassification",
|
||||
),
|
||||
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
|
||||
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
|
||||
|
|
Loading…
Reference in New Issue