Make image processors more general (#27690)

* Make image processors more general

* Add backwards compatibility for KOSMOS-2

* Remove use_square_size everywhere

* Remove script
This commit is contained in:
NielsRogge 2023-12-05 10:45:39 +01:00 committed by GitHub
parent 96f9caa10b
commit df40edfb00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 83 additions and 79 deletions

View File

@ -84,10 +84,6 @@ class BitImageProcessor(BaseImageProcessor):
Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`): do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB. Whether to convert the image to RGB.
use_square_size (`bool`, *optional*, defaults to `False`):
The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
`size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not.
Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
@ -105,12 +101,11 @@ class BitImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True, do_convert_rgb: bool = True,
use_square_size: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224} size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=use_square_size) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
@ -125,7 +120,6 @@ class BitImageProcessor(BaseImageProcessor):
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb self.do_convert_rgb = do_convert_rgb
self.use_square_size = use_square_size
# Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize
def resize( def resize(
@ -153,13 +147,19 @@ class BitImageProcessor(BaseImageProcessor):
input_data_format (`ChannelDimension` or `str`, *optional*): input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred. The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=self.use_square_size) default_to_square = True
if "shortest_edge" not in size: if "shortest_edge" in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
image, image,
size=size["shortest_edge"], size=size,
default_to_square=self.use_square_size, default_to_square=default_to_square,
input_data_format=input_data_format, input_data_format=input_data_format,
) )
return resize( return resize(
@ -243,7 +243,7 @@ class BitImageProcessor(BaseImageProcessor):
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=self.use_square_size) size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size

View File

@ -84,10 +84,6 @@ class CLIPImageProcessor(BaseImageProcessor):
Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`): do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB. Whether to convert the image to RGB.
use_square_size (`bool`, *optional*, defaults to `False`):
The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
`size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not.
Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
@ -105,12 +101,11 @@ class CLIPImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True, do_convert_rgb: bool = True,
use_square_size: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224} size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=use_square_size) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
@ -125,7 +120,10 @@ class CLIPImageProcessor(BaseImageProcessor):
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb self.do_convert_rgb = do_convert_rgb
self.use_square_size = use_square_size
# for backwards compatibility of KOSMOS-2
if "use_square_size" in kwargs:
self.size = {"height": size["shortest_edge"], "width": size["shortest_edge"]}
def resize( def resize(
self, self,
@ -152,13 +150,19 @@ class CLIPImageProcessor(BaseImageProcessor):
input_data_format (`ChannelDimension` or `str`, *optional*): input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred. The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=self.use_square_size) default_to_square = True
if "shortest_edge" not in size: if "shortest_edge" in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
image, image,
size=size["shortest_edge"], size=size,
default_to_square=self.use_square_size, default_to_square=default_to_square,
input_data_format=input_data_format, input_data_format=input_data_format,
) )
return resize( return resize(
@ -242,7 +246,7 @@ class CLIPImageProcessor(BaseImageProcessor):
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=self.use_square_size) size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size

View File

@ -79,10 +79,6 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
use_square_size (`bool`, *optional*, defaults to `False`):
The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
`size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not.
Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
@ -99,12 +95,11 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
do_normalize: bool = True, do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
use_square_size: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 256} size = size if size is not None else {"shortest_edge": 256}
size = get_size_dict(size, default_to_square=use_square_size) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size)
self.do_resize = do_resize self.do_resize = do_resize
@ -117,7 +112,6 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.use_square_size = use_square_size
# Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize
def resize( def resize(
@ -145,13 +139,19 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
input_data_format (`ChannelDimension` or `str`, *optional*): input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred. The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=self.use_square_size) default_to_square = True
if "shortest_edge" not in size: if "shortest_edge" in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
image, image,
size=size["shortest_edge"], size=size,
default_to_square=self.use_square_size, default_to_square=default_to_square,
input_data_format=input_data_format, input_data_format=input_data_format,
) )
return resize( return resize(
@ -231,7 +231,7 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=self.use_square_size) size = get_size_dict(size, default_to_square=False)
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size

View File

@ -83,10 +83,6 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
use_square_size (`bool`, *optional*, defaults to `False`):
The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
`size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not.
Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
@ -103,12 +99,11 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
do_normalize: bool = True, do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
use_square_size: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 256} size = size if size is not None else {"shortest_edge": 256}
size = get_size_dict(size, default_to_square=use_square_size) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size, param_name="crop_size") crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize self.do_resize = do_resize
@ -121,7 +116,6 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.use_square_size = use_square_size
# Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize # Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize
def resize( def resize(
@ -149,13 +143,19 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
input_data_format (`ChannelDimension` or `str`, *optional*): input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred. The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=self.use_square_size) default_to_square = True
if "shortest_edge" not in size: if "shortest_edge" in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
image, image,
size=size["shortest_edge"], size=size,
default_to_square=self.use_square_size, default_to_square=default_to_square,
input_data_format=input_data_format, input_data_format=input_data_format,
) )
return resize( return resize(
@ -235,7 +235,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=self.use_square_size) size = get_size_dict(size, default_to_square=False)
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size

View File

@ -78,10 +78,6 @@ class MobileViTImageProcessor(BaseImageProcessor):
do_flip_channel_order (`bool`, *optional*, defaults to `True`): do_flip_channel_order (`bool`, *optional*, defaults to `True`):
Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order` Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order`
parameter in the `preprocess` method. parameter in the `preprocess` method.
use_square_size (`bool`, *optional*, defaults to `False`):
The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
`size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not.
Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
@ -96,12 +92,11 @@ class MobileViTImageProcessor(BaseImageProcessor):
do_center_crop: bool = True, do_center_crop: bool = True,
crop_size: Dict[str, int] = None, crop_size: Dict[str, int] = None,
do_flip_channel_order: bool = True, do_flip_channel_order: bool = True,
use_square_size: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224} size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=use_square_size) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256} crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
crop_size = get_size_dict(crop_size, param_name="crop_size") crop_size = get_size_dict(crop_size, param_name="crop_size")
@ -113,7 +108,6 @@ class MobileViTImageProcessor(BaseImageProcessor):
self.do_center_crop = do_center_crop self.do_center_crop = do_center_crop
self.crop_size = crop_size self.crop_size = crop_size
self.do_flip_channel_order = do_flip_channel_order self.do_flip_channel_order = do_flip_channel_order
self.use_square_size = use_square_size
# Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize with PILImageResampling.BICUBIC->PILImageResampling.BILINEAR # Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize with PILImageResampling.BICUBIC->PILImageResampling.BILINEAR
def resize( def resize(
@ -141,13 +135,19 @@ class MobileViTImageProcessor(BaseImageProcessor):
input_data_format (`ChannelDimension` or `str`, *optional*): input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred. The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=self.use_square_size) default_to_square = True
if "shortest_edge" not in size: if "shortest_edge" in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
image, image,
size=size["shortest_edge"], size=size,
default_to_square=self.use_square_size, default_to_square=default_to_square,
input_data_format=input_data_format, input_data_format=input_data_format,
) )
return resize( return resize(
@ -246,7 +246,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
) )
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=self.use_square_size) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size, param_name="crop_size") crop_size = get_size_dict(crop_size, param_name="crop_size")

View File

@ -84,10 +84,6 @@ class ViTHybridImageProcessor(BaseImageProcessor):
Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`): do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB. Whether to convert the image to RGB.
use_square_size (`bool`, *optional*, defaults to `False`):
The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
`size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not.
Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
@ -105,12 +101,11 @@ class ViTHybridImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True, do_convert_rgb: bool = True,
use_square_size: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224} size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=use_square_size) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
@ -125,7 +120,6 @@ class ViTHybridImageProcessor(BaseImageProcessor):
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb self.do_convert_rgb = do_convert_rgb
self.use_square_size = use_square_size
# Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize
def resize( def resize(
@ -153,13 +147,19 @@ class ViTHybridImageProcessor(BaseImageProcessor):
input_data_format (`ChannelDimension` or `str`, *optional*): input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred. The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=self.use_square_size) default_to_square = True
if "shortest_edge" not in size: if "shortest_edge" in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
image, image,
size=size["shortest_edge"], size=size,
default_to_square=self.use_square_size, default_to_square=default_to_square,
input_data_format=input_data_format, input_data_format=input_data_format,
) )
return resize( return resize(
@ -243,7 +243,7 @@ class ViTHybridImageProcessor(BaseImageProcessor):
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=self.use_square_size) size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size

View File

@ -55,7 +55,7 @@ class Kosmos2ProcessorTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.tmpdirname = tempfile.mkdtemp() self.tmpdirname = tempfile.mkdtemp()
image_processor = CLIPImageProcessor(use_square_size=True) image_processor = CLIPImageProcessor()
# We have a SentencePiece fixture for testing # We have a SentencePiece fixture for testing
slow_tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB) slow_tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB)