diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 3e5b3701de..f81d76a966 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -91,6 +91,45 @@ def is_batched(img): return False +def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]: + """ + Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1. + If the input is a batch of images, it is converted to a list of images. + + Args: + images (`ImageInput`): + Image of images to turn into a list of images. + expected_ndims (`int`, *optional*, defaults to 3): + Expected number of dimensions for a single input image. If the input image has a different number of + dimensions, an error is raised. + """ + if is_batched(images): + return images + + # Either the input is a single image, in which case we create a list of length 1 + if isinstance(images, PIL.Image.Image): + # PIL images are never batched + return [images] + + if is_valid_image(images): + if images.ndim == expected_ndims + 1: + # Batch of images + images = [image for image in images] + elif images.ndim == expected_ndims: + # Single image + images = [images] + else: + raise ValueError( + f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got" + f" {images.ndim} dimensions." + ) + return images + raise ValueError( + "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or " + f"jax.ndarray, but got {type(images)}." + ) + + def to_numpy_array(img) -> np.ndarray: if not is_valid_image(img): raise ValueError(f"Invalid image type: {type(img)}") diff --git a/src/transformers/models/beit/image_processing_beit.py b/src/transformers/models/beit/image_processing_beit.py index 10a6bfad22..4fa8b0aab0 100644 --- a/src/transformers/models/beit/image_processing_beit.py +++ b/src/transformers/models/beit/image_processing_beit.py @@ -30,7 +30,7 @@ from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -438,9 +438,9 @@ class BeitImageProcessor(BaseImageProcessor): image_std = image_std if image_std is not None else self.image_std do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels - if not is_batched(images): - images = [images] - segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None + images = make_list_of_images(images) + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/bit/image_processing_bit.py b/src/transformers/models/bit/image_processing_bit.py index f210ad30de..4395d05584 100644 --- a/src/transformers/models/bit/image_processing_bit.py +++ b/src/transformers/models/bit/image_processing_bit.py @@ -30,7 +30,14 @@ from ...image_transforms import ( resize, to_channel_dimension_format, ) -from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + make_list_of_images, + to_numpy_array, + valid_images, +) from ...utils import logging from ...utils.import_utils import is_vision_available @@ -286,8 +293,7 @@ class BitImageProcessor(BaseImageProcessor): image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/blip/image_processing_blip.py b/src/transformers/models/blip/image_processing_blip.py index 4310a073fc..e9b49924ec 100644 --- a/src/transformers/models/blip/image_processing_blip.py +++ b/src/transformers/models/blip/image_processing_blip.py @@ -29,7 +29,7 @@ from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -247,8 +247,7 @@ class BlipImageProcessor(BaseImageProcessor): size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False) - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py index 593ba05f82..848bc086a1 100644 --- a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py +++ b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py @@ -30,7 +30,14 @@ from ...image_transforms import ( resize, to_channel_dimension_format, ) -from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + make_list_of_images, + to_numpy_array, + valid_images, +) from ...utils import logging from ...utils.import_utils import is_vision_available @@ -284,8 +291,7 @@ class ChineseCLIPImageProcessor(BaseImageProcessor): image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/clip/image_processing_clip.py b/src/transformers/models/clip/image_processing_clip.py index 380411b47a..ac99feb54d 100644 --- a/src/transformers/models/clip/image_processing_clip.py +++ b/src/transformers/models/clip/image_processing_clip.py @@ -30,7 +30,14 @@ from ...image_transforms import ( resize, to_channel_dimension_format, ) -from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + make_list_of_images, + to_numpy_array, + valid_images, +) from ...utils import logging from ...utils.import_utils import is_vision_available @@ -286,8 +293,7 @@ class CLIPImageProcessor(BaseImageProcessor): image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py index aed1ee65c2..1cd271b59b 100644 --- a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py +++ b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py @@ -44,7 +44,7 @@ from transformers.image_utils import ( PILImageResampling, get_image_size, infer_channel_dimension_format, - is_batched, + make_list_of_images, to_numpy_array, valid_coco_detection_annotations, valid_coco_panoptic_annotations, @@ -1172,9 +1172,9 @@ class ConditionalDetrImageProcessor(BaseImageProcessor): if do_normalize is not None and (image_mean is None or image_std is None): raise ValueError("Image mean and std must be specified if do_normalize is True.") - if not is_batched(images): - images = [images] - annotations = [annotations] if annotations is not None else None + images = make_list_of_images(images) + if annotations is not None and isinstance(annotations[0], dict): + annotations = [annotations] if annotations is not None and len(images) != len(annotations): raise ValueError( diff --git a/src/transformers/models/convnext/image_processing_convnext.py b/src/transformers/models/convnext/image_processing_convnext.py index 57382a05a8..2353767df5 100644 --- a/src/transformers/models/convnext/image_processing_convnext.py +++ b/src/transformers/models/convnext/image_processing_convnext.py @@ -36,7 +36,7 @@ from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -272,8 +272,7 @@ class ConvNextImageProcessor(BaseImageProcessor): size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False) - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/deformable_detr/image_processing_deformable_detr.py b/src/transformers/models/deformable_detr/image_processing_deformable_detr.py index 12a32ac593..ea974843c3 100644 --- a/src/transformers/models/deformable_detr/image_processing_deformable_detr.py +++ b/src/transformers/models/deformable_detr/image_processing_deformable_detr.py @@ -44,7 +44,7 @@ from transformers.image_utils import ( PILImageResampling, get_image_size, infer_channel_dimension_format, - is_batched, + make_list_of_images, to_numpy_array, valid_coco_detection_annotations, valid_coco_panoptic_annotations, @@ -1170,9 +1170,9 @@ class DeformableDetrImageProcessor(BaseImageProcessor): if do_normalize is not None and (image_mean is None or image_std is None): raise ValueError("Image mean and std must be specified if do_normalize is True.") - if not is_batched(images): - images = [images] - annotations = [annotations] if annotations is not None else None + images = make_list_of_images(images) + if annotations is not None and isinstance(annotations[0], dict): + annotations = [annotations] if annotations is not None and len(images) != len(annotations): raise ValueError( diff --git a/src/transformers/models/deit/image_processing_deit.py b/src/transformers/models/deit/image_processing_deit.py index 6d60a17012..2c0ad59fa5 100644 --- a/src/transformers/models/deit/image_processing_deit.py +++ b/src/transformers/models/deit/image_processing_deit.py @@ -29,7 +29,7 @@ from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -276,8 +276,7 @@ class DeiTImageProcessor(BaseImageProcessor): crop_size = crop_size if crop_size is not None else self.crop_size crop_size = get_size_dict(crop_size, param_name="crop_size") - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/detr/image_processing_detr.py b/src/transformers/models/detr/image_processing_detr.py index 0e22eeb893..00391d3819 100644 --- a/src/transformers/models/detr/image_processing_detr.py +++ b/src/transformers/models/detr/image_processing_detr.py @@ -43,7 +43,7 @@ from transformers.image_utils import ( PILImageResampling, get_image_size, infer_channel_dimension_format, - is_batched, + make_list_of_images, to_numpy_array, valid_coco_detection_annotations, valid_coco_panoptic_annotations, @@ -1138,9 +1138,9 @@ class DetrImageProcessor(BaseImageProcessor): if do_normalize is not None and (image_mean is None or image_std is None): raise ValueError("Image mean and std must be specified if do_normalize is True.") - if not is_batched(images): - images = [images] - annotations = [annotations] if annotations is not None else None + images = make_list_of_images(images) + if annotations is not None and isinstance(annotations[0], dict): + annotations = [annotations] if annotations is not None and len(images) != len(annotations): raise ValueError( diff --git a/src/transformers/models/donut/image_processing_donut.py b/src/transformers/models/donut/image_processing_donut.py index 7fe402a09d..835fdb9f64 100644 --- a/src/transformers/models/donut/image_processing_donut.py +++ b/src/transformers/models/donut/image_processing_donut.py @@ -34,7 +34,7 @@ from ...image_utils import ( ImageInput, PILImageResampling, get_image_size, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -396,8 +396,7 @@ class DonutImageProcessor(BaseImageProcessor): image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/dpt/image_processing_dpt.py b/src/transformers/models/dpt/image_processing_dpt.py index 3bfe80c9e8..2bc57c9a2a 100644 --- a/src/transformers/models/dpt/image_processing_dpt.py +++ b/src/transformers/models/dpt/image_processing_dpt.py @@ -31,9 +31,9 @@ from ...image_utils import ( ImageInput, PILImageResampling, get_image_size, - is_batched, is_torch_available, is_torch_tensor, + make_list_of_images, to_numpy_array, valid_images, ) @@ -308,8 +308,7 @@ class DPTImageProcessor(BaseImageProcessor): image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/flava/image_processing_flava.py b/src/transformers/models/flava/image_processing_flava.py index 22e062306f..c41bb37e72 100644 --- a/src/transformers/models/flava/image_processing_flava.py +++ b/src/transformers/models/flava/image_processing_flava.py @@ -26,7 +26,14 @@ from transformers.utils.generic import TensorType from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format -from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + make_list_of_images, + to_numpy_array, + valid_images, +) from ...utils import logging @@ -647,8 +654,7 @@ class FlavaImageProcessor(BaseImageProcessor): codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else self.codebook_image_mean codebook_image_std = codebook_image_std if codebook_image_std is not None else self.codebook_image_std - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/glpn/image_processing_glpn.py b/src/transformers/models/glpn/image_processing_glpn.py index 5d5cd8c198..0533d4c242 100644 --- a/src/transformers/models/glpn/image_processing_glpn.py +++ b/src/transformers/models/glpn/image_processing_glpn.py @@ -24,7 +24,7 @@ from transformers.utils.generic import TensorType from ...image_processing_utils import BaseImageProcessor, BatchFeature from ...image_transforms import rescale, resize, to_channel_dimension_format -from ...image_utils import ChannelDimension, get_image_size, is_batched, to_numpy_array, valid_images +from ...image_utils import ChannelDimension, get_image_size, make_list_of_images, to_numpy_array, valid_images from ...utils import logging @@ -166,8 +166,7 @@ class GLPNImageProcessor(BaseImageProcessor): if do_resize and size_divisor is None: raise ValueError("size_divisor is required for resizing") - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError("Invalid image(s)") diff --git a/src/transformers/models/imagegpt/image_processing_imagegpt.py b/src/transformers/models/imagegpt/image_processing_imagegpt.py index e775b50a28..af31fdf191 100644 --- a/src/transformers/models/imagegpt/image_processing_imagegpt.py +++ b/src/transformers/models/imagegpt/image_processing_imagegpt.py @@ -23,7 +23,14 @@ from transformers.utils.generic import TensorType from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import rescale, resize, to_channel_dimension_format -from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + make_list_of_images, + to_numpy_array, + valid_images, +) from ...utils import logging @@ -196,8 +203,7 @@ class ImageGPTImageProcessor(BaseImageProcessor): do_color_quantize = do_color_quantize if do_color_quantize is not None else self.do_color_quantize clusters = clusters if clusters is not None else self.clusters - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py index 454dc50cb4..04547eebd8 100644 --- a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py @@ -28,7 +28,7 @@ from ...image_utils import ( ImageInput, PILImageResampling, infer_channel_dimension_format, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -230,8 +230,7 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor): ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py index 2c74d8ed9b..c2cd270846 100644 --- a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py @@ -30,7 +30,7 @@ from ...image_utils import ( ImageInput, PILImageResampling, infer_channel_dimension_format, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -320,8 +320,7 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor): ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/levit/image_processing_levit.py b/src/transformers/models/levit/image_processing_levit.py index 4b2fc85ecd..b369bfd33a 100644 --- a/src/transformers/models/levit/image_processing_levit.py +++ b/src/transformers/models/levit/image_processing_levit.py @@ -35,7 +35,7 @@ from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -303,8 +303,7 @@ class LevitImageProcessor(BaseImageProcessor): crop_size = crop_size if crop_size is not None else self.crop_size crop_size = get_size_dict(crop_size, param_name="crop_size") - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/maskformer/image_processing_maskformer.py b/src/transformers/models/maskformer/image_processing_maskformer.py index 89f6c6920a..2a59a0f4db 100644 --- a/src/transformers/models/maskformer/image_processing_maskformer.py +++ b/src/transformers/models/maskformer/image_processing_maskformer.py @@ -37,7 +37,7 @@ from transformers.image_utils import ( PILImageResampling, get_image_size, infer_channel_dimension_format, - is_batched, + make_list_of_images, valid_images, ) from transformers.utils import ( @@ -717,9 +717,9 @@ class MaskFormerImageProcessor(BaseImageProcessor): "torch.Tensor, tf.Tensor or jax.ndarray." ) - if not is_batched(images): - images = [images] - segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None + images = make_list_of_images(images) + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) if segmentation_maps is not None and len(images) != len(segmentation_maps): raise ValueError("Images and segmentation maps must have the same length.") diff --git a/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py b/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py index 1bf7ccd113..9843f600be 100644 --- a/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +++ b/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py @@ -35,7 +35,7 @@ from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -288,8 +288,7 @@ class MobileNetV1ImageProcessor(BaseImageProcessor): image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py index 92fa04081d..343152ebde 100644 --- a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +++ b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py @@ -36,7 +36,7 @@ from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -295,8 +295,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit.py b/src/transformers/models/mobilevit/image_processing_mobilevit.py index a7a4a071d9..7cf24216fe 100644 --- a/src/transformers/models/mobilevit/image_processing_mobilevit.py +++ b/src/transformers/models/mobilevit/image_processing_mobilevit.py @@ -28,7 +28,7 @@ from ...image_utils import ( ImageInput, PILImageResampling, infer_channel_dimension_format, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -284,8 +284,7 @@ class MobileViTImageProcessor(BaseImageProcessor): crop_size = crop_size if crop_size is not None else self.crop_size crop_size = get_size_dict(crop_size, param_name="crop_size") - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/oneformer/image_processing_oneformer.py b/src/transformers/models/oneformer/image_processing_oneformer.py index 2a679b95c8..af36dbe0ab 100644 --- a/src/transformers/models/oneformer/image_processing_oneformer.py +++ b/src/transformers/models/oneformer/image_processing_oneformer.py @@ -38,7 +38,7 @@ from transformers.image_utils import ( PILImageResampling, get_image_size, infer_channel_dimension_format, - is_batched, + make_list_of_images, valid_images, ) from transformers.utils import ( @@ -676,9 +676,9 @@ class OneFormerImageProcessor(BaseImageProcessor): "torch.Tensor, tf.Tensor or jax.ndarray." ) - if not is_batched(images): - images = [images] - segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None + images = make_list_of_images(images) + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) if segmentation_maps is not None and len(images) != len(segmentation_maps): raise ValueError("Images and segmentation maps must have the same length.") diff --git a/src/transformers/models/owlvit/image_processing_owlvit.py b/src/transformers/models/owlvit/image_processing_owlvit.py index fc3f0fa331..650a2d787e 100644 --- a/src/transformers/models/owlvit/image_processing_owlvit.py +++ b/src/transformers/models/owlvit/image_processing_owlvit.py @@ -29,7 +29,13 @@ from transformers.image_transforms import ( to_channel_dimension_format, to_numpy_array, ) -from transformers.image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, valid_images +from transformers.image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + make_list_of_images, + valid_images, +) from transformers.utils import TensorType, is_torch_available, logging @@ -300,8 +306,7 @@ class OwlViTImageProcessor(BaseImageProcessor): if do_normalize is not None and (image_mean is None or image_std is None): raise ValueError("Image mean and std must be specified if do_normalize is True.") - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/perceiver/image_processing_perceiver.py b/src/transformers/models/perceiver/image_processing_perceiver.py index 18161a97e0..00bf238865 100644 --- a/src/transformers/models/perceiver/image_processing_perceiver.py +++ b/src/transformers/models/perceiver/image_processing_perceiver.py @@ -30,7 +30,7 @@ from ...image_utils import ( ImageInput, PILImageResampling, get_image_size, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -289,8 +289,7 @@ class PerceiverImageProcessor(BaseImageProcessor): image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/poolformer/image_processing_poolformer.py b/src/transformers/models/poolformer/image_processing_poolformer.py index 896465551c..d78bf30327 100644 --- a/src/transformers/models/poolformer/image_processing_poolformer.py +++ b/src/transformers/models/poolformer/image_processing_poolformer.py @@ -36,7 +36,7 @@ from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -339,8 +339,7 @@ class PoolFormerImageProcessor(BaseImageProcessor): crop_size = crop_size if crop_size is not None else self.crop_size crop_size = get_size_dict(crop_size, param_name="crop_size") - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/segformer/image_processing_segformer.py b/src/transformers/models/segformer/image_processing_segformer.py index f8122e38c1..f52e7d0d00 100644 --- a/src/transformers/models/segformer/image_processing_segformer.py +++ b/src/transformers/models/segformer/image_processing_segformer.py @@ -30,7 +30,7 @@ from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -385,9 +385,9 @@ class SegformerImageProcessor(BaseImageProcessor): image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std - if not is_batched(images): - images = [images] - segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None + images = make_list_of_images(images) + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/swin2sr/image_processing_swin2sr.py b/src/transformers/models/swin2sr/image_processing_swin2sr.py index c5c5458d8a..62ec9db16c 100644 --- a/src/transformers/models/swin2sr/image_processing_swin2sr.py +++ b/src/transformers/models/swin2sr/image_processing_swin2sr.py @@ -22,7 +22,7 @@ from transformers.utils.generic import TensorType from ...image_processing_utils import BaseImageProcessor, BatchFeature from ...image_transforms import get_image_size, pad, rescale, to_channel_dimension_format -from ...image_utils import ChannelDimension, ImageInput, is_batched, to_numpy_array, valid_images +from ...image_utils import ChannelDimension, ImageInput, make_list_of_images, to_numpy_array, valid_images from ...utils import logging @@ -148,8 +148,7 @@ class Swin2SRImageProcessor(BaseImageProcessor): do_pad = do_pad if do_pad is not None else self.do_pad pad_size = pad_size if pad_size is not None else self.pad_size - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/vilt/image_processing_vilt.py b/src/transformers/models/vilt/image_processing_vilt.py index e4fbdec032..9e50c1f7d5 100644 --- a/src/transformers/models/vilt/image_processing_vilt.py +++ b/src/transformers/models/vilt/image_processing_vilt.py @@ -32,7 +32,7 @@ from ...image_utils import ( PILImageResampling, get_image_size, infer_channel_dimension_format, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -441,8 +441,7 @@ class ViltImageProcessor(BaseImageProcessor): size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False) - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/vit/image_processing_vit.py b/src/transformers/models/vit/image_processing_vit.py index 4287b34b73..4b089443b5 100644 --- a/src/transformers/models/vit/image_processing_vit.py +++ b/src/transformers/models/vit/image_processing_vit.py @@ -28,7 +28,7 @@ from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - is_batched, + make_list_of_images, to_numpy_array, valid_images, ) @@ -243,8 +243,7 @@ class ViTImageProcessor(BaseImageProcessor): size = size if size is not None else self.size size_dict = get_size_dict(size) - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py b/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py index 2cd0074708..12bb63fc66 100644 --- a/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py @@ -30,7 +30,14 @@ from ...image_transforms import ( resize, to_channel_dimension_format, ) -from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + make_list_of_images, + to_numpy_array, + valid_images, +) from ...utils import logging from ...utils.import_utils import is_vision_available @@ -286,8 +293,7 @@ class ViTHybridImageProcessor(BaseImageProcessor): image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - if not is_batched(images): - images = [images] + images = make_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/yolos/image_processing_yolos.py b/src/transformers/models/yolos/image_processing_yolos.py index f21b38283d..814b5602f8 100644 --- a/src/transformers/models/yolos/image_processing_yolos.py +++ b/src/transformers/models/yolos/image_processing_yolos.py @@ -42,7 +42,7 @@ from transformers.image_utils import ( PILImageResampling, get_image_size, infer_channel_dimension_format, - is_batched, + make_list_of_images, to_numpy_array, valid_coco_detection_annotations, valid_coco_panoptic_annotations, @@ -1038,9 +1038,9 @@ class YolosImageProcessor(BaseImageProcessor): if do_normalize is not None and (image_mean is None or image_std is None): raise ValueError("Image mean and std must be specified if do_normalize is True.") - if not is_batched(images): - images = [images] - annotations = [annotations] if annotations is not None else None + images = make_list_of_images(images) + if annotations is not None and isinstance(annotations[0], dict): + annotations = [annotations] if annotations is not None and len(images) != len(annotations): raise ValueError( diff --git a/tests/utils/test_image_utils.py b/tests/utils/test_image_utils.py index 6868e117c4..dd816f5931 100644 --- a/tests/utils/test_image_utils.py +++ b/tests/utils/test_image_utils.py @@ -20,7 +20,7 @@ import numpy as np import pytest from transformers import is_torch_available, is_vision_available -from transformers.image_utils import ChannelDimension, get_channel_dimension_axis +from transformers.image_utils import ChannelDimension, get_channel_dimension_axis, make_list_of_images from transformers.testing_utils import require_torch, require_vision @@ -102,6 +102,58 @@ class ImageFeatureExtractionTester(unittest.TestCase): self.assertEqual(array5.shape, (3, 16, 32)) self.assertTrue(np.array_equal(array5, array1)) + def test_make_list_of_images_numpy(self): + # Test a single image is converted to a list of 1 image + images = np.random.randint(0, 256, (16, 32, 3)) + images_list = make_list_of_images(images) + self.assertEqual(len(images_list), 1) + self.assertTrue(np.array_equal(images_list[0], images)) + self.assertIsInstance(images_list, list) + + # Test a batch of images is converted to a list of images + images = np.random.randint(0, 256, (4, 16, 32, 3)) + images_list = make_list_of_images(images) + self.assertEqual(len(images_list), 4) + self.assertTrue(np.array_equal(images_list[0], images[0])) + self.assertIsInstance(images_list, list) + + # Test a list of images is not modified + images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)] + images_list = make_list_of_images(images) + self.assertEqual(len(images_list), 4) + self.assertTrue(np.array_equal(images_list[0], images[0])) + self.assertIsInstance(images_list, list) + + # Test batched masks with no channel dimension are converted to a list of masks + masks = np.random.randint(0, 2, (4, 16, 32)) + masks_list = make_list_of_images(masks, expected_ndims=2) + self.assertEqual(len(masks_list), 4) + self.assertTrue(np.array_equal(masks_list[0], masks[0])) + self.assertIsInstance(masks_list, list) + + @require_torch + def test_make_list_of_images_torch(self): + # Test a single image is converted to a list of 1 image + images = torch.randint(0, 256, (16, 32, 3)) + images_list = make_list_of_images(images) + self.assertEqual(len(images_list), 1) + self.assertTrue(np.array_equal(images_list[0], images)) + self.assertIsInstance(images_list, list) + + # Test a batch of images is converted to a list of images + images = torch.randint(0, 256, (4, 16, 32, 3)) + images_list = make_list_of_images(images) + self.assertEqual(len(images_list), 4) + self.assertTrue(np.array_equal(images_list[0], images[0])) + self.assertIsInstance(images_list, list) + + # Test a list of images is left unchanged + images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)] + images_list = make_list_of_images(images) + self.assertEqual(len(images_list), 4) + self.assertTrue(np.array_equal(images_list[0], images[0])) + self.assertIsInstance(images_list, list) + @require_torch def test_conversion_torch_to_array(self): feature_extractor = ImageFeatureExtractionMixin()