Add Image Processors (#19796)
* Add CLIP image processor * Crop size as dict too * Update warning * Actually use logger this time * Normalize doesn't change dtype of input * Add perceiver image processor * Tidy up * Add DPT image processor * Add Vilt image processor * Tidy up * Add poolformer image processor * Tidy up * Add LayoutLM v2 and v3 imsge processors * Tidy up * Add Flava image processor * Tidy up * Add deit image processor * Tidy up * Add ConvNext image processor * Tidy up * Add levit image processor * Add segformer image processor * Add in post processing * Fix up * Add ImageGPT image processor * Fixup * Add mobilevit image processor * Tidy up * Add postprocessing * Fixup * Add VideoMAE image processor * Tidy up * Add ImageGPT image processor * Fixup * Add ViT image processor * Tidy up * Add beit image processor * Add mobilevit image processor * Tidy up * Add postprocessing * Fixup * Fix up * Fix flava and remove tree module * Fix image classification pipeline failing tests * Update feature extractor in trainer scripts * Update pad_if_smaller to accept tuple and int size * Update for image segmentation pipeline * Update src/transformers/models/perceiver/image_processing_perceiver.py Co-authored-by: Alara Dirik <8944735+alaradirik@users.noreply.github.com> * Update src/transformers/image_processing_utils.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/beit/image_processing_beit.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * PR comments - docstrings; remove accidentally added resize; var names * Update docstrings * Add exception if size is not in the right format * Fix exception check * Fix up * Use shortest_edge in tuple in script Co-authored-by: Alara Dirik <8944735+alaradirik@users.noreply.github.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
parent
2e3452af0f
commit
a6b7759880
|
@ -361,9 +361,12 @@ For computer vision tasks, it is common to add some type of data augmentation to
|
||||||
>>> from torchvision.transforms import Compose, Normalize, RandomResizedCrop, ColorJitter, ToTensor
|
>>> from torchvision.transforms import Compose, Normalize, RandomResizedCrop, ColorJitter, ToTensor
|
||||||
|
|
||||||
>>> normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
>>> normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
||||||
>>> _transforms = Compose(
|
>>> size = (
|
||||||
... [RandomResizedCrop(feature_extractor.size), ColorJitter(brightness=0.5, hue=0.5), ToTensor(), normalize]
|
... feature_extractor.size["shortest_edge"]
|
||||||
|
... if "shortest_edge" in feature_extractor.size
|
||||||
|
... else (feature_extractor.size["height"], feature_extractor.size["width"])
|
||||||
... )
|
... )
|
||||||
|
>>> _transforms = Compose([RandomResizedCrop(size), ColorJitter(brightness=0.5, hue=0.5), ToTensor(), normalize])
|
||||||
```
|
```
|
||||||
|
|
||||||
2. The model accepts [`pixel_values`](model_doc/visionencoderdecoder#transformers.VisionEncoderDecoderModel.forward.pixel_values) as its input, which is generated by the feature extractor. Create a function that generates `pixel_values` from the transforms:
|
2. The model accepts [`pixel_values`](model_doc/visionencoderdecoder#transformers.VisionEncoderDecoderModel.forward.pixel_values) as its input, which is generated by the feature extractor. Create a function that generates `pixel_values` from the transforms:
|
||||||
|
@ -487,4 +490,4 @@ Load a processor with [`AutoProcessor.from_pretrained`]:
|
||||||
>>> prepare_dataset(lj_speech[0])
|
>>> prepare_dataset(lj_speech[0])
|
||||||
```
|
```
|
||||||
|
|
||||||
The processor has now added `input_values` and `labels`, and the sampling rate has also been correctly downsampled to 16kHz. You can pass your processed dataset to the model now!
|
The processor has now added `input_values` and `labels`, and the sampling rate has also been correctly downsampled to 16kHz. You can pass your processed dataset to the model now!
|
||||||
|
|
|
@ -83,7 +83,12 @@ Apply several image transformations to the dataset to make the model more robust
|
||||||
>>> from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
|
>>> from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
|
||||||
|
|
||||||
>>> normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
>>> normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
||||||
>>> _transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
|
>>> size = (
|
||||||
|
... feature_extractor.size["shortest_edge"]
|
||||||
|
... if "shortest_edge" in feature_extractor.size
|
||||||
|
... else (feature_extractor.size["height"], feature_extractor.size["width"])
|
||||||
|
... )
|
||||||
|
>>> _transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
|
||||||
```
|
```
|
||||||
|
|
||||||
Create a preprocessing function that will apply the transforms and return the `pixel_values` - the inputs to the model - of the image:
|
Create a preprocessing function that will apply the transforms and return the `pixel_values` - the inputs to the model - of the image:
|
||||||
|
|
|
@ -291,10 +291,14 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Define torchvision transforms to be applied to each image.
|
# Define torchvision transforms to be applied to each image.
|
||||||
|
if "shortest_edge" in feature_extractor.size:
|
||||||
|
size = feature_extractor.size["shortest_edge"]
|
||||||
|
else:
|
||||||
|
size = (feature_extractor.size["height"], feature_extractor.size["width"])
|
||||||
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
||||||
_train_transforms = Compose(
|
_train_transforms = Compose(
|
||||||
[
|
[
|
||||||
RandomResizedCrop(feature_extractor.size),
|
RandomResizedCrop(size),
|
||||||
RandomHorizontalFlip(),
|
RandomHorizontalFlip(),
|
||||||
ToTensor(),
|
ToTensor(),
|
||||||
normalize,
|
normalize,
|
||||||
|
@ -302,8 +306,8 @@ def main():
|
||||||
)
|
)
|
||||||
_val_transforms = Compose(
|
_val_transforms = Compose(
|
||||||
[
|
[
|
||||||
Resize(feature_extractor.size),
|
Resize(size),
|
||||||
CenterCrop(feature_extractor.size),
|
CenterCrop(size),
|
||||||
ToTensor(),
|
ToTensor(),
|
||||||
normalize,
|
normalize,
|
||||||
]
|
]
|
||||||
|
|
|
@ -315,10 +315,14 @@ def main():
|
||||||
# Preprocessing the datasets
|
# Preprocessing the datasets
|
||||||
|
|
||||||
# Define torchvision transforms to be applied to each image.
|
# Define torchvision transforms to be applied to each image.
|
||||||
|
if "shortest_edge" in feature_extractor.size:
|
||||||
|
size = feature_extractor.size["shortest_edge"]
|
||||||
|
else:
|
||||||
|
size = (feature_extractor.size["height"], feature_extractor.size["width"])
|
||||||
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
||||||
train_transforms = Compose(
|
train_transforms = Compose(
|
||||||
[
|
[
|
||||||
RandomResizedCrop(feature_extractor.size),
|
RandomResizedCrop(size),
|
||||||
RandomHorizontalFlip(),
|
RandomHorizontalFlip(),
|
||||||
ToTensor(),
|
ToTensor(),
|
||||||
normalize,
|
normalize,
|
||||||
|
@ -326,8 +330,8 @@ def main():
|
||||||
)
|
)
|
||||||
val_transforms = Compose(
|
val_transforms = Compose(
|
||||||
[
|
[
|
||||||
Resize(feature_extractor.size),
|
Resize(size),
|
||||||
CenterCrop(feature_extractor.size),
|
CenterCrop(size),
|
||||||
ToTensor(),
|
ToTensor(),
|
||||||
normalize,
|
normalize,
|
||||||
]
|
]
|
||||||
|
|
|
@ -298,10 +298,14 @@ def main():
|
||||||
|
|
||||||
# transformations as done in original MAE paper
|
# transformations as done in original MAE paper
|
||||||
# source: https://github.com/facebookresearch/mae/blob/main/main_pretrain.py
|
# source: https://github.com/facebookresearch/mae/blob/main/main_pretrain.py
|
||||||
|
if "shortest_edge" in feature_extractor.size:
|
||||||
|
size = feature_extractor.size["shortest_edge"]
|
||||||
|
else:
|
||||||
|
size = (feature_extractor.size["height"], feature_extractor.size["width"])
|
||||||
transforms = Compose(
|
transforms = Compose(
|
||||||
[
|
[
|
||||||
Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
||||||
RandomResizedCrop(feature_extractor.size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC),
|
RandomResizedCrop(size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC),
|
||||||
RandomHorizontalFlip(),
|
RandomHorizontalFlip(),
|
||||||
ToTensor(),
|
ToTensor(),
|
||||||
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
|
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
|
||||||
|
|
|
@ -57,12 +57,11 @@ require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/sema
|
||||||
|
|
||||||
|
|
||||||
def pad_if_smaller(img, size, fill=0):
|
def pad_if_smaller(img, size, fill=0):
|
||||||
min_size = min(img.size)
|
size = (size, size) if isinstance(size, int) else size
|
||||||
if min_size < size:
|
original_width, original_height = img.size
|
||||||
original_width, original_height = img.size
|
pad_height = size[1] - original_height if original_height < size[1] else 0
|
||||||
pad_height = size - original_height if original_height < size else 0
|
pad_width = size[0] - original_width if original_width < size[0] else 0
|
||||||
pad_width = size - original_width if original_width < size else 0
|
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
|
||||||
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,12 +109,12 @@ class RandomResize:
|
||||||
|
|
||||||
class RandomCrop:
|
class RandomCrop:
|
||||||
def __init__(self, size):
|
def __init__(self, size):
|
||||||
self.size = size
|
self.size = size if isinstance(size, tuple) else (size, size)
|
||||||
|
|
||||||
def __call__(self, image, target):
|
def __call__(self, image, target):
|
||||||
image = pad_if_smaller(image, self.size)
|
image = pad_if_smaller(image, self.size)
|
||||||
target = pad_if_smaller(target, self.size, fill=255)
|
target = pad_if_smaller(target, self.size, fill=255)
|
||||||
crop_params = transforms.RandomCrop.get_params(image, (self.size, self.size))
|
crop_params = transforms.RandomCrop.get_params(image, self.size)
|
||||||
image = functional.crop(image, *crop_params)
|
image = functional.crop(image, *crop_params)
|
||||||
target = functional.crop(target, *crop_params)
|
target = functional.crop(target, *crop_params)
|
||||||
return image, target
|
return image, target
|
||||||
|
@ -359,7 +358,7 @@ def main():
|
||||||
references=labels,
|
references=labels,
|
||||||
num_labels=len(id2label),
|
num_labels=len(id2label),
|
||||||
ignore_index=0,
|
ignore_index=0,
|
||||||
reduce_labels=feature_extractor.reduce_labels,
|
reduce_labels=feature_extractor.do_reduce_labels,
|
||||||
)
|
)
|
||||||
# add per category metrics as individual key-value pairs
|
# add per category metrics as individual key-value pairs
|
||||||
per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
|
per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
|
||||||
|
@ -396,10 +395,15 @@ def main():
|
||||||
# Define torchvision transforms to be applied to each image + target.
|
# Define torchvision transforms to be applied to each image + target.
|
||||||
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
|
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
|
||||||
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
|
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
|
||||||
|
if "shortest_edge" in feature_extractor.size:
|
||||||
|
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
|
||||||
|
size = (feature_extractor.size["shortest_edge"], feature_extractor.size["shortest_edge"])
|
||||||
|
else:
|
||||||
|
size = (feature_extractor.size["height"], feature_extractor.size["width"])
|
||||||
train_transforms = Compose(
|
train_transforms = Compose(
|
||||||
[
|
[
|
||||||
ReduceLabels() if data_args.reduce_labels else Identity(),
|
ReduceLabels() if data_args.reduce_labels else Identity(),
|
||||||
RandomCrop(size=feature_extractor.size),
|
RandomCrop(size=size),
|
||||||
RandomHorizontalFlip(flip_prob=0.5),
|
RandomHorizontalFlip(flip_prob=0.5),
|
||||||
PILToTensor(),
|
PILToTensor(),
|
||||||
ConvertImageDtype(torch.float),
|
ConvertImageDtype(torch.float),
|
||||||
|
@ -411,7 +415,7 @@ def main():
|
||||||
val_transforms = Compose(
|
val_transforms = Compose(
|
||||||
[
|
[
|
||||||
ReduceLabels() if data_args.reduce_labels else Identity(),
|
ReduceLabels() if data_args.reduce_labels else Identity(),
|
||||||
Resize(size=(feature_extractor.size, feature_extractor.size)),
|
Resize(size=size),
|
||||||
PILToTensor(),
|
PILToTensor(),
|
||||||
ConvertImageDtype(torch.float),
|
ConvertImageDtype(torch.float),
|
||||||
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
|
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
|
||||||
|
|
|
@ -405,10 +405,15 @@ def main():
|
||||||
# Define torchvision transforms to be applied to each image + target.
|
# Define torchvision transforms to be applied to each image + target.
|
||||||
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
|
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
|
||||||
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
|
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
|
||||||
|
if "shortest_edge" in feature_extractor.size:
|
||||||
|
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
|
||||||
|
size = (feature_extractor.size["shortest_edge"], feature_extractor.size["shortest_edge"])
|
||||||
|
else:
|
||||||
|
size = (feature_extractor.size["height"], feature_extractor.size["width"])
|
||||||
train_transforms = Compose(
|
train_transforms = Compose(
|
||||||
[
|
[
|
||||||
ReduceLabels() if args.reduce_labels else Identity(),
|
ReduceLabels() if args.reduce_labels else Identity(),
|
||||||
RandomCrop(size=feature_extractor.size),
|
RandomCrop(size=size),
|
||||||
RandomHorizontalFlip(flip_prob=0.5),
|
RandomHorizontalFlip(flip_prob=0.5),
|
||||||
PILToTensor(),
|
PILToTensor(),
|
||||||
ConvertImageDtype(torch.float),
|
ConvertImageDtype(torch.float),
|
||||||
|
@ -420,7 +425,7 @@ def main():
|
||||||
val_transforms = Compose(
|
val_transforms = Compose(
|
||||||
[
|
[
|
||||||
ReduceLabels() if args.reduce_labels else Identity(),
|
ReduceLabels() if args.reduce_labels else Identity(),
|
||||||
Resize(size=(feature_extractor.size, feature_extractor.size)),
|
Resize(size=size),
|
||||||
PILToTensor(),
|
PILToTensor(),
|
||||||
ConvertImageDtype(torch.float),
|
ConvertImageDtype(torch.float),
|
||||||
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
|
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Dict, Iterable, Optional, Union
|
||||||
|
|
||||||
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
|
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
|
||||||
from .feature_extraction_utils import FeatureExtractionMixin
|
from .feature_extraction_utils import FeatureExtractionMixin
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
@ -48,7 +50,72 @@ class BaseImageProcessor(ImageProcessorMixin):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __call__(self, images, **kwargs) -> BatchFeature:
|
def __call__(self, images, **kwargs) -> BatchFeature:
|
||||||
|
"""Preprocess an image or a batch of images."""
|
||||||
return self.preprocess(images, **kwargs)
|
return self.preprocess(images, **kwargs)
|
||||||
|
|
||||||
def preprocess(self, images, **kwargs) -> BatchFeature:
|
def preprocess(self, images, **kwargs) -> BatchFeature:
|
||||||
raise NotImplementedError("Each image processor must implement its own preprocess method")
|
raise NotImplementedError("Each image processor must implement its own preprocess method")
|
||||||
|
|
||||||
|
|
||||||
|
def get_size_dict(
|
||||||
|
size: Union[int, Iterable[int], Dict[str, int]] = None,
|
||||||
|
max_size: Optional[int] = None,
|
||||||
|
height_width_order: bool = True,
|
||||||
|
default_to_square: bool = True,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
|
||||||
|
compatibility with the old feature extractor configs and removes ambiguity over whether the tuple is in (height,
|
||||||
|
width) or (width, height) format.
|
||||||
|
|
||||||
|
- If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
|
||||||
|
size[0]}` if `height_width_order` is `False`.
|
||||||
|
- If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
|
||||||
|
- If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
|
||||||
|
is set, it is added to the dict as `{"longest_edge": max_size}`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (`Union[int, Iterable[int], Dict[str, int]]`, *optional*):
|
||||||
|
The `size` parameter to be cast into a size dictionary.
|
||||||
|
max_size (`Optional[int]`, *optional*):
|
||||||
|
The `max_size` parameter to be cast into a size dictionary.
|
||||||
|
height_width_order (`bool`, *optional*, defaults to `True`):
|
||||||
|
If `size` is a tuple, whether it's in (height, width) or (width, height) order.
|
||||||
|
default_to_square (`bool`, *optional*, defaults to `True`):
|
||||||
|
If `size` is an int, whether to default to a square image or not.
|
||||||
|
"""
|
||||||
|
# If a dict is passed, we check if it's a valid size dict and then return it.
|
||||||
|
if isinstance(size, dict):
|
||||||
|
size_keys = set(size.keys())
|
||||||
|
if (
|
||||||
|
size_keys != set(["height", "width"])
|
||||||
|
and size_keys != set(["shortest_edge"])
|
||||||
|
and size_keys != set(["shortest_edge", "longest_edge"])
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"The size dict must contain either the keys ('height', 'width') or ('shortest_edge')"
|
||||||
|
f"or ('shortest_edge', 'longest_edge') but got {size_keys}"
|
||||||
|
)
|
||||||
|
return size
|
||||||
|
|
||||||
|
# By default, if size is an int we assume it represents a tuple of (size, size).
|
||||||
|
elif isinstance(size, int) and default_to_square:
|
||||||
|
if max_size is not None:
|
||||||
|
raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
|
||||||
|
size_dict = {"height": size, "width": size}
|
||||||
|
# In other configs, if size is an int and default_to_square is False, size represents the length of the shortest edge after resizing.
|
||||||
|
elif isinstance(size, int) and not default_to_square:
|
||||||
|
if max_size is not None:
|
||||||
|
size_dict = {"shortest_edge": size, "longest_edge": max_size}
|
||||||
|
else:
|
||||||
|
size_dict = {"shortest_edge": size}
|
||||||
|
elif isinstance(size, (tuple, list)) and height_width_order:
|
||||||
|
size_dict = {"height": size[0], "width": size[1]}
|
||||||
|
elif isinstance(size, (tuple, list)) and not height_width_order:
|
||||||
|
size_dict = {"height": size[1], "width": size[0]}
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"The size parameter should be a dictionary with keys ('height', 'width'), ('shortest_edge', 'longest_edge')"
|
||||||
|
f" or ('shortest_edge',) got {size}. Setting as {size_dict}.",
|
||||||
|
)
|
||||||
|
return size_dict
|
||||||
|
|
|
@ -139,6 +139,9 @@ def to_pil_image(
|
||||||
# If the channel as been moved to first dim, we put it back at the end.
|
# If the channel as been moved to first dim, we put it back at the end.
|
||||||
image = to_channel_dimension_format(image, ChannelDimension.LAST)
|
image = to_channel_dimension_format(image, ChannelDimension.LAST)
|
||||||
|
|
||||||
|
# If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
|
||||||
|
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
|
||||||
|
|
||||||
# PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed.
|
# PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed.
|
||||||
do_rescale = isinstance(image.flat[0], float) if do_rescale is None else do_rescale
|
do_rescale = isinstance(image.flat[0], float) if do_rescale is None else do_rescale
|
||||||
if do_rescale:
|
if do_rescale:
|
||||||
|
@ -259,6 +262,9 @@ def resize(
|
||||||
|
|
||||||
if return_numpy:
|
if return_numpy:
|
||||||
resized_image = np.array(resized_image)
|
resized_image = np.array(resized_image)
|
||||||
|
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
|
||||||
|
# so we need to add it back if necessary.
|
||||||
|
resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
|
||||||
resized_image = to_channel_dimension_format(resized_image, data_format)
|
resized_image = to_channel_dimension_format(resized_image, data_format)
|
||||||
return resized_image
|
return resized_image
|
||||||
|
|
||||||
|
@ -303,12 +309,14 @@ def normalize(
|
||||||
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
|
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
|
||||||
else:
|
else:
|
||||||
mean = [mean] * num_channels
|
mean = [mean] * num_channels
|
||||||
|
mean = np.array(mean, dtype=image.dtype)
|
||||||
|
|
||||||
if isinstance(std, Iterable):
|
if isinstance(std, Iterable):
|
||||||
if len(std) != num_channels:
|
if len(std) != num_channels:
|
||||||
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
|
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
|
||||||
else:
|
else:
|
||||||
std = [std] * num_channels
|
std = [std] * num_channels
|
||||||
|
std = np.array(std, dtype=image.dtype)
|
||||||
|
|
||||||
if input_data_format == ChannelDimension.LAST:
|
if input_data_format == ChannelDimension.LAST:
|
||||||
image = (image - mean) / std
|
image = (image - mean) / std
|
||||||
|
@ -372,6 +380,7 @@ def center_crop(
|
||||||
|
|
||||||
orig_height, orig_width = get_image_size(image)
|
orig_height, orig_width = get_image_size(image)
|
||||||
crop_height, crop_width = size
|
crop_height, crop_width = size
|
||||||
|
crop_height, crop_width = int(crop_height), int(crop_width)
|
||||||
|
|
||||||
# In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
|
# In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
|
||||||
top = (orig_height - crop_height) // 2
|
top = (orig_height - crop_height) // 2
|
||||||
|
|
|
@ -72,7 +72,15 @@ def is_valid_image(img):
|
||||||
|
|
||||||
|
|
||||||
def valid_images(imgs):
|
def valid_images(imgs):
|
||||||
return all(is_valid_image(img) for img in imgs)
|
# If we have an list of images, make sure every image is valid
|
||||||
|
if isinstance(imgs, (list, tuple)):
|
||||||
|
for img in imgs:
|
||||||
|
if not valid_images(img):
|
||||||
|
return False
|
||||||
|
# If not a list of tuple, we have been given a single image or batched tensor of images
|
||||||
|
elif not is_valid_image(imgs):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def is_batched(img):
|
def is_batched(img):
|
||||||
|
|
|
@ -14,258 +14,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for BEiT."""
|
"""Feature extractor class for BEiT."""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_beit import BeitImageProcessor
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import (
|
|
||||||
IMAGENET_STANDARD_MEAN,
|
|
||||||
IMAGENET_STANDARD_STD,
|
|
||||||
ImageFeatureExtractionMixin,
|
|
||||||
ImageInput,
|
|
||||||
is_torch_tensor,
|
|
||||||
)
|
|
||||||
from ...utils import TensorType, is_torch_available, logging
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
BeitFeatureExtractor = BeitImageProcessor
|
||||||
class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|
||||||
r"""
|
|
||||||
Constructs a BEiT feature extractor.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`~feature_extraction_utils.FeatureExtractionMixin`] which contains most of
|
|
||||||
the main methods. Users should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the input to a certain `size`.
|
|
||||||
size (`int` or `Tuple(int)`, *optional*, defaults to 256):
|
|
||||||
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
|
|
||||||
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
|
|
||||||
set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
do_center_crop (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
|
|
||||||
image is padded with 0's and then center cropped.
|
|
||||||
crop_size (`int`, *optional*, defaults to 224):
|
|
||||||
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with `image_mean` and `image_std`.
|
|
||||||
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images.
|
|
||||||
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
|
||||||
reduce_labels (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
|
|
||||||
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
|
|
||||||
background label will be replaced by 255.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize=True,
|
|
||||||
size=256,
|
|
||||||
resample=PILImageResampling.BICUBIC,
|
|
||||||
do_center_crop=True,
|
|
||||||
crop_size=224,
|
|
||||||
do_normalize=True,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
reduce_labels=False,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.do_center_crop = do_center_crop
|
|
||||||
self.crop_size = crop_size
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
|
||||||
self.reduce_labels = reduce_labels
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
images: ImageInput,
|
|
||||||
segmentation_maps: ImageInput = None,
|
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
|
|
||||||
Optionally, the corresponding semantic segmentation maps with the pixel-wise annotations.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
|
||||||
width).
|
|
||||||
- **labels** -- Optional labels to be fed to a model (when `segmentation_maps` are provided)
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
valid_segmentation_maps = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that segmentation maps has a valid type
|
|
||||||
if segmentation_maps is not None:
|
|
||||||
if isinstance(segmentation_maps, (Image.Image, np.ndarray)) or is_torch_tensor(segmentation_maps):
|
|
||||||
valid_segmentation_maps = True
|
|
||||||
elif isinstance(segmentation_maps, (list, tuple)):
|
|
||||||
if (
|
|
||||||
len(segmentation_maps) == 0
|
|
||||||
or isinstance(segmentation_maps[0], (Image.Image, np.ndarray))
|
|
||||||
or is_torch_tensor(segmentation_maps[0])
|
|
||||||
):
|
|
||||||
valid_segmentation_maps = True
|
|
||||||
|
|
||||||
if not valid_segmentation_maps:
|
|
||||||
raise ValueError(
|
|
||||||
"Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single"
|
|
||||||
" example),`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of"
|
|
||||||
" examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
if segmentation_maps is not None:
|
|
||||||
segmentation_maps = [segmentation_maps]
|
|
||||||
|
|
||||||
# reduce zero label if needed
|
|
||||||
if self.reduce_labels:
|
|
||||||
if segmentation_maps is not None:
|
|
||||||
for idx, map in enumerate(segmentation_maps):
|
|
||||||
if not isinstance(map, np.ndarray):
|
|
||||||
map = np.array(map)
|
|
||||||
# avoid using underflow conversion
|
|
||||||
map[map == 0] = 255
|
|
||||||
map = map - 1
|
|
||||||
map[map == 254] = 255
|
|
||||||
segmentation_maps[idx] = Image.fromarray(map.astype(np.uint8))
|
|
||||||
|
|
||||||
# transformations (resizing + center cropping + normalization)
|
|
||||||
if self.do_resize and self.size is not None and self.resample is not None:
|
|
||||||
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
|
|
||||||
if segmentation_maps is not None:
|
|
||||||
segmentation_maps = [
|
|
||||||
self.resize(map, size=self.size, resample=self.resample) for map in segmentation_maps
|
|
||||||
]
|
|
||||||
if self.do_center_crop and self.crop_size is not None:
|
|
||||||
images = [self.center_crop(image, self.crop_size) for image in images]
|
|
||||||
if segmentation_maps is not None:
|
|
||||||
segmentation_maps = [self.center_crop(map, size=self.crop_size) for map in segmentation_maps]
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
|
|
||||||
if segmentation_maps is not None:
|
|
||||||
labels = []
|
|
||||||
for map in segmentation_maps:
|
|
||||||
if not isinstance(map, np.ndarray):
|
|
||||||
map = np.array(map)
|
|
||||||
labels.append(map.astype(np.int64))
|
|
||||||
# cast to np.int64
|
|
||||||
data["labels"] = labels
|
|
||||||
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
||||||
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
|
|
||||||
"""
|
|
||||||
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
outputs ([`BeitForSemanticSegmentation`]):
|
|
||||||
Raw outputs of the model.
|
|
||||||
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
|
|
||||||
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
|
|
||||||
None, predictions will not be resized.
|
|
||||||
Returns:
|
|
||||||
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
|
||||||
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
|
||||||
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
|
||||||
"""
|
|
||||||
logits = outputs.logits
|
|
||||||
|
|
||||||
# Resize logits and compute semantic segmentation maps
|
|
||||||
if target_sizes is not None:
|
|
||||||
if len(logits) != len(target_sizes):
|
|
||||||
raise ValueError(
|
|
||||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_torch_tensor(target_sizes):
|
|
||||||
target_sizes = target_sizes.numpy()
|
|
||||||
|
|
||||||
semantic_segmentation = []
|
|
||||||
|
|
||||||
for idx in range(len(logits)):
|
|
||||||
resized_logits = torch.nn.functional.interpolate(
|
|
||||||
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
semantic_map = resized_logits[0].argmax(dim=0)
|
|
||||||
semantic_segmentation.append(semantic_map)
|
|
||||||
else:
|
|
||||||
semantic_segmentation = logits.argmax(dim=1)
|
|
||||||
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
|
||||||
|
|
||||||
return semantic_segmentation
|
|
||||||
|
|
|
@ -0,0 +1,525 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for Beit."""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils import is_torch_available, is_torch_tensor, is_vision_available
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
is_batched,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BeitImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a BEiT image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
||||||
|
`do_resize` parameter in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
|
||||||
|
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
||||||
|
method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
|
||||||
|
is padded with 0's and then center cropped. Can be overridden by the `do_center_crop` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||||
|
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
|
||||||
|
Can be overridden by the `crop_size` parameter in the `preprocess` method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||||
|
parameter in the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||||
|
method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||||
|
The mean to use if normalizing the image. This is a float or list of floats of length of the number of
|
||||||
|
channels of the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||||
|
The standard deviation to use if normalizing the image. This is a float or list of floats of length of the
|
||||||
|
number of channels of the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
do_reduce_labels (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
|
||||||
|
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
|
||||||
|
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
do_center_crop: bool = True,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_reduce_labels: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
if "reduce_labels" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use"
|
||||||
|
" `do_reduce_labels` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
do_reduce_labels = kwargs.pop("reduce_labels")
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"height": 256, "width": 256}
|
||||||
|
size = get_size_dict(size)
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_center_crop = do_center_crop
|
||||||
|
self.crop_size = crop_size
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
self.do_reduce_labels = do_reduce_labels
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image to (size["height"], size["width"]).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BICUBIC`):
|
||||||
|
Resampling filter to use when resiizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
if "height" not in size or "width" not in size:
|
||||||
|
raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}")
|
||||||
|
return resize(
|
||||||
|
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def center_crop(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Center crop an image to (size["height"], size["width"]). If the input size is smaller than `size` along any
|
||||||
|
edge, the image is padded with 0's and then center cropped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to center crop.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, List[float]],
|
||||||
|
std: Union[float, List[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image. image = (image - image_mean) / image_std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
image_mean (`float` or `List[float]`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`):
|
||||||
|
Image standard deviation.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def reduce_label(self, label: ImageInput) -> np.ndarray:
|
||||||
|
label = to_numpy_array(label)
|
||||||
|
# Avoid using underflow conversion
|
||||||
|
label[label == 0] = 255
|
||||||
|
label = label - 1
|
||||||
|
label[label == 254] = 255
|
||||||
|
return label
|
||||||
|
|
||||||
|
def _preprocess(
|
||||||
|
self,
|
||||||
|
image: ImageInput,
|
||||||
|
do_reduce_labels: bool = None,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_center_crop: bool = None,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
):
|
||||||
|
if do_reduce_labels:
|
||||||
|
image = self.reduce_label(image)
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
image = self.resize(image=image, size=size, resample=resample)
|
||||||
|
|
||||||
|
if do_center_crop:
|
||||||
|
image = self.center_crop(image=image, size=crop_size)
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
image = self.rescale(image=image, scale=rescale_factor)
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
image = self.normalize(image=image, mean=image_mean, std=image_std)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _preprocess_image(
|
||||||
|
self,
|
||||||
|
image: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_center_crop: bool = None,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Preprocesses a single image."""
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
image = to_numpy_array(image)
|
||||||
|
image = self._preprocess(
|
||||||
|
image,
|
||||||
|
do_reduce_labels=False,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=resample,
|
||||||
|
do_center_crop=do_center_crop,
|
||||||
|
crop_size=crop_size,
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
do_normalize=do_normalize,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
)
|
||||||
|
if data_format is not None:
|
||||||
|
image = to_channel_dimension_format(image, data_format)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _preprocess_segmentation_map(
|
||||||
|
self,
|
||||||
|
segmentation_map: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_center_crop: bool = None,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_reduce_labels: bool = None,
|
||||||
|
):
|
||||||
|
"""Preprocesses a single segmentation map."""
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
segmentation_map = to_numpy_array(segmentation_map)
|
||||||
|
# Add an axis to the segmentation maps for transformations.
|
||||||
|
if segmentation_map.ndim == 2:
|
||||||
|
segmentation_map = segmentation_map[None, ...]
|
||||||
|
added_dimension = True
|
||||||
|
else:
|
||||||
|
added_dimension = False
|
||||||
|
segmentation_map = self._preprocess(
|
||||||
|
image=segmentation_map,
|
||||||
|
do_reduce_labels=do_reduce_labels,
|
||||||
|
do_resize=do_resize,
|
||||||
|
resample=resample,
|
||||||
|
size=size,
|
||||||
|
do_center_crop=do_center_crop,
|
||||||
|
crop_size=crop_size,
|
||||||
|
do_normalize=False,
|
||||||
|
do_rescale=False,
|
||||||
|
)
|
||||||
|
# Remove extra axis if added
|
||||||
|
if added_dimension:
|
||||||
|
segmentation_map = np.squeeze(segmentation_map, axis=0)
|
||||||
|
segmentation_map = segmentation_map.astype(np.int64)
|
||||||
|
return segmentation_map
|
||||||
|
|
||||||
|
def __call__(self, images, segmentation_maps=None, **kwargs):
|
||||||
|
# Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both
|
||||||
|
# be passed in as positional arguments.
|
||||||
|
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
segmentation_maps: Optional[ImageInput] = None,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_center_crop: bool = None,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_reduce_labels: Optional[bool] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image after resizing.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
||||||
|
Whether to center crop the image.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
||||||
|
Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
|
||||||
|
padded with zeros and then cropped
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image values between [0 - 1].
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation.
|
||||||
|
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
|
||||||
|
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
|
||||||
|
is used for background, and background itself is not included in all classes of a dataset (e.g.
|
||||||
|
ADE20k). The background label will be replaced by 255.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size)
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||||
|
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if segmentation_maps is not None and not valid_images(segmentation_maps):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None or resample is None:
|
||||||
|
raise ValueError("Size and resample must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_center_crop and crop_size is None:
|
||||||
|
raise ValueError("Crop size must be specified if do_center_crop is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_normalize and (image_mean is None or image_std is None):
|
||||||
|
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||||
|
|
||||||
|
images = [
|
||||||
|
self._preprocess_image(
|
||||||
|
image=img,
|
||||||
|
do_resize=do_resize,
|
||||||
|
do_center_crop=do_center_crop,
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
do_normalize=do_normalize,
|
||||||
|
resample=resample,
|
||||||
|
size=size,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
crop_size=crop_size,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
data_format=data_format,
|
||||||
|
)
|
||||||
|
for img in images
|
||||||
|
]
|
||||||
|
|
||||||
|
data = {"pixel_values": images}
|
||||||
|
|
||||||
|
if segmentation_maps is not None:
|
||||||
|
segmentation_maps = [
|
||||||
|
self._preprocess_segmentation_map(
|
||||||
|
segmentation_map=segmentation_map,
|
||||||
|
do_reduce_labels=do_reduce_labels,
|
||||||
|
do_resize=do_resize,
|
||||||
|
resample=resample,
|
||||||
|
size=size,
|
||||||
|
do_center_crop=do_center_crop,
|
||||||
|
crop_size=crop_size,
|
||||||
|
)
|
||||||
|
for segmentation_map in segmentation_maps
|
||||||
|
]
|
||||||
|
data["labels"] = segmentation_maps
|
||||||
|
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
|
||||||
|
"""
|
||||||
|
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs ([`BeitForSemanticSegmentation`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
|
||||||
|
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
|
||||||
|
None, predictions will not be resized.
|
||||||
|
Returns:
|
||||||
|
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
||||||
|
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
||||||
|
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
||||||
|
"""
|
||||||
|
# TODO: add support for other frameworks
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# Resize logits and compute semantic segmentation maps
|
||||||
|
if target_sizes is not None:
|
||||||
|
if len(logits) != len(target_sizes):
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_torch_tensor(target_sizes):
|
||||||
|
target_sizes = target_sizes.numpy()
|
||||||
|
|
||||||
|
semantic_segmentation = []
|
||||||
|
|
||||||
|
for idx in range(len(logits)):
|
||||||
|
resized_logits = torch.nn.functional.interpolate(
|
||||||
|
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
semantic_map = resized_logits[0].argmax(dim=0)
|
||||||
|
semantic_segmentation.append(semantic_map)
|
||||||
|
else:
|
||||||
|
semantic_segmentation = logits.argmax(dim=1)
|
||||||
|
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
||||||
|
|
||||||
|
return semantic_segmentation
|
|
@ -14,155 +14,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for CLIP."""
|
"""Feature extractor class for CLIP."""
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_clip import CLIPImageProcessor
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
|
|
||||||
from ...utils import TensorType, logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
CLIPFeatureExtractor = CLIPImageProcessor
|
||||||
r"""
|
|
||||||
Constructs a CLIP feature extractor.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the input to a certain `size`.
|
|
||||||
size (`int`, *optional*, defaults to 224):
|
|
||||||
Resize the input to the given size. Only has an effect if `do_resize` is set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
do_center_crop (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
|
|
||||||
image is padded with 0's and then center cropped.
|
|
||||||
crop_size (`int`, *optional*, defaults to 224):
|
|
||||||
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with `image_mean` and `image_std`.
|
|
||||||
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images.
|
|
||||||
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
|
||||||
convert_rgb (`bool`, defaults to `True`):
|
|
||||||
Whether or not to convert `PIL.Image.Image` into `RGB` format
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize=True,
|
|
||||||
size=224,
|
|
||||||
resample=PILImageResampling.BICUBIC,
|
|
||||||
do_center_crop=True,
|
|
||||||
crop_size=224,
|
|
||||||
do_normalize=True,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
do_convert_rgb=True,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.do_center_crop = do_center_crop
|
|
||||||
self.crop_size = crop_size
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
|
|
||||||
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
|
|
||||||
self.do_convert_rgb = do_convert_rgb
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
images: Union[
|
|
||||||
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
|
|
||||||
],
|
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model.
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
# transformations (convert rgb + resizing + center cropping + normalization)
|
|
||||||
if self.do_convert_rgb:
|
|
||||||
images = [self.convert_rgb(image) for image in images]
|
|
||||||
if self.do_resize and self.size is not None and self.resample is not None:
|
|
||||||
images = [
|
|
||||||
self.resize(image=image, size=self.size, resample=self.resample, default_to_square=False)
|
|
||||||
for image in images
|
|
||||||
]
|
|
||||||
if self.do_center_crop and self.crop_size is not None:
|
|
||||||
images = [self.center_crop(image, self.crop_size) for image in images]
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
|
@ -0,0 +1,342 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for CLIP."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import (
|
||||||
|
center_crop,
|
||||||
|
get_resize_output_image_size,
|
||||||
|
normalize,
|
||||||
|
rescale,
|
||||||
|
resize,
|
||||||
|
to_channel_dimension_format,
|
||||||
|
)
|
||||||
|
from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images
|
||||||
|
from ...utils import logging
|
||||||
|
from ...utils.import_utils import is_vision_available
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
|
||||||
|
"""
|
||||||
|
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`PIL.Image.Image`):
|
||||||
|
The image to convert.
|
||||||
|
"""
|
||||||
|
if not isinstance(image, PIL.Image.Image):
|
||||||
|
return image
|
||||||
|
|
||||||
|
return image.convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a CLIP image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||||
|
`do_resize` in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
|
||||||
|
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
|
||||||
|
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
|
||||||
|
method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
|
||||||
|
`preprocess` method.
|
||||||
|
crop_size (`Dict[str, int]` *optional*, defaults to 224):
|
||||||
|
Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
|
||||||
|
method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||||
|
the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||||
|
method.
|
||||||
|
do_normalize:
|
||||||
|
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||||
|
Image standard deviation.
|
||||||
|
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
do_center_crop: bool = True,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_convert_rgb: bool = True,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"shortest_edge": 224}
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_center_crop = do_center_crop
|
||||||
|
self.crop_size = crop_size
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
|
||||||
|
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
|
||||||
|
self.do_convert_rgb = do_convert_rgb
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
|
||||||
|
resized to keep the input aspect ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use when resiizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
if "shortest_edge" not in size:
|
||||||
|
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
|
||||||
|
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
|
||||||
|
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def center_crop(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Center crop an image. If the image is too small to be cropped to the size given, it will be padded (so the
|
||||||
|
returned result will always be of size `size`).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to center crop.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image in the form of a dictionary with keys `height` and `width`.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, List[float]],
|
||||||
|
std: Union[float, List[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image. image = (image - image_mean) / image_std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
image_mean (`float` or `List[float]`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`):
|
||||||
|
Image standard deviation.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_center_crop: bool = None,
|
||||||
|
crop_size: int = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_convert_rgb: bool = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
||||||
|
the longest edge resized to keep the input aspect ratio.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
||||||
|
Whether to center crop the image.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
||||||
|
Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image.
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||||
|
`True`.
|
||||||
|
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||||
|
Whether to convert the image to RGB.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- Unset: defaults to the channel dimension format of the input image.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||||
|
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None:
|
||||||
|
raise ValueError("Size must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_center_crop and crop_size is None:
|
||||||
|
raise ValueError("Crop size must be specified if do_center_crop is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_normalize and (image_mean is None or image_std is None):
|
||||||
|
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||||
|
|
||||||
|
# PIL RGBA images are converted to RGB
|
||||||
|
if do_convert_rgb:
|
||||||
|
images = [convert_to_rgb(image) for image in images]
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||||
|
|
||||||
|
if do_center_crop:
|
||||||
|
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||||
|
|
||||||
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
|
data = {"pixel_values": images}
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
@ -14,157 +14,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for ConvNeXT."""
|
"""Feature extractor class for ConvNeXT."""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_convnext import ConvNextImageProcessor
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import (
|
|
||||||
IMAGENET_DEFAULT_MEAN,
|
|
||||||
IMAGENET_DEFAULT_STD,
|
|
||||||
ImageFeatureExtractionMixin,
|
|
||||||
ImageInput,
|
|
||||||
is_torch_tensor,
|
|
||||||
)
|
|
||||||
from ...utils import TensorType, logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ConvNextFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
ConvNextFeatureExtractor = ConvNextImageProcessor
|
||||||
r"""
|
|
||||||
Constructs a ConvNeXT feature extractor.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize (and optionally center crop) the input to a certain `size`.
|
|
||||||
size (`int`, *optional*, defaults to 224):
|
|
||||||
Resize the input to the given size. If 384 or larger, the image is resized to (`size`, `size`). Else, the
|
|
||||||
smaller edge of the image will be matched to int(`size`/ `crop_pct`), after which the image is cropped to
|
|
||||||
`size`. Only has an effect if `do_resize` is set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
crop_pct (`float`, *optional*):
|
|
||||||
The percentage of the image to crop. If `None`, then a cropping percentage of 224 / 256 is used. Only has
|
|
||||||
an effect if `do_resize` is set to `True` and `size` < 384.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with mean and standard deviation.
|
|
||||||
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images.
|
|
||||||
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize=True,
|
|
||||||
size=224,
|
|
||||||
resample=PILImageResampling.BICUBIC,
|
|
||||||
crop_pct=None,
|
|
||||||
do_normalize=True,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.crop_pct = crop_pct
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
|
||||||
width).
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
# transformations (resizing and optional center cropping + normalization)
|
|
||||||
if self.do_resize and self.size is not None:
|
|
||||||
if self.size >= 384:
|
|
||||||
# warping (no cropping) when evaluated at 384 or larger
|
|
||||||
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
|
|
||||||
else:
|
|
||||||
if self.crop_pct is None:
|
|
||||||
self.crop_pct = 224 / 256
|
|
||||||
size = int(self.size / self.crop_pct)
|
|
||||||
# to maintain same ratio w.r.t. 224 images
|
|
||||||
images = [
|
|
||||||
self.resize(image=image, size=size, default_to_square=False, resample=self.resample)
|
|
||||||
for image in images
|
|
||||||
]
|
|
||||||
images = [self.center_crop(image=image, size=self.size) for image in images]
|
|
||||||
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
|
@ -0,0 +1,310 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for ConvNeXT."""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils import is_vision_available
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import (
|
||||||
|
center_crop,
|
||||||
|
get_resize_output_image_size,
|
||||||
|
normalize,
|
||||||
|
rescale,
|
||||||
|
resize,
|
||||||
|
to_channel_dimension_format,
|
||||||
|
)
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
is_batched,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNextImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a ConvNeXT image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden
|
||||||
|
by `do_resize` in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`):
|
||||||
|
Resolution of the output image after `resize` is applied. If `size["shortest_edge"]` >= 384, the image is
|
||||||
|
resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the image will
|
||||||
|
be matched to `int(size["shortest_edge"]/crop_pct)`, after which the image is cropped to
|
||||||
|
`(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. Can
|
||||||
|
be overriden by `size` in the `preprocess` method.
|
||||||
|
crop_pct (`float` *optional*, defaults to 244 / 256):
|
||||||
|
Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be
|
||||||
|
overriden by `crop_pct` in the `preprocess` method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overriden by `resample` in the `preprocess` method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in
|
||||||
|
the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
|
||||||
|
method.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||||
|
method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
crop_pct: float = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"shortest_edge": 384}
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
# Default value set here for backwards compatibility where the value in config is None
|
||||||
|
self.crop_pct = crop_pct if crop_pct is not None else 224 / 256
|
||||||
|
self.resample = resample
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
crop_pct: float,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If
|
||||||
|
`size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`.
|
||||||
|
Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`,
|
||||||
|
after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`.
|
||||||
|
crop_pct (`float`):
|
||||||
|
Percentage of the image to crop. Only has an effect if size < 384.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use when resizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
if "shortest_edge" not in size:
|
||||||
|
raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}")
|
||||||
|
shortest_edge = size["shortest_edge"]
|
||||||
|
|
||||||
|
if shortest_edge < 384:
|
||||||
|
# maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
|
||||||
|
resize_shortest_edge = int(shortest_edge / crop_pct)
|
||||||
|
resize_size = get_resize_output_image_size(image, size=resize_shortest_edge, default_to_square=False)
|
||||||
|
image = resize(image=image, size=resize_size, resample=resample, data_format=data_format, **kwargs)
|
||||||
|
# then crop to (shortest_edge, shortest_edge)
|
||||||
|
return center_crop(image=image, size=(shortest_edge, shortest_edge), data_format=data_format, **kwargs)
|
||||||
|
else:
|
||||||
|
# warping (no cropping) when evaluated at 384 or larger
|
||||||
|
return resize(
|
||||||
|
image, size=(shortest_edge, shortest_edge), resample=resample, data_format=data_format, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, List[float]],
|
||||||
|
std: Union[float, List[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image. image = (image - image_mean) / image_std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
image_mean (`float` or `List[float]`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`):
|
||||||
|
Image standard deviation.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
crop_pct: float = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
|
||||||
|
is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
|
||||||
|
image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
|
||||||
|
`(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`.
|
||||||
|
crop_pct (`float`, *optional*, defaults to `self.crop_pct`):
|
||||||
|
Percentage of the image to crop if size < 384.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image values between [0 - 1].
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
crop_pct = crop_pct if crop_pct is not None else self.crop_pct
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None or resample is None:
|
||||||
|
raise ValueError("Size and resample must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_resize and size["shortest_edge"] < 384 and crop_pct is None:
|
||||||
|
raise ValueError("crop_pct must be specified if size < 384.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_normalize and (image_mean is None or image_std is None):
|
||||||
|
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images = [self.resize(image=image, size=size, crop_pct=crop_pct, resample=resample) for image in images]
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||||
|
|
||||||
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
|
data = {"pixel_values": images}
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
@ -308,7 +308,7 @@ def convert_cvt_checkpoint(cvt_model, image_size, cvt_file_name, pytorch_dump_fo
|
||||||
|
|
||||||
model = CvtForImageClassification(config)
|
model = CvtForImageClassification(config)
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/convnext-base-224-22k-1k")
|
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/convnext-base-224-22k-1k")
|
||||||
feature_extractor.size = image_size
|
feature_extractor.size["shortest_edge"] = image_size
|
||||||
original_weights = torch.load(cvt_file_name, map_location=torch.device("cpu"))
|
original_weights = torch.load(cvt_file_name, map_location=torch.device("cpu"))
|
||||||
|
|
||||||
huggingface_weights = OrderedDict()
|
huggingface_weights = OrderedDict()
|
||||||
|
|
|
@ -14,150 +14,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for DeiT."""
|
"""Feature extractor class for DeiT."""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_deit import DeiTImageProcessor
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import (
|
|
||||||
IMAGENET_DEFAULT_MEAN,
|
|
||||||
IMAGENET_DEFAULT_STD,
|
|
||||||
ImageFeatureExtractionMixin,
|
|
||||||
ImageInput,
|
|
||||||
is_torch_tensor,
|
|
||||||
)
|
|
||||||
from ...utils import TensorType, logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
DeiTFeatureExtractor = DeiTImageProcessor
|
||||||
class DeiTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|
||||||
r"""
|
|
||||||
Constructs a DeiT feature extractor.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the input to a certain `size`.
|
|
||||||
size (`int` or `Tuple(int)`, *optional*, defaults to 256):
|
|
||||||
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
|
|
||||||
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
|
|
||||||
set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
do_center_crop (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
|
|
||||||
image is padded with 0's and then center cropped.
|
|
||||||
crop_size (`int`, *optional*, defaults to 224):
|
|
||||||
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with `image_mean` and `image_std`.
|
|
||||||
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images.
|
|
||||||
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize=True,
|
|
||||||
size=256,
|
|
||||||
resample=PILImageResampling.BICUBIC,
|
|
||||||
do_center_crop=True,
|
|
||||||
crop_size=224,
|
|
||||||
do_normalize=True,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.do_center_crop = do_center_crop
|
|
||||||
self.crop_size = crop_size
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
|
||||||
width).
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
# transformations (resizing + center cropping + normalization)
|
|
||||||
if self.do_resize and self.size is not None and self.resample is not None:
|
|
||||||
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
|
|
||||||
if self.do_center_crop and self.crop_size is not None:
|
|
||||||
images = [self.center_crop(image, self.crop_size) for image in images]
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
|
@ -0,0 +1,315 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for DeiT."""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils import is_vision_available
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
is_batched,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeiTImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a DeiT image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||||
|
`do_resize` in `preprocess`.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
|
||||||
|
Size of the image after `resize`. Can be overridden by `size` in `preprocess`.
|
||||||
|
resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
|
||||||
|
is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in `preprocess`.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||||
|
Desired output size when applying center-cropping. Can be overridden by `crop_size` in `preprocess`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||||
|
parameter in the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||||
|
method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PIL.Image.BICUBIC,
|
||||||
|
do_center_crop: bool = True,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"height": 256, "width": 256}
|
||||||
|
size = get_size_dict(size)
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_center_crop = do_center_crop
|
||||||
|
self.crop_size = crop_size
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PIL.Image.BICUBIC,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image to `(size["height"], size["width"])` using the specified resampling filter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use when resizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
if "height" not in size or "width" not in size:
|
||||||
|
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
|
||||||
|
return resize(
|
||||||
|
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def center_crop(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Center crop an image to `(crop_size["height"], crop_size["width"])`. If the input size is smaller than
|
||||||
|
`crop_size` along any edge, the image is padded with 0's and then center cropped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to center crop.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, List[float]],
|
||||||
|
std: Union[float, List[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image. image = (image - image_mean) / image_std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
image_mean (`float` or `List[float]`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`):
|
||||||
|
Image standard deviation.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample=None,
|
||||||
|
do_center_crop: bool = None,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image after `resize`.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||||
|
PILImageResampling filter to use if resizing the image Only has an effect if `do_resize` is set to
|
||||||
|
`True`.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
||||||
|
Whether to center crop the image.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
||||||
|
Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
|
||||||
|
padded with zeros and then cropped
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image values between [0 - 1].
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- `None`: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size)
|
||||||
|
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None or resample is None:
|
||||||
|
raise ValueError("Size and resample must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_center_crop and crop_size is None:
|
||||||
|
raise ValueError("Crop size must be specified if do_center_crop is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_normalize and (image_mean is None or image_std is None):
|
||||||
|
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||||
|
|
||||||
|
if do_center_crop:
|
||||||
|
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||||
|
|
||||||
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
|
data = {"pixel_values": images}
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
@ -14,235 +14,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for DPT."""
|
"""Feature extractor class for DPT."""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_dpt import DPTImageProcessor
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import (
|
|
||||||
IMAGENET_STANDARD_MEAN,
|
|
||||||
IMAGENET_STANDARD_STD,
|
|
||||||
ImageFeatureExtractionMixin,
|
|
||||||
ImageInput,
|
|
||||||
is_torch_tensor,
|
|
||||||
)
|
|
||||||
from ...utils import TensorType, is_torch_available, logging
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
DPTFeatureExtractor = DPTImageProcessor
|
||||||
r"""
|
|
||||||
Constructs a DPT feature extractor.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the input to a certain `size`.
|
|
||||||
size ('int' or `Tuple(int)`, *optional*, defaults to 384):
|
|
||||||
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
|
|
||||||
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
|
|
||||||
set to `True`.
|
|
||||||
ensure_multiple_of (`int`, *optional*, defaults to 1):
|
|
||||||
Ensure that the input is resized to a multiple of this value. Only has an effect if `do_resize` is set to
|
|
||||||
`True`.
|
|
||||||
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to keep the aspect ratio of the input. Only has an effect if `do_resize` is set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with mean and standard deviation.
|
|
||||||
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images.
|
|
||||||
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize=True,
|
|
||||||
size=384,
|
|
||||||
keep_aspect_ratio=False,
|
|
||||||
ensure_multiple_of=1,
|
|
||||||
resample=PILImageResampling.BILINEAR,
|
|
||||||
do_normalize=True,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.keep_aspect_ratio = keep_aspect_ratio
|
|
||||||
self.ensure_multiple_of = ensure_multiple_of
|
|
||||||
self.resample = resample
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
|
||||||
|
|
||||||
def constrain_to_multiple_of(self, size, min_val=0, max_val=None):
|
|
||||||
y = (np.round(size / self.ensure_multiple_of) * self.ensure_multiple_of).astype(int)
|
|
||||||
|
|
||||||
if max_val is not None and y > max_val:
|
|
||||||
y = (np.floor(size / self.ensure_multiple_of) * self.ensure_multiple_of).astype(int)
|
|
||||||
|
|
||||||
if y < min_val:
|
|
||||||
y = (np.ceil(size / self.ensure_multiple_of) * self.ensure_multiple_of).astype(int)
|
|
||||||
|
|
||||||
return y
|
|
||||||
|
|
||||||
def update_size(self, image):
|
|
||||||
image = self.to_pil_image(image)
|
|
||||||
width, height = image.size
|
|
||||||
|
|
||||||
size = self.size
|
|
||||||
|
|
||||||
if isinstance(size, list):
|
|
||||||
size = tuple(size)
|
|
||||||
|
|
||||||
if isinstance(size, int) or len(size) == 1:
|
|
||||||
size = (size, size)
|
|
||||||
|
|
||||||
# determine new width and height
|
|
||||||
scale_width = size[0] / width
|
|
||||||
scale_height = size[1] / height
|
|
||||||
|
|
||||||
if self.keep_aspect_ratio:
|
|
||||||
# scale as least as possbile
|
|
||||||
if abs(1 - scale_width) < abs(1 - scale_height):
|
|
||||||
# fit width
|
|
||||||
scale_height = scale_width
|
|
||||||
else:
|
|
||||||
# fit height
|
|
||||||
scale_width = scale_height
|
|
||||||
else:
|
|
||||||
new_width = self.constrain_to_multiple_of(scale_width * width)
|
|
||||||
new_height = self.constrain_to_multiple_of(scale_height * height)
|
|
||||||
|
|
||||||
return (new_width, new_height)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~file_utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
|
||||||
width).
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
# transformations (resizing + normalization)
|
|
||||||
if self.do_resize and self.size is not None:
|
|
||||||
for idx, image in enumerate(images):
|
|
||||||
size = self.update_size(image)
|
|
||||||
images[idx] = self.resize(image, size=size, resample=self.resample)
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
||||||
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
|
|
||||||
"""
|
|
||||||
Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
outputs ([`DPTForSemanticSegmentation`]):
|
|
||||||
Raw outputs of the model.
|
|
||||||
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
|
|
||||||
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
|
|
||||||
None, predictions will not be resized.
|
|
||||||
Returns:
|
|
||||||
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
|
||||||
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
|
||||||
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
|
||||||
"""
|
|
||||||
logits = outputs.logits
|
|
||||||
|
|
||||||
# Resize logits and compute semantic segmentation maps
|
|
||||||
if target_sizes is not None:
|
|
||||||
if len(logits) != len(target_sizes):
|
|
||||||
raise ValueError(
|
|
||||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_torch_tensor(target_sizes):
|
|
||||||
target_sizes = target_sizes.numpy()
|
|
||||||
|
|
||||||
semantic_segmentation = []
|
|
||||||
|
|
||||||
for idx in range(len(logits)):
|
|
||||||
resized_logits = torch.nn.functional.interpolate(
|
|
||||||
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
semantic_map = resized_logits[0].argmax(dim=0)
|
|
||||||
semantic_segmentation.append(semantic_map)
|
|
||||||
else:
|
|
||||||
semantic_segmentation = logits.argmax(dim=1)
|
|
||||||
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
|
||||||
|
|
||||||
return semantic_segmentation
|
|
||||||
|
|
|
@ -0,0 +1,384 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for DPT."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils import is_vision_available
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
get_image_size,
|
||||||
|
is_batched,
|
||||||
|
is_torch_available,
|
||||||
|
is_torch_tensor,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_resize_output_image_size(
|
||||||
|
input_image: np.ndarray, output_size: Union[int, Iterable[int]], keep_aspect_ratio: bool, multiple: int
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
def constraint_to_multiple_of(val, multiple, min_val=0, max_val=None):
|
||||||
|
x = round(val / multiple) * multiple
|
||||||
|
|
||||||
|
if max_val is not None and x > max_val:
|
||||||
|
x = math.floor(val / multiple) * multiple
|
||||||
|
|
||||||
|
if x < min_val:
|
||||||
|
x = math.ceil(val / multiple) * multiple
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
output_size = (output_size, output_size) if isinstance(output_size, int) else output_size
|
||||||
|
|
||||||
|
input_height, input_width = get_image_size(input_image)
|
||||||
|
output_height, output_width = output_size
|
||||||
|
|
||||||
|
# determine new height and width
|
||||||
|
scale_height = output_height / input_height
|
||||||
|
scale_width = output_width / input_width
|
||||||
|
|
||||||
|
if keep_aspect_ratio:
|
||||||
|
# scale as little as possible
|
||||||
|
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||||
|
# fit width
|
||||||
|
scale_height = scale_width
|
||||||
|
else:
|
||||||
|
# fit height
|
||||||
|
scale_width = scale_height
|
||||||
|
|
||||||
|
new_height = constraint_to_multiple_of(scale_height * input_height, multiple=multiple)
|
||||||
|
new_width = constraint_to_multiple_of(scale_width * input_width, multiple=multiple)
|
||||||
|
|
||||||
|
return (new_height, new_width)
|
||||||
|
|
||||||
|
|
||||||
|
class DPTImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a DPT image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 384}`):
|
||||||
|
Size of the image after resizing. Can be overidden by `size` in `preprocess`.
|
||||||
|
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
|
||||||
|
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
|
||||||
|
be overidden by `keep_aspect_ratio` in `preprocess`.
|
||||||
|
ensure_multiple_of (`int`, *optional*, defaults to `1`):
|
||||||
|
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden
|
||||||
|
by `ensure_multiple_of` in `preprocess`.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overidden by `do_rescale` in
|
||||||
|
`preprocess`.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overidden by `rescale_factor` in `preprocess`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||||
|
method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
keep_aspect_ratio: bool = False,
|
||||||
|
ensure_multiple_of: int = 1,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"height": 384, "width": 384}
|
||||||
|
size = get_size_dict(size)
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.keep_aspect_ratio = keep_aspect_ratio
|
||||||
|
self.ensure_multiple_of = ensure_multiple_of
|
||||||
|
self.resample = resample
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
keep_aspect_ratio: bool = False,
|
||||||
|
ensure_multiple_of: int = 1,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image
|
||||||
|
is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is
|
||||||
|
set, the image is resized to a size that is a multiple of this value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Target size of the output image.
|
||||||
|
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
|
||||||
|
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
|
||||||
|
ensure_multiple_of (`int`, *optional*, defaults to `1`):
|
||||||
|
The image is resized to a size that is a multiple of this value.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size
|
||||||
|
specified in `size`.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use when resiizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
if "height" not in size or "width" not in size:
|
||||||
|
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
|
||||||
|
output_size = get_resize_output_image_size(
|
||||||
|
image,
|
||||||
|
output_size=(size["height"], size["width"]),
|
||||||
|
keep_aspect_ratio=keep_aspect_ratio,
|
||||||
|
multiple=ensure_multiple_of,
|
||||||
|
)
|
||||||
|
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, List[float]],
|
||||||
|
std: Union[float, List[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image. image = (image - image_mean) / image_std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
image_mean (`float` or `List[float]`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`):
|
||||||
|
Image standard deviation.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: int = None,
|
||||||
|
keep_aspect_ratio: bool = None,
|
||||||
|
ensure_multiple_of: int = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image after reszing. If `keep_aspect_ratio` is `True`, the image is resized to the largest
|
||||||
|
possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is set, the image is
|
||||||
|
resized to a size that is a multiple of this value.
|
||||||
|
keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`):
|
||||||
|
Whether to keep the aspect ratio of the image. If False, the image will be resized to (size, size). If
|
||||||
|
True, the image will be resized to keep the aspect ratio and the size will be the maximum possible.
|
||||||
|
ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`):
|
||||||
|
Ensure that the image size is a multiple of this value.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image values between [0 - 1].
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size)
|
||||||
|
keep_aspect_ratio = keep_aspect_ratio if keep_aspect_ratio is not None else self.keep_aspect_ratio
|
||||||
|
ensure_multiple_of = ensure_multiple_of if ensure_multiple_of is not None else self.ensure_multiple_of
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None or resample is None:
|
||||||
|
raise ValueError("Size and resample must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_normalize and (image_mean is None or image_std is None):
|
||||||
|
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||||
|
|
||||||
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
|
data = {"pixel_values": images}
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
||||||
|
outputs ([`DPTForSemanticSegmentation`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
|
||||||
|
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
|
||||||
|
predictions will not be resized.
|
||||||
|
Returns:
|
||||||
|
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
||||||
|
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
||||||
|
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
||||||
|
"""
|
||||||
|
# TODO: add support for other frameworks
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# Resize logits and compute semantic segmentation maps
|
||||||
|
if target_sizes is not None:
|
||||||
|
if len(logits) != len(target_sizes):
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_torch_tensor(target_sizes):
|
||||||
|
target_sizes = target_sizes.numpy()
|
||||||
|
|
||||||
|
semantic_segmentation = []
|
||||||
|
|
||||||
|
for idx in range(len(logits)):
|
||||||
|
resized_logits = torch.nn.functional.interpolate(
|
||||||
|
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
semantic_map = resized_logits[0].argmax(dim=0)
|
||||||
|
semantic_segmentation.append(semantic_map)
|
||||||
|
else:
|
||||||
|
semantic_segmentation = logits.argmax(dim=1)
|
||||||
|
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
||||||
|
|
||||||
|
return semantic_segmentation
|
|
@ -14,344 +14,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for FLAVA."""
|
"""Feature extractor class for FLAVA."""
|
||||||
|
|
||||||
import math
|
from ...utils import logging
|
||||||
import random
|
from .image_processing_flava import FlavaImageProcessor
|
||||||
from functools import lru_cache
|
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
|
|
||||||
from ...utils import TensorType, logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
FlavaFeatureExtractor = FlavaImageProcessor
|
||||||
# These values are taken from CLIP
|
|
||||||
FLAVA_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
|
||||||
FLAVA_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711]
|
|
||||||
FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0]
|
|
||||||
FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0]
|
|
||||||
LOGIT_LAPLACE_EPS: float = 0.1
|
|
||||||
|
|
||||||
|
|
||||||
# Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py
|
|
||||||
class FlavaMaskingGenerator:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: Union[int, Tuple[int, int]] = 14,
|
|
||||||
total_mask_patches: int = 75,
|
|
||||||
mask_group_max_patches: Optional[int] = None,
|
|
||||||
mask_group_min_patches: int = 16,
|
|
||||||
mask_group_min_aspect_ratio: Optional[float] = 0.3,
|
|
||||||
mask_group_max_aspect_ratio: float = None,
|
|
||||||
):
|
|
||||||
if not isinstance(input_size, tuple):
|
|
||||||
input_size = (input_size,) * 2
|
|
||||||
self.height, self.width = input_size
|
|
||||||
|
|
||||||
self.num_patches = self.height * self.width
|
|
||||||
self.total_mask_patches = total_mask_patches
|
|
||||||
|
|
||||||
self.mask_group_min_patches = mask_group_min_patches
|
|
||||||
self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
|
|
||||||
|
|
||||||
mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
|
|
||||||
self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
|
|
||||||
self.height,
|
|
||||||
self.width,
|
|
||||||
self.mask_group_min_patches,
|
|
||||||
self.mask_group_max_patches,
|
|
||||||
self.total_mask_patches,
|
|
||||||
self.log_aspect_ratio[0],
|
|
||||||
self.log_aspect_ratio[1],
|
|
||||||
)
|
|
||||||
return repr_str
|
|
||||||
|
|
||||||
def get_shape(self):
|
|
||||||
return self.height, self.width
|
|
||||||
|
|
||||||
def _mask(self, mask, max_mask_patches):
|
|
||||||
delta = 0
|
|
||||||
for _attempt in range(10):
|
|
||||||
target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
|
|
||||||
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
|
||||||
height = int(round(math.sqrt(target_area * aspect_ratio)))
|
|
||||||
width = int(round(math.sqrt(target_area / aspect_ratio)))
|
|
||||||
if width < self.width and height < self.height:
|
|
||||||
top = random.randint(0, self.height - height)
|
|
||||||
left = random.randint(0, self.width - width)
|
|
||||||
|
|
||||||
num_masked = mask[top : top + height, left : left + width].sum()
|
|
||||||
# Overlap
|
|
||||||
if 0 < height * width - num_masked <= max_mask_patches:
|
|
||||||
for i in range(top, top + height):
|
|
||||||
for j in range(left, left + width):
|
|
||||||
if mask[i, j] == 0:
|
|
||||||
mask[i, j] = 1
|
|
||||||
delta += 1
|
|
||||||
|
|
||||||
if delta > 0:
|
|
||||||
break
|
|
||||||
return delta
|
|
||||||
|
|
||||||
def __call__(self):
|
|
||||||
mask = np.zeros(shape=self.get_shape(), dtype=int)
|
|
||||||
mask_count = 0
|
|
||||||
while mask_count < self.total_mask_patches:
|
|
||||||
max_mask_patches = self.total_mask_patches - mask_count
|
|
||||||
max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
|
|
||||||
|
|
||||||
delta = self._mask(mask, max_mask_patches)
|
|
||||||
if delta == 0:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
mask_count += delta
|
|
||||||
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
class FlavaFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|
||||||
r"""
|
|
||||||
Constructs a FLAVA feature extractor.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the input to a certain `size`.
|
|
||||||
size (`int`, *optional*, defaults to 224):
|
|
||||||
Resize the input to the given size. Only has an effect if `do_resize` is set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
do_center_crop (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
|
|
||||||
image is padded with 0's and then center cropped.
|
|
||||||
crop_size (`int`, *optional*, defaults to 224):
|
|
||||||
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with `image_mean` and `image_std`.
|
|
||||||
image_mean (`Tuple[float, float, float]`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images.
|
|
||||||
image_std (`Tuple[float, float, float]`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
|
||||||
input_size_patches (`int`, *optional*, defaults to 14):
|
|
||||||
Number of patches in the image in height and width direction. 14x14 = 196 total patches.
|
|
||||||
total_mask_patches (`int`, *optional*, defaults to 75):
|
|
||||||
Total number of patches that should be masked.
|
|
||||||
mask_group_min_patches (`int`, *optional*, defaults to 16):
|
|
||||||
Minimum number of patches that should be masked.
|
|
||||||
mask_group_max_patches (`int`, *optional*, defaults to None):
|
|
||||||
Maximum number of patches that should be masked.
|
|
||||||
mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
|
|
||||||
Minimum aspect ratio of the mask window.
|
|
||||||
mask_group_max_aspect_ratio (`float`, *optional*, defaults to None):
|
|
||||||
Maximum aspect ratio of the mask window
|
|
||||||
codebook_do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the input for codebook to a certain `codebook_size`.
|
|
||||||
codebook_size (`int`, *optional*, defaults to 224):
|
|
||||||
Resize the input for codebook to the given size. Only has an effect if `codebook_do_resize` is set to
|
|
||||||
`True`.
|
|
||||||
codebook_resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to crop the input for codebook at the center. If the input size is smaller than
|
|
||||||
`codebook_crop_size` along any edge, the image is padded with 0's and then center cropped.
|
|
||||||
codebook_crop_size (`int`, *optional*, defaults to 224):
|
|
||||||
Desired output size for codebook input when applying center-cropping. Only has an effect if
|
|
||||||
`codebook_do_center_crop` is set to `True`.
|
|
||||||
codebook_do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`.
|
|
||||||
codebook_image_mean (`Tuple[float, float, float]`, *optional*, defaults to `[0, 0, 0]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images for codebook.
|
|
||||||
codebook_image_std (`Tuple[float, float, float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images for codebook.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize: bool = True,
|
|
||||||
size: Union[int, Tuple[int, int]] = 224,
|
|
||||||
resample: int = PILImageResampling.BICUBIC,
|
|
||||||
do_center_crop: bool = True,
|
|
||||||
crop_size: Union[int, Tuple[int, int]] = 224,
|
|
||||||
do_normalize: bool = True,
|
|
||||||
image_mean: Tuple[float, float, float] = FLAVA_IMAGE_MEAN,
|
|
||||||
image_std: Tuple[float, float, float] = FLAVA_IMAGE_STD,
|
|
||||||
# Mask related params
|
|
||||||
input_size_patches: int = 14,
|
|
||||||
total_mask_patches: int = 75,
|
|
||||||
mask_group_min_patches: int = 16,
|
|
||||||
mask_group_max_patches: Optional[int] = None,
|
|
||||||
mask_group_min_aspect_ratio: float = 0.3,
|
|
||||||
mask_group_max_aspect_ratio: Optional[float] = None,
|
|
||||||
# Codebook related params
|
|
||||||
codebook_do_resize: bool = True,
|
|
||||||
codebook_size: bool = 112,
|
|
||||||
codebook_resample: int = PILImageResampling.LANCZOS,
|
|
||||||
codebook_do_center_crop: bool = True,
|
|
||||||
codebook_crop_size: int = 112,
|
|
||||||
codebook_do_map_pixels: bool = True,
|
|
||||||
codebook_do_normalize: bool = True,
|
|
||||||
codebook_image_mean: Tuple[float, float, float] = FLAVA_CODEBOOK_MEAN,
|
|
||||||
codebook_image_std: Tuple[float, float, float] = FLAVA_CODEBOOK_STD,
|
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.do_center_crop = do_center_crop
|
|
||||||
self.crop_size = crop_size
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean
|
|
||||||
self.image_std = image_std
|
|
||||||
|
|
||||||
self.input_size_patches = input_size_patches
|
|
||||||
self.total_mask_patches = total_mask_patches
|
|
||||||
self.mask_group_min_patches = mask_group_min_patches
|
|
||||||
self.mask_group_max_patches = mask_group_max_patches
|
|
||||||
self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio
|
|
||||||
self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio
|
|
||||||
|
|
||||||
self.codebook_do_resize = codebook_do_resize
|
|
||||||
self.codebook_size = codebook_size
|
|
||||||
self.codebook_resample = codebook_resample
|
|
||||||
self.codebook_do_center_crop = codebook_do_center_crop
|
|
||||||
self.codebook_crop_size = codebook_crop_size
|
|
||||||
self.codebook_do_map_pixels = codebook_do_map_pixels
|
|
||||||
self.codebook_do_normalize = codebook_do_normalize
|
|
||||||
self.codebook_image_mean = codebook_image_mean
|
|
||||||
self.codebook_image_std = codebook_image_std
|
|
||||||
|
|
||||||
@property
|
|
||||||
@lru_cache()
|
|
||||||
def masking_generator(self):
|
|
||||||
return FlavaMaskingGenerator(
|
|
||||||
input_size=self.input_size_patches,
|
|
||||||
total_mask_patches=self.total_mask_patches,
|
|
||||||
mask_group_min_patches=self.mask_group_min_patches,
|
|
||||||
mask_group_max_patches=self.mask_group_max_patches,
|
|
||||||
mask_group_min_aspect_ratio=self.mask_group_min_aspect_ratio,
|
|
||||||
mask_group_max_aspect_ratio=self.mask_group_max_aspect_ratio,
|
|
||||||
)
|
|
||||||
|
|
||||||
def map_pixels(self, x):
|
|
||||||
return (1 - 2 * LOGIT_LAPLACE_EPS) * x + LOGIT_LAPLACE_EPS
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
images: Union[
|
|
||||||
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
|
|
||||||
],
|
|
||||||
return_image_mask: Optional[bool] = None,
|
|
||||||
return_codebook_pixels: Optional[bool] = None,
|
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
||||||
**kwargs: Any
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
return_image_mask (`bool`, *optional*, defaults to None):
|
|
||||||
If True, the processor will return `bool_masked_pos` suggesting masks for image's patch version.
|
|
||||||
|
|
||||||
return_codebook_pixels (`bool`, *optional*, defaults to None):
|
|
||||||
If True, the processor will return `codebook_pixel_values` providing image pixels to be used with the
|
|
||||||
default FLAVA codebook. Used in pretraining by Masked Image Modeling (MIM) loss.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model.
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
if isinstance(images, (list, tuple)) and len(images) != 0:
|
|
||||||
self._ensure_format_supported(images[0])
|
|
||||||
else:
|
|
||||||
self._ensure_format_supported(images)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
images_for_codebook = images
|
|
||||||
|
|
||||||
# transformations (resizing + center cropping + normalization)
|
|
||||||
if self.do_resize and self.size is not None and self.resample is not None:
|
|
||||||
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
|
|
||||||
if self.do_center_crop and self.crop_size is not None:
|
|
||||||
images = [self.center_crop(image, self.crop_size) for image in images]
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
|
|
||||||
if return_codebook_pixels:
|
|
||||||
images = images_for_codebook
|
|
||||||
if self.codebook_do_resize and self.codebook_size is not None and self.codebook_resample is not None:
|
|
||||||
images = [
|
|
||||||
self.resize(image=image, size=self.codebook_size, resample=self.codebook_resample)
|
|
||||||
for image in images
|
|
||||||
]
|
|
||||||
if self.codebook_do_center_crop and self.codebook_crop_size is not None:
|
|
||||||
images = [self.center_crop(image, self.codebook_crop_size) for image in images]
|
|
||||||
if self.codebook_do_normalize:
|
|
||||||
images = [
|
|
||||||
self.normalize(image=image, mean=self.codebook_image_mean, std=self.codebook_image_std)
|
|
||||||
for image in images
|
|
||||||
]
|
|
||||||
if self.codebook_do_map_pixels:
|
|
||||||
images = [self.map_pixels(image) for image in images]
|
|
||||||
|
|
||||||
data["codebook_pixel_values"] = images
|
|
||||||
|
|
||||||
if return_image_mask:
|
|
||||||
masks = [self.masking_generator() for _ in images]
|
|
||||||
data["bool_masked_pos"] = masks
|
|
||||||
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
|
@ -0,0 +1,696 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for Flava."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils import is_vision_available
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
|
||||||
|
from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# These values are taken from CLIP
|
||||||
|
FLAVA_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
||||||
|
FLAVA_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711]
|
||||||
|
FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0]
|
||||||
|
FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0]
|
||||||
|
LOGIT_LAPLACE_EPS: float = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
# Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py
|
||||||
|
class FlavaMaskingGenerator:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: Union[int, Tuple[int, int]] = 14,
|
||||||
|
total_mask_patches: int = 75,
|
||||||
|
mask_group_max_patches: Optional[int] = None,
|
||||||
|
mask_group_min_patches: int = 16,
|
||||||
|
mask_group_min_aspect_ratio: Optional[float] = 0.3,
|
||||||
|
mask_group_max_aspect_ratio: float = None,
|
||||||
|
):
|
||||||
|
if not isinstance(input_size, tuple):
|
||||||
|
input_size = (input_size,) * 2
|
||||||
|
self.height, self.width = input_size
|
||||||
|
|
||||||
|
self.num_patches = self.height * self.width
|
||||||
|
self.total_mask_patches = total_mask_patches
|
||||||
|
|
||||||
|
self.mask_group_min_patches = mask_group_min_patches
|
||||||
|
self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
|
||||||
|
|
||||||
|
mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
|
||||||
|
self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
|
||||||
|
self.height,
|
||||||
|
self.width,
|
||||||
|
self.mask_group_min_patches,
|
||||||
|
self.mask_group_max_patches,
|
||||||
|
self.total_mask_patches,
|
||||||
|
self.log_aspect_ratio[0],
|
||||||
|
self.log_aspect_ratio[1],
|
||||||
|
)
|
||||||
|
return repr_str
|
||||||
|
|
||||||
|
def get_shape(self):
|
||||||
|
return self.height, self.width
|
||||||
|
|
||||||
|
def _mask(self, mask, max_mask_patches):
|
||||||
|
delta = 0
|
||||||
|
for _attempt in range(10):
|
||||||
|
target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
|
||||||
|
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
||||||
|
height = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
|
width = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
|
if width < self.width and height < self.height:
|
||||||
|
top = random.randint(0, self.height - height)
|
||||||
|
left = random.randint(0, self.width - width)
|
||||||
|
|
||||||
|
num_masked = mask[top : top + height, left : left + width].sum()
|
||||||
|
# Overlap
|
||||||
|
if 0 < height * width - num_masked <= max_mask_patches:
|
||||||
|
for i in range(top, top + height):
|
||||||
|
for j in range(left, left + width):
|
||||||
|
if mask[i, j] == 0:
|
||||||
|
mask[i, j] = 1
|
||||||
|
delta += 1
|
||||||
|
|
||||||
|
if delta > 0:
|
||||||
|
break
|
||||||
|
return delta
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
mask = np.zeros(shape=self.get_shape(), dtype=int)
|
||||||
|
mask_count = 0
|
||||||
|
while mask_count < self.total_mask_patches:
|
||||||
|
max_mask_patches = self.total_mask_patches - mask_count
|
||||||
|
max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
|
||||||
|
|
||||||
|
delta = self._mask(mask, max_mask_patches)
|
||||||
|
if delta == 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
mask_count += delta
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
class FlavaImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a Flava image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
||||||
|
`do_resize` parameter in `preprocess`.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||||
|
Size of the image after resizing. Can be overridden by the `size` parameter in `preprocess`.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in
|
||||||
|
`preprocess`.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to center crop the images. Can be overridden by the `do_center_crop` parameter in `preprocess`.
|
||||||
|
crop_size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||||
|
Size of image after the center crop `(crop_size["height"], crop_size["width"])`. Can be overridden by the
|
||||||
|
`crop_size` parameter in `preprocess`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||||
|
parameter in `preprocess`.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in
|
||||||
|
`preprocess`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in `preprocess`.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
return_image_mask (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`.
|
||||||
|
input_size_patches (`int`, *optional*, defaults to 14):
|
||||||
|
Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden
|
||||||
|
by the `input_size_patches` parameter in `preprocess`.
|
||||||
|
total_mask_patches (`int`, *optional*, defaults to 75):
|
||||||
|
Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in
|
||||||
|
`preprocess`.
|
||||||
|
mask_group_min_patches (`int`, *optional*, defaults to 16):
|
||||||
|
Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches`
|
||||||
|
parameter in `preprocess`.
|
||||||
|
mask_group_max_patches (`int`, *optional*):
|
||||||
|
Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches`
|
||||||
|
parameter in `preprocess`.
|
||||||
|
mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
|
||||||
|
Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter
|
||||||
|
in `preprocess`.
|
||||||
|
mask_group_max_aspect_ratio (`float`, *optional*):
|
||||||
|
Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter
|
||||||
|
in `preprocess`.
|
||||||
|
codebook_do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize`
|
||||||
|
parameter in `preprocess`. `codebook_size`.
|
||||||
|
codebook_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||||
|
Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in
|
||||||
|
`preprocess`.
|
||||||
|
codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
|
||||||
|
Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample`
|
||||||
|
parameter in `preprocess`.
|
||||||
|
codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to crop the input for codebook at the center. If the input size is smaller than
|
||||||
|
`codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be
|
||||||
|
overridden by the `codebook_do_center_crop` parameter in `preprocess`.
|
||||||
|
codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||||
|
Desired output size for codebook input when applying center-cropping. Can be overridden by the
|
||||||
|
`codebook_crop_size` parameter in `preprocess`.
|
||||||
|
codebook_do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be
|
||||||
|
overridden by the `codebook_do_rescale` parameter in `preprocess`.
|
||||||
|
codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Defines the scale factor to use if rescaling the codebook image. Can be overridden by the
|
||||||
|
`codebook_rescale_factor` parameter in `preprocess`.
|
||||||
|
codebook_do_map_pixels (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the
|
||||||
|
`codebook_do_map_pixels` parameter in `preprocess`.
|
||||||
|
codebook_do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can
|
||||||
|
be overridden by the `codebook_do_normalize` parameter in `preprocess`.
|
||||||
|
codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`):
|
||||||
|
The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden
|
||||||
|
by the `codebook_image_mean` parameter in `preprocess`.
|
||||||
|
codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||||
|
The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can
|
||||||
|
be overridden by the `codebook_image_std` parameter in `preprocess`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
do_center_crop: bool = True,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, Iterable[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, Iterable[float]]] = None,
|
||||||
|
# Mask related params
|
||||||
|
return_image_mask: bool = False,
|
||||||
|
input_size_patches: int = 14,
|
||||||
|
total_mask_patches: int = 75,
|
||||||
|
mask_group_min_patches: int = 16,
|
||||||
|
mask_group_max_patches: Optional[int] = None,
|
||||||
|
mask_group_min_aspect_ratio: float = 0.3,
|
||||||
|
mask_group_max_aspect_ratio: Optional[float] = None,
|
||||||
|
# Codebook related params
|
||||||
|
return_codebook_pixels: bool = False,
|
||||||
|
codebook_do_resize: bool = True,
|
||||||
|
codebook_size: bool = None,
|
||||||
|
codebook_resample: int = PILImageResampling.LANCZOS,
|
||||||
|
codebook_do_center_crop: bool = True,
|
||||||
|
codebook_crop_size: int = None,
|
||||||
|
codebook_do_rescale: bool = True,
|
||||||
|
codebook_rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
codebook_do_map_pixels: bool = True,
|
||||||
|
codebook_do_normalize: bool = True,
|
||||||
|
codebook_image_mean: Optional[Union[float, Iterable[float]]] = None,
|
||||||
|
codebook_image_std: Optional[Union[float, Iterable[float]]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"height": 224, "width": 224}
|
||||||
|
size = get_size_dict(size)
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
|
||||||
|
codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112}
|
||||||
|
codebook_size = get_size_dict(codebook_size)
|
||||||
|
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112}
|
||||||
|
codebook_crop_size = get_size_dict(codebook_crop_size)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_center_crop = do_center_crop
|
||||||
|
self.crop_size = crop_size
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else FLAVA_IMAGE_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else FLAVA_IMAGE_STD
|
||||||
|
|
||||||
|
self.return_image_mask = return_image_mask
|
||||||
|
self.input_size_patches = input_size_patches
|
||||||
|
self.total_mask_patches = total_mask_patches
|
||||||
|
self.mask_group_min_patches = mask_group_min_patches
|
||||||
|
self.mask_group_max_patches = mask_group_max_patches
|
||||||
|
self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio
|
||||||
|
self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio
|
||||||
|
|
||||||
|
self.return_codebook_pixels = return_codebook_pixels
|
||||||
|
self.codebook_do_resize = codebook_do_resize
|
||||||
|
self.codebook_size = codebook_size
|
||||||
|
self.codebook_resample = codebook_resample
|
||||||
|
self.codebook_do_center_crop = codebook_do_center_crop
|
||||||
|
self.codebook_crop_size = codebook_crop_size
|
||||||
|
self.codebook_do_rescale = codebook_do_rescale
|
||||||
|
self.codebook_rescale_factor = codebook_rescale_factor
|
||||||
|
self.codebook_do_map_pixels = codebook_do_map_pixels
|
||||||
|
self.codebook_do_normalize = codebook_do_normalize
|
||||||
|
self.codebook_image_mean = codebook_image_mean
|
||||||
|
self.codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else FLAVA_CODEBOOK_MEAN
|
||||||
|
self.codebook_image_std = codebook_image_std if codebook_image_std is not None else FLAVA_CODEBOOK_STD
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def masking_generator(
|
||||||
|
self,
|
||||||
|
input_size_patches,
|
||||||
|
total_mask_patches,
|
||||||
|
mask_group_min_patches,
|
||||||
|
mask_group_max_patches,
|
||||||
|
mask_group_min_aspect_ratio,
|
||||||
|
mask_group_max_aspect_ratio,
|
||||||
|
) -> FlavaMaskingGenerator:
|
||||||
|
return FlavaMaskingGenerator(
|
||||||
|
input_size=input_size_patches,
|
||||||
|
total_mask_patches=total_mask_patches,
|
||||||
|
mask_group_min_patches=mask_group_min_patches,
|
||||||
|
mask_group_max_patches=mask_group_max_patches,
|
||||||
|
mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
|
||||||
|
mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image to `(size["height"], size["width"])`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use when resiizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
if "height" not in size or "width" not in size:
|
||||||
|
raise ValueError(f"The size dictionary must contain 'height' and 'width' keys. Got {size.keys()}")
|
||||||
|
return resize(
|
||||||
|
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def center_crop(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
|
||||||
|
any edge, the image is padded with 0's and then center cropped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to center crop.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, List[float]],
|
||||||
|
std: Union[float, List[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image. image = (image - image_mean) / image_std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
image_mean (`float` or `List[float]`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`):
|
||||||
|
Image standard deviation.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def map_pixels(self, image: np.ndarray) -> np.ndarray:
|
||||||
|
return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS
|
||||||
|
|
||||||
|
def _preprocess_image(
|
||||||
|
self,
|
||||||
|
image: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_center_crop: bool = None,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_map_pixels: bool = None,
|
||||||
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Preprocesses a single image."""
|
||||||
|
if do_resize and size is None or resample is None:
|
||||||
|
raise ValueError("Size and resample must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_normalize and (image_mean is None or image_std is None):
|
||||||
|
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
image = to_numpy_array(image)
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
image = self.resize(image=image, size=size, resample=resample)
|
||||||
|
|
||||||
|
if do_center_crop:
|
||||||
|
image = self.center_crop(image=image, size=crop_size)
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
image = self.rescale(image=image, scale=rescale_factor)
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
image = self.normalize(image=image, mean=image_mean, std=image_std)
|
||||||
|
|
||||||
|
if do_map_pixels:
|
||||||
|
image = self.map_pixels(image)
|
||||||
|
|
||||||
|
if data_format is not None:
|
||||||
|
image = to_channel_dimension_format(image, data_format)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: Optional[bool] = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_center_crop: Optional[bool] = None,
|
||||||
|
crop_size: Optional[Dict[str, int]] = None,
|
||||||
|
do_rescale: Optional[bool] = None,
|
||||||
|
rescale_factor: Optional[float] = None,
|
||||||
|
do_normalize: Optional[bool] = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
# Mask related params
|
||||||
|
return_image_mask: Optional[bool] = None,
|
||||||
|
input_size_patches: Optional[int] = None,
|
||||||
|
total_mask_patches: Optional[int] = None,
|
||||||
|
mask_group_min_patches: Optional[int] = None,
|
||||||
|
mask_group_max_patches: Optional[int] = None,
|
||||||
|
mask_group_min_aspect_ratio: Optional[float] = None,
|
||||||
|
mask_group_max_aspect_ratio: Optional[float] = None,
|
||||||
|
# Codebook related params
|
||||||
|
return_codebook_pixels: Optional[bool] = None,
|
||||||
|
codebook_do_resize: Optional[bool] = None,
|
||||||
|
codebook_size: Optional[Dict[str, int]] = None,
|
||||||
|
codebook_resample: Optional[int] = None,
|
||||||
|
codebook_do_center_crop: Optional[bool] = None,
|
||||||
|
codebook_crop_size: Optional[Dict[str, int]] = None,
|
||||||
|
codebook_do_rescale: Optional[bool] = None,
|
||||||
|
codebook_rescale_factor: Optional[float] = None,
|
||||||
|
codebook_do_map_pixels: Optional[bool] = None,
|
||||||
|
codebook_do_normalize: Optional[bool] = None,
|
||||||
|
codebook_image_mean: Optional[Iterable[float]] = None,
|
||||||
|
codebook_image_std: Optional[Iterable[float]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
||||||
|
Whether to center crop the image.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
||||||
|
Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image values between [0 - 1].
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation.
|
||||||
|
return_image_mask (`bool`, *optional*, defaults to `self.return_image_mask`):
|
||||||
|
Whether to return the image mask.
|
||||||
|
input_size_patches (`int`, *optional*, defaults to `self.input_size_patches`):
|
||||||
|
Size of the patches to extract from the image.
|
||||||
|
total_mask_patches (`int`, *optional*, defaults to `self.total_mask_patches`):
|
||||||
|
Total number of patches to extract from the image.
|
||||||
|
mask_group_min_patches (`int`, *optional*, defaults to `self.mask_group_min_patches`):
|
||||||
|
Minimum number of patches to extract from the image.
|
||||||
|
mask_group_max_patches (`int`, *optional*, defaults to `self.mask_group_max_patches`):
|
||||||
|
Maximum number of patches to extract from the image.
|
||||||
|
mask_group_min_aspect_ratio (`float`, *optional*, defaults to `self.mask_group_min_aspect_ratio`):
|
||||||
|
Minimum aspect ratio of the patches to extract from the image.
|
||||||
|
mask_group_max_aspect_ratio (`float`, *optional*, defaults to `self.mask_group_max_aspect_ratio`):
|
||||||
|
Maximum aspect ratio of the patches to extract from the image.
|
||||||
|
return_codebook_pixels (`bool`, *optional*, defaults to `self.return_codebook_pixels`):
|
||||||
|
Whether to return the codebook pixels.
|
||||||
|
codebook_do_resize (`bool`, *optional*, defaults to `self.codebook_do_resize`):
|
||||||
|
Whether to resize the codebook pixels.
|
||||||
|
codebook_size (`Dict[str, int]`, *optional*, defaults to `self.codebook_size`):
|
||||||
|
Size of the codebook pixels.
|
||||||
|
codebook_resample (`int`, *optional*, defaults to `self.codebook_resample`):
|
||||||
|
Resampling filter to use if resizing the codebook pixels. This can be one of the enum
|
||||||
|
`PILImageResampling`, Only has an effect if `codebook_do_resize` is set to `True`.
|
||||||
|
codebook_do_center_crop (`bool`, *optional*, defaults to `self.codebook_do_center_crop`):
|
||||||
|
Whether to center crop the codebook pixels.
|
||||||
|
codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `self.codebook_crop_size`):
|
||||||
|
Size of the center crop of the codebook pixels. Only has an effect if `codebook_do_center_crop` is set
|
||||||
|
to `True`.
|
||||||
|
codebook_do_rescale (`bool`, *optional*, defaults to `self.codebook_do_rescale`):
|
||||||
|
Whether to rescale the codebook pixels values between [0 - 1].
|
||||||
|
codebook_rescale_factor (`float`, *optional*, defaults to `self.codebook_rescale_factor`):
|
||||||
|
Rescale factor to rescale the codebook pixels by if `codebook_do_rescale` is set to `True`.
|
||||||
|
codebook_do_map_pixels (`bool`, *optional*, defaults to `self.codebook_do_map_pixels`):
|
||||||
|
Whether to map the codebook pixels values.
|
||||||
|
codebook_do_normalize (`bool`, *optional*, defaults to `self.codebook_do_normalize`):
|
||||||
|
Whether to normalize the codebook pixels.
|
||||||
|
codebook_image_mean (`float` or `List[float]`, *optional*, defaults to `self.codebook_image_mean`):
|
||||||
|
Codebook pixels mean to normalize the codebook pixels by if `codebook_do_normalize` is set to `True`.
|
||||||
|
codebook_image_std (`float` or `List[float]`, *optional*, defaults to `self.codebook_image_std`):
|
||||||
|
Codebook pixels standard deviation to normalize the codebook pixels by if `codebook_do_normalize` is
|
||||||
|
set to `True`.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size)
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||||
|
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
|
||||||
|
return_image_mask = return_image_mask if return_image_mask is not None else self.return_image_mask
|
||||||
|
input_size_patches = input_size_patches if input_size_patches is not None else self.input_size_patches
|
||||||
|
total_mask_patches = total_mask_patches if total_mask_patches is not None else self.total_mask_patches
|
||||||
|
mask_group_min_patches = (
|
||||||
|
mask_group_min_patches if mask_group_min_patches is not None else self.mask_group_min_patches
|
||||||
|
)
|
||||||
|
mask_group_max_patches = (
|
||||||
|
mask_group_max_patches if mask_group_max_patches is not None else self.mask_group_max_patches
|
||||||
|
)
|
||||||
|
mask_group_min_aspect_ratio = (
|
||||||
|
mask_group_min_aspect_ratio
|
||||||
|
if mask_group_min_aspect_ratio is not None
|
||||||
|
else self.mask_group_min_aspect_ratio
|
||||||
|
)
|
||||||
|
mask_group_max_aspect_ratio = (
|
||||||
|
mask_group_max_aspect_ratio
|
||||||
|
if mask_group_max_aspect_ratio is not None
|
||||||
|
else self.mask_group_max_aspect_ratio
|
||||||
|
)
|
||||||
|
|
||||||
|
return_codebook_pixels = (
|
||||||
|
return_codebook_pixels if return_codebook_pixels is not None else self.return_codebook_pixels
|
||||||
|
)
|
||||||
|
codebook_do_resize = codebook_do_resize if codebook_do_resize is not None else self.codebook_do_resize
|
||||||
|
codebook_size = codebook_size if codebook_size is not None else self.codebook_size
|
||||||
|
codebook_size = get_size_dict(codebook_size)
|
||||||
|
codebook_resample = codebook_resample if codebook_resample is not None else self.codebook_resample
|
||||||
|
codebook_do_rescale = codebook_do_rescale if codebook_do_rescale is not None else self.codebook_do_rescale
|
||||||
|
codebook_rescale_factor = (
|
||||||
|
codebook_rescale_factor if codebook_rescale_factor is not None else self.codebook_rescale_factor
|
||||||
|
)
|
||||||
|
codebook_do_center_crop = (
|
||||||
|
codebook_do_center_crop if codebook_do_center_crop is not None else self.codebook_do_center_crop
|
||||||
|
)
|
||||||
|
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else self.codebook_crop_size
|
||||||
|
codebook_crop_size = get_size_dict(codebook_crop_size)
|
||||||
|
codebook_do_map_pixels = (
|
||||||
|
codebook_do_map_pixels if codebook_do_map_pixels is not None else self.codebook_do_map_pixels
|
||||||
|
)
|
||||||
|
codebook_do_normalize = (
|
||||||
|
codebook_do_normalize if codebook_do_normalize is not None else self.codebook_do_normalize
|
||||||
|
)
|
||||||
|
codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else self.codebook_image_mean
|
||||||
|
codebook_image_std = codebook_image_std if codebook_image_std is not None else self.codebook_image_std
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_images = [
|
||||||
|
self._preprocess_image(
|
||||||
|
image=img,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=resample,
|
||||||
|
do_center_crop=do_center_crop,
|
||||||
|
crop_size=crop_size,
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
do_normalize=do_normalize,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
do_map_pixels=False,
|
||||||
|
data_format=data_format,
|
||||||
|
)
|
||||||
|
for img in images
|
||||||
|
]
|
||||||
|
data = {"pixel_values": processed_images}
|
||||||
|
|
||||||
|
if return_codebook_pixels:
|
||||||
|
codebook_images = [
|
||||||
|
self._preprocess_image(
|
||||||
|
image=img,
|
||||||
|
do_resize=codebook_do_resize,
|
||||||
|
size=codebook_size,
|
||||||
|
resample=codebook_resample,
|
||||||
|
do_center_crop=codebook_do_center_crop,
|
||||||
|
crop_size=codebook_crop_size,
|
||||||
|
do_rescale=codebook_do_rescale,
|
||||||
|
rescale_factor=codebook_rescale_factor,
|
||||||
|
do_normalize=codebook_do_normalize,
|
||||||
|
image_mean=codebook_image_mean,
|
||||||
|
image_std=codebook_image_std,
|
||||||
|
do_map_pixels=codebook_do_map_pixels,
|
||||||
|
data_format=data_format,
|
||||||
|
)
|
||||||
|
for img in images
|
||||||
|
]
|
||||||
|
data["codebook_pixel_values"] = codebook_images
|
||||||
|
|
||||||
|
if return_image_mask:
|
||||||
|
mask_generator = self.masking_generator(
|
||||||
|
input_size_patches=input_size_patches,
|
||||||
|
total_mask_patches=total_mask_patches,
|
||||||
|
mask_group_min_patches=mask_group_min_patches,
|
||||||
|
mask_group_max_patches=mask_group_max_patches,
|
||||||
|
mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
|
||||||
|
mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
|
||||||
|
)
|
||||||
|
masks = [mask_generator() for _ in images]
|
||||||
|
data["bool_masked_pos"] = masks
|
||||||
|
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
@ -37,16 +37,16 @@ class GLPNImageProcessor(BaseImageProcessor):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
Set the class default for the `do_resize` parameter. Controls whether to resize the image's (height, width)
|
Whether to resize the image's (height, width) dimensions, rounding them down to the closest multiple of
|
||||||
dimensions, rounding them down to the closest multiple of `size_divisor`.
|
`size_divisor`. Can be overridden by `do_resize` in `preprocess`.
|
||||||
size_divisor (`int`, *optional*, defaults to 32):
|
size_divisor (`int`, *optional*, defaults to 32):
|
||||||
Set the class default for the `size_divisor` parameter. When `do_resize` is `True`, images are resized so
|
When `do_resize` is `True`, images are resized so their height and width are rounded down to the closest
|
||||||
their height and width are rounded down to the closest multiple of `size_divisor`.
|
multiple of `size_divisor`. Can be overridden by `size_divisor` in `preprocess`.
|
||||||
resample (`PIL.Image` resampling filter, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
|
resample (`PIL.Image` resampling filter, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
|
||||||
Set the class default for `resample`. Defines the resampling filter to use if resizing the image.
|
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
|
||||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
Set the class default for the `do_rescale` parameter. Controls whether or not to apply the scaling factor
|
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Can be
|
||||||
(to make pixel values floats between 0. and 1.).
|
overridden by `do_rescale` in `preprocess`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
model_input_names = ["pixel_values"]
|
||||||
|
@ -81,7 +81,7 @@ class GLPNImageProcessor(BaseImageProcessor):
|
||||||
`size_divisor`.
|
`size_divisor`.
|
||||||
resample:
|
resample:
|
||||||
`PIL.Image` resampling filter to use when resizing the image e.g. `PIL.Image.Resampling.BILINEAR`.
|
`PIL.Image` resampling filter to use when resizing the image e.g. `PIL.Image.Resampling.BILINEAR`.
|
||||||
data_format (`ChannelDimension`, *optional*):
|
data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
The channel dimension format for the output image. If `None`, the channel dimension format of the input
|
The channel dimension format for the output image. If `None`, the channel dimension format of the input
|
||||||
image is used. Can be one of:
|
image is used. Can be one of:
|
||||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
@ -108,7 +108,7 @@ class GLPNImageProcessor(BaseImageProcessor):
|
||||||
The image to rescale.
|
The image to rescale.
|
||||||
scale (`float`):
|
scale (`float`):
|
||||||
The scaling factor to rescale pixel values by.
|
The scaling factor to rescale pixel values by.
|
||||||
data_format (`ChannelDimension`, *optional*):
|
data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
The channel dimension format for the output image. If `None`, the channel dimension format of the input
|
The channel dimension format for the output image. If `None`, the channel dimension format of the input
|
||||||
image is used. Can be one of:
|
image is used. Can be one of:
|
||||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
@ -146,14 +146,14 @@ class GLPNImageProcessor(BaseImageProcessor):
|
||||||
has an effect if `do_resize` is set to `True`.
|
has an effect if `do_resize` is set to `True`.
|
||||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.).
|
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.).
|
||||||
return_tensors (`str`, *optional*):
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
The type of tensors to return. Can be one of:
|
The type of tensors to return. Can be one of:
|
||||||
- `None`: Return a list of `np.ndarray`.
|
- `None`: Return a list of `np.ndarray`.
|
||||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
The channel dimension format for the output image. Can be one of:
|
The channel dimension format for the output image. Can be one of:
|
||||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
|
|
@ -14,168 +14,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for ImageGPT."""
|
"""Feature extractor class for ImageGPT."""
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_imagegpt import ImageGPTImageProcessor
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
|
|
||||||
from ...utils import TensorType, logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def squared_euclidean_distance(a, b):
|
ImageGPTFeatureExtractor = ImageGPTImageProcessor
|
||||||
b = b.T
|
|
||||||
a2 = np.sum(np.square(a), axis=1)
|
|
||||||
b2 = np.sum(np.square(b), axis=0)
|
|
||||||
ab = np.matmul(a, b)
|
|
||||||
d = a2[:, None] - 2 * ab + b2[None, :]
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
def color_quantize(x, clusters):
|
|
||||||
x = x.reshape(-1, 3)
|
|
||||||
d = squared_euclidean_distance(x, clusters)
|
|
||||||
return np.argmin(d, axis=1)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|
||||||
r"""
|
|
||||||
Constructs an ImageGPT feature extractor. This feature extractor can be used to resize images to a smaller
|
|
||||||
resolution (such as 32x32 or 64x64), normalize them and finally color quantize them to obtain sequences of "pixel
|
|
||||||
values" (color clusters).
|
|
||||||
|
|
||||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
clusters (`np.ndarray`):
|
|
||||||
The color clusters to use, as a `np.ndarray` of shape `(n_clusters, 3)`.
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the input to a certain `size`.
|
|
||||||
size (`int` or `Tuple(int)`, *optional*, defaults to 32):
|
|
||||||
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
|
|
||||||
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
|
|
||||||
set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input to the range between -1 and +1.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["input_ids"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, clusters, do_resize=True, size=32, resample=PILImageResampling.BILINEAR, do_normalize=True, **kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.clusters = np.asarray(clusters)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
|
|
||||||
def normalize(self, image):
|
|
||||||
"""
|
|
||||||
Normalizes `image` into the range -1 to +1.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
|
||||||
The image to normalize.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`np.ndarray`: The normalized image.
|
|
||||||
"""
|
|
||||||
image = self.to_numpy_array(image, rescale=False, channel_first=False)
|
|
||||||
|
|
||||||
return image / 127.5 - 1
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
images: Union[
|
|
||||||
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
|
|
||||||
],
|
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **input_ids** -- Input IDs to be fed to a model, of shape `(batch_size, height * width)`.
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
# transformations (resizing + normalization)
|
|
||||||
if self.do_resize and self.size is not None:
|
|
||||||
images = [self.resize(image, size=self.size, resample=self.resample) for image in images]
|
|
||||||
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image) for image in images]
|
|
||||||
|
|
||||||
# color quantize from (batch_size, height, width, 3) to (batch_size, height, width)
|
|
||||||
images = np.array(images)
|
|
||||||
images = color_quantize(images, self.clusters).reshape(images.shape[:-1])
|
|
||||||
|
|
||||||
# flatten to (batch_size, height*width)
|
|
||||||
batch_size = images.shape[0]
|
|
||||||
images = images.reshape(batch_size, -1)
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"input_ids": images}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
|
@ -0,0 +1,239 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for ImageGPT."""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils import is_vision_available
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import rescale, resize, to_channel_dimension_format
|
||||||
|
from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def squared_euclidean_distance(a, b):
|
||||||
|
b = b.T
|
||||||
|
a2 = np.sum(np.square(a), axis=1)
|
||||||
|
b2 = np.sum(np.square(b), axis=0)
|
||||||
|
ab = np.matmul(a, b)
|
||||||
|
d = a2[:, None] - 2 * ab + b2[None, :]
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def color_quantize(x, clusters):
|
||||||
|
x = x.reshape(-1, 3)
|
||||||
|
d = squared_euclidean_distance(x, clusters)
|
||||||
|
return np.argmin(d, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGPTImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a ImageGPT image processor. This image processor can be used to resize images to a smaller resolution
|
||||||
|
(such as 32x32 or 64x64), normalize them and finally color quantize them to obtain sequences of "pixel values"
|
||||||
|
(color clusters).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clusters (`np.ndarray`, *optional*):
|
||||||
|
The color clusters to use, as a `np.ndarray` of shape `(n_clusters, 3)` when color quantizing. Can be
|
||||||
|
overriden by `clusters` in `preprocess`.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's dimensions to `(size["height"], size["width"])`. Can be overridden by
|
||||||
|
`do_resize` in `preprocess`.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
|
||||||
|
Size of the image after resizing. Can be overridden by `size` in `preprocess`.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image pixel value to between [-1, 1]. Can be overridden by `do_normalize` in
|
||||||
|
`preprocess`.
|
||||||
|
do_color_quantize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to color quantize the image. Can be overridden by `do_color_quantize` in `preprocess`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# clusters is a first argument to maintain backwards compatibility with the old ImageGPTFeatureExtractor
|
||||||
|
clusters: Optional[np.ndarray] = None,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
do_color_quantize: bool = True,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"height": 256, "width": 256}
|
||||||
|
size = get_size_dict(size)
|
||||||
|
self.clusters = clusters
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.do_color_quantize = do_color_quantize
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image to (size["height"], size["width"]).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Resampling filter to use when resizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
if "height" not in size or "width" not in size:
|
||||||
|
raise ValueError(f"Size dictionary must contain both height and width keys. Got {size.keys()}")
|
||||||
|
return resize(
|
||||||
|
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalizes an images' pixel values to between [-1, 1].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
image = rescale(image=image, scale=1 / 127.5, data_format=data_format)
|
||||||
|
image = image - 1
|
||||||
|
return image
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
do_color_quantize: Optional[bool] = None,
|
||||||
|
clusters: Optional[Union[int, List[int]]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image after resizing.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image
|
||||||
|
do_color_quantize (`bool`, *optional*, defaults to `self.do_color_quantize`):
|
||||||
|
Whether to color quantize the image.
|
||||||
|
clusters (`np.ndarray`, *optional*, defaults to `self.clusters`):
|
||||||
|
Clusters used to quantize the image of shape `(n_clusters, 3)`. Only has an effect if
|
||||||
|
`do_color_quantize` is set to `True`.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
Only has an effect if `do_color_quantize` is set to `False`.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size)
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
do_color_quantize = do_color_quantize if do_color_quantize is not None else self.do_color_quantize
|
||||||
|
clusters = clusters if clusters is not None else self.clusters
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None or resample is None:
|
||||||
|
raise ValueError("Size and resample must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_color_quantize and clusters is None:
|
||||||
|
raise ValueError("Clusters must be specified if do_color_quantize is True.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
images = [self.normalize(image=image) for image in images]
|
||||||
|
|
||||||
|
if do_color_quantize:
|
||||||
|
images = [to_channel_dimension_format(image, ChannelDimension.LAST) for image in images]
|
||||||
|
# color quantize from (batch_size, height, width, 3) to (batch_size, height, width)
|
||||||
|
images = np.array(images)
|
||||||
|
clusters = np.array(clusters)
|
||||||
|
images = color_quantize(images, clusters).reshape(images.shape[:-1])
|
||||||
|
|
||||||
|
# flatten to (batch_size, height*width)
|
||||||
|
batch_size = images.shape[0]
|
||||||
|
images = images.reshape(batch_size, -1)
|
||||||
|
|
||||||
|
# We need to convert back to a list of images to keep consistent behaviour across processors.
|
||||||
|
images = list(images)
|
||||||
|
else:
|
||||||
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
|
data = {"input_ids": images}
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
@ -992,11 +992,12 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
|
||||||
... )
|
... )
|
||||||
|
|
||||||
>>> clusters = feature_extractor.clusters
|
>>> clusters = feature_extractor.clusters
|
||||||
>>> n_px = feature_extractor.size
|
>>> height = feature_extractor.size["height"]
|
||||||
|
>>> width = feature_extractor.size["width"]
|
||||||
|
|
||||||
>>> samples = output[:, 1:].cpu().detach().numpy()
|
>>> samples = output[:, 1:].cpu().detach().numpy()
|
||||||
>>> samples_img = [
|
>>> samples_img = [
|
||||||
... np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples
|
... np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [height, width, 3]).astype(np.uint8) for s in samples
|
||||||
... ] # convert color cluster tokens back to pixels
|
... ] # convert color cluster tokens back to pixels
|
||||||
>>> f, axes = plt.subplots(1, batch_size, dpi=300)
|
>>> f, axes = plt.subplots(1, batch_size, dpi=300)
|
||||||
|
|
||||||
|
|
|
@ -16,226 +16,10 @@
|
||||||
Feature extractor class for LayoutLMv2.
|
Feature extractor class for LayoutLMv2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_layoutlmv2 import LayoutLMv2ImageProcessor
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
|
|
||||||
from ...utils import TensorType, is_pytesseract_available, logging, requires_backends
|
|
||||||
|
|
||||||
|
|
||||||
# soft dependency
|
|
||||||
if is_pytesseract_available():
|
|
||||||
import pytesseract
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
ImageInput = Union[
|
LayoutLMv2FeatureExtractor = LayoutLMv2ImageProcessor
|
||||||
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_box(box, width, height):
|
|
||||||
return [
|
|
||||||
int(1000 * (box[0] / width)),
|
|
||||||
int(1000 * (box[1] / height)),
|
|
||||||
int(1000 * (box[2] / width)),
|
|
||||||
int(1000 * (box[3] / height)),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def apply_tesseract(image: Image.Image, lang: Optional[str], tesseract_config: Optional[str]):
|
|
||||||
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
|
|
||||||
|
|
||||||
# apply OCR
|
|
||||||
data = pytesseract.image_to_data(image, lang=lang, output_type="dict", config=tesseract_config)
|
|
||||||
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
|
|
||||||
|
|
||||||
# filter empty words and corresponding coordinates
|
|
||||||
irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
|
|
||||||
words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
|
|
||||||
left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
|
|
||||||
top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
|
|
||||||
width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
|
|
||||||
height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
|
|
||||||
|
|
||||||
# turn coordinates into (left, top, left+width, top+height) format
|
|
||||||
actual_boxes = []
|
|
||||||
for x, y, w, h in zip(left, top, width, height):
|
|
||||||
actual_box = [x, y, x + w, y + h]
|
|
||||||
actual_boxes.append(actual_box)
|
|
||||||
|
|
||||||
image_width, image_height = image.size
|
|
||||||
|
|
||||||
# finally, normalize the bounding boxes
|
|
||||||
normalized_boxes = []
|
|
||||||
for box in actual_boxes:
|
|
||||||
normalized_boxes.append(normalize_box(box, image_width, image_height))
|
|
||||||
|
|
||||||
assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
|
|
||||||
|
|
||||||
return words, normalized_boxes
|
|
||||||
|
|
||||||
|
|
||||||
class LayoutLMv2FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|
||||||
r"""
|
|
||||||
Constructs a LayoutLMv2 feature extractor. This can be used to resize document images to the same size, as well as
|
|
||||||
to apply OCR on them in order to get a list of words and normalized bounding boxes.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`~feature_extraction_utils.PreTrainedFeatureExtractor`] which contains most
|
|
||||||
of the main methods. Users should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the input to a certain `size`.
|
|
||||||
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
|
|
||||||
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
|
|
||||||
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
|
|
||||||
set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
apply_ocr (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
|
|
||||||
ocr_lang (`str`, *optional*):
|
|
||||||
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
|
|
||||||
used.
|
|
||||||
tesseract_config (`str`, *optional*):
|
|
||||||
Any additional custom configuration flags that are forwarded to the `config` parameter when calling
|
|
||||||
Tesseract. For example: '--psm 6'.
|
|
||||||
|
|
||||||
<Tip>
|
|
||||||
|
|
||||||
LayoutLMv2FeatureExtractor uses Google's Tesseract OCR engine under the hood.
|
|
||||||
|
|
||||||
</Tip>"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize=True,
|
|
||||||
size=224,
|
|
||||||
resample=PILImageResampling.BILINEAR,
|
|
||||||
apply_ocr=True,
|
|
||||||
ocr_lang=None,
|
|
||||||
tesseract_config="",
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.apply_ocr = apply_ocr
|
|
||||||
self.ocr_lang = ocr_lang
|
|
||||||
self.tesseract_config = tesseract_config
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
|
||||||
width).
|
|
||||||
- **words** -- Optional words as identified by Tesseract OCR (only when [`LayoutLMv2FeatureExtractor`] was
|
|
||||||
initialized with `apply_ocr` set to `True`).
|
|
||||||
- **boxes** -- Optional bounding boxes as identified by Tesseract OCR, normalized based on the image size
|
|
||||||
(only when [`LayoutLMv2FeatureExtractor`] was initialized with `apply_ocr` set to `True`).
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import LayoutLMv2FeatureExtractor
|
|
||||||
>>> from PIL import Image
|
|
||||||
|
|
||||||
>>> # Document can be a png, jpg, etc. PDFs must be converted to images.
|
|
||||||
>>> image = Image.open(name_of_your_document).convert("RGB")
|
|
||||||
|
|
||||||
>>> # option 1: with apply_ocr=True (default)
|
|
||||||
>>> feature_extractor = LayoutLMv2FeatureExtractor()
|
|
||||||
>>> encoding = feature_extractor(image, return_tensors="pt")
|
|
||||||
>>> print(encoding.keys())
|
|
||||||
>>> # dict_keys(['pixel_values', 'words', 'boxes'])
|
|
||||||
|
|
||||||
>>> # option 2: with apply_ocr=False
|
|
||||||
>>> feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
|
||||||
>>> encoding = feature_extractor(image, return_tensors="pt")
|
|
||||||
>>> print(encoding.keys())
|
|
||||||
>>> # dict_keys(['pixel_values'])
|
|
||||||
```"""
|
|
||||||
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples), "
|
|
||||||
f"but is of type {type(images)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
# Tesseract OCR to get words + normalized bounding boxes
|
|
||||||
if self.apply_ocr:
|
|
||||||
requires_backends(self, "pytesseract")
|
|
||||||
words_batch = []
|
|
||||||
boxes_batch = []
|
|
||||||
for image in images:
|
|
||||||
words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang, self.tesseract_config)
|
|
||||||
words_batch.append(words)
|
|
||||||
boxes_batch.append(boxes)
|
|
||||||
|
|
||||||
# transformations (resizing)
|
|
||||||
if self.do_resize and self.size is not None:
|
|
||||||
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
|
|
||||||
|
|
||||||
images = [self.to_numpy_array(image, rescale=False) for image in images]
|
|
||||||
# flip color channels from RGB to BGR (as Detectron2 requires this)
|
|
||||||
images = [image[::-1, :, :] for image in images]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
if self.apply_ocr:
|
|
||||||
encoded_inputs["words"] = words_batch
|
|
||||||
encoded_inputs["boxes"] = boxes_batch
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
|
@ -0,0 +1,268 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for LayoutLMv2."""
|
||||||
|
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils import is_vision_available
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import resize, to_channel_dimension_format, to_pil_image
|
||||||
|
from ...image_utils import (
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
infer_channel_dimension_format,
|
||||||
|
is_batched,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import is_pytesseract_available, logging, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
# soft dependency
|
||||||
|
if is_pytesseract_available():
|
||||||
|
import pytesseract
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_box(box, width, height):
|
||||||
|
return [
|
||||||
|
int(1000 * (box[0] / width)),
|
||||||
|
int(1000 * (box[1] / height)),
|
||||||
|
int(1000 * (box[2] / width)),
|
||||||
|
int(1000 * (box[3] / height)),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Optional[str] = None):
|
||||||
|
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
|
||||||
|
tesseract_config = tesseract_config if tesseract_config is not None else ""
|
||||||
|
|
||||||
|
# apply OCR
|
||||||
|
pil_image = to_pil_image(image)
|
||||||
|
image_width, image_height = pil_image.size
|
||||||
|
data = pytesseract.image_to_data(pil_image, lang=lang, output_type="dict", config=tesseract_config)
|
||||||
|
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
|
||||||
|
|
||||||
|
# filter empty words and corresponding coordinates
|
||||||
|
irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
|
||||||
|
words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
|
||||||
|
left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
|
||||||
|
top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
|
||||||
|
width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
|
||||||
|
height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
|
||||||
|
|
||||||
|
# turn coordinates into (left, top, left+width, top+height) format
|
||||||
|
actual_boxes = []
|
||||||
|
for x, y, w, h in zip(left, top, width, height):
|
||||||
|
actual_box = [x, y, x + w, y + h]
|
||||||
|
actual_boxes.append(actual_box)
|
||||||
|
|
||||||
|
# finally, normalize the bounding boxes
|
||||||
|
normalized_boxes = []
|
||||||
|
for box in actual_boxes:
|
||||||
|
normalized_boxes.append(normalize_box(box, image_width, image_height))
|
||||||
|
|
||||||
|
assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
|
||||||
|
|
||||||
|
return words, normalized_boxes
|
||||||
|
|
||||||
|
|
||||||
|
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
|
||||||
|
input_data_format = infer_channel_dimension_format(image)
|
||||||
|
if input_data_format == ChannelDimension.LAST:
|
||||||
|
image = image[..., ::-1]
|
||||||
|
elif input_data_format == ChannelDimension.FIRST:
|
||||||
|
image = image[:, ::-1, ...]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
|
||||||
|
|
||||||
|
if data_format is not None:
|
||||||
|
image = to_channel_dimension_format(image, data_format)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class LayoutLMv2ImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a LayoutLMv2 image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to `(size["height"], size["width"])`. Can be
|
||||||
|
overridden by `do_resize` in `preprocess`.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||||
|
Size of the image after resizing. Can be overridden by `size` in `preprocess`.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
apply_ocr (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes. Can be overridden by
|
||||||
|
`apply_ocr` in `preprocess`.
|
||||||
|
ocr_lang (`str`, *optional*):
|
||||||
|
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
|
||||||
|
used. Can be overridden by `ocr_lang` in `preprocess`.
|
||||||
|
tesseract_config (`str`, *optional*):
|
||||||
|
Any additional custom configuration flags that are forwarded to the `config` parameter when calling
|
||||||
|
Tesseract. For example: '--psm 6'. Can be overridden by `tesseract_config` in `preprocess`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
apply_ocr: bool = True,
|
||||||
|
ocr_lang: Optional[str] = None,
|
||||||
|
tesseract_config: Optional[str] = "",
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"height": 224, "width": 224}
|
||||||
|
size = get_size_dict(size)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.apply_ocr = apply_ocr
|
||||||
|
self.ocr_lang = ocr_lang
|
||||||
|
self.tesseract_config = tesseract_config
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image to `(size["height"], size["width"])`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Resampling filter to use when resizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
if "height" not in size or "width" not in size:
|
||||||
|
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
|
||||||
|
output_size = (size["height"], size["width"])
|
||||||
|
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
apply_ocr: bool = None,
|
||||||
|
ocr_lang: Optional[str] = None,
|
||||||
|
tesseract_config: Optional[str] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Desired size of the output image after resizing.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PIL.Image` resampling
|
||||||
|
filter. Only has an effect if `do_resize` is set to `True`.
|
||||||
|
apply_ocr (`bool`, *optional*, defaults to `self.apply_ocr`):
|
||||||
|
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
|
||||||
|
ocr_lang (`str`, *optional*, defaults to `self.ocr_lang`):
|
||||||
|
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
|
||||||
|
used.
|
||||||
|
tesseract_config (`str`, *optional*, defaults to `self.tesseract_config`):
|
||||||
|
Any additional custom configuration flags that are forwarded to the `config` parameter when calling
|
||||||
|
Tesseract.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size)
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
apply_ocr = apply_ocr if apply_ocr is not None else self.apply_ocr
|
||||||
|
ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang
|
||||||
|
tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None:
|
||||||
|
raise ValueError("Size must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
if apply_ocr:
|
||||||
|
requires_backends(self, "pytesseract")
|
||||||
|
words_batch = []
|
||||||
|
boxes_batch = []
|
||||||
|
for image in images:
|
||||||
|
words, boxes = apply_tesseract(image, ocr_lang, tesseract_config)
|
||||||
|
words_batch.append(words)
|
||||||
|
boxes_batch.append(boxes)
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||||
|
|
||||||
|
# flip color channels from RGB to BGR (as Detectron2 requires this)
|
||||||
|
images = [flip_channel_order(image) for image in images]
|
||||||
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
|
data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
if apply_ocr:
|
||||||
|
data["words"] = words_batch
|
||||||
|
data["boxes"] = boxes_batch
|
||||||
|
return data
|
|
@ -16,235 +16,11 @@
|
||||||
Feature extractor class for LayoutLMv3.
|
Feature extractor class for LayoutLMv3.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_layoutlmv3 import LayoutLMv3ImageProcessor
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageFeatureExtractionMixin, is_torch_tensor
|
|
||||||
from ...utils import TensorType, is_pytesseract_available, logging, requires_backends
|
|
||||||
|
|
||||||
|
|
||||||
# soft dependency
|
|
||||||
if is_pytesseract_available():
|
|
||||||
import pytesseract
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
ImageInput = Union[
|
|
||||||
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
|
|
||||||
]
|
|
||||||
|
|
||||||
|
LayoutLMv3FeatureExtractor = LayoutLMv3ImageProcessor
|
||||||
def normalize_box(box, width, height):
|
|
||||||
return [
|
|
||||||
int(1000 * (box[0] / width)),
|
|
||||||
int(1000 * (box[1] / height)),
|
|
||||||
int(1000 * (box[2] / width)),
|
|
||||||
int(1000 * (box[3] / height)),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def apply_tesseract(image: Image.Image, lang: Optional[str], tesseract_config: Optional[str]):
|
|
||||||
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
|
|
||||||
# apply OCR
|
|
||||||
data = pytesseract.image_to_data(image, lang=lang, output_type="dict", config=tesseract_config)
|
|
||||||
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
|
|
||||||
|
|
||||||
# filter empty words and corresponding coordinates
|
|
||||||
irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
|
|
||||||
words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
|
|
||||||
left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
|
|
||||||
top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
|
|
||||||
width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
|
|
||||||
height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
|
|
||||||
|
|
||||||
# turn coordinates into (left, top, left+width, top+height) format
|
|
||||||
actual_boxes = []
|
|
||||||
for x, y, w, h in zip(left, top, width, height):
|
|
||||||
actual_box = [x, y, x + w, y + h]
|
|
||||||
actual_boxes.append(actual_box)
|
|
||||||
|
|
||||||
image_width, image_height = image.size
|
|
||||||
|
|
||||||
# finally, normalize the bounding boxes
|
|
||||||
normalized_boxes = []
|
|
||||||
for box in actual_boxes:
|
|
||||||
normalized_boxes.append(normalize_box(box, image_width, image_height))
|
|
||||||
|
|
||||||
assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
|
|
||||||
|
|
||||||
return words, normalized_boxes
|
|
||||||
|
|
||||||
|
|
||||||
class LayoutLMv3FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|
||||||
r"""
|
|
||||||
Constructs a LayoutLMv3 feature extractor. This can be used to resize + normalize document images, as well as to
|
|
||||||
apply OCR on them in order to get a list of words and normalized bounding boxes.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`~feature_extraction_utils.PreTrainedFeatureExtractor`] which contains most
|
|
||||||
of the main methods. Users should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the input to a certain `size`.
|
|
||||||
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
|
|
||||||
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
|
|
||||||
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
|
|
||||||
set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with mean and standard deviation.
|
|
||||||
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images.
|
|
||||||
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
|
||||||
apply_ocr (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
|
|
||||||
ocr_lang (`str`, *optional*):
|
|
||||||
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
|
|
||||||
used.
|
|
||||||
tesseract_config (`str`, *optional*):
|
|
||||||
Any additional custom configuration flags that are forwarded to the `config` parameter when calling
|
|
||||||
Tesseract. For example: '--psm 6'.
|
|
||||||
|
|
||||||
<Tip>
|
|
||||||
|
|
||||||
LayoutLMv3FeatureExtractor uses Google's Tesseract OCR engine under the hood.
|
|
||||||
|
|
||||||
</Tip>"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize=True,
|
|
||||||
size=224,
|
|
||||||
resample=PILImageResampling.BILINEAR,
|
|
||||||
do_normalize=True,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
apply_ocr=True,
|
|
||||||
ocr_lang=None,
|
|
||||||
tesseract_config="",
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
|
||||||
self.apply_ocr = apply_ocr
|
|
||||||
self.ocr_lang = ocr_lang
|
|
||||||
self.tesseract_config = tesseract_config
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
|
||||||
width).
|
|
||||||
- **words** -- Optional words as identified by Tesseract OCR (only when [`LayoutLMv3FeatureExtractor`] was
|
|
||||||
initialized with `apply_ocr` set to `True`).
|
|
||||||
- **boxes** -- Optional bounding boxes as identified by Tesseract OCR, normalized based on the image size
|
|
||||||
(only when [`LayoutLMv3FeatureExtractor`] was initialized with `apply_ocr` set to `True`).
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import LayoutLMv3FeatureExtractor
|
|
||||||
>>> from PIL import Image
|
|
||||||
|
|
||||||
>>> # Document can be a png, jpg, etc. PDFs must be converted to images.
|
|
||||||
>>> image = Image.open(name_of_your_document).convert("RGB")
|
|
||||||
|
|
||||||
>>> # option 1: with apply_ocr=True (default)
|
|
||||||
>>> feature_extractor = LayoutLMv3FeatureExtractor()
|
|
||||||
>>> encoding = feature_extractor(image, return_tensors="pt")
|
|
||||||
>>> print(encoding.keys())
|
|
||||||
>>> # dict_keys(['pixel_values', 'words', 'boxes'])
|
|
||||||
|
|
||||||
>>> # option 2: with apply_ocr=False
|
|
||||||
>>> feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
|
|
||||||
>>> encoding = feature_extractor(image, return_tensors="pt")
|
|
||||||
>>> print(encoding.keys())
|
|
||||||
>>> # dict_keys(['pixel_values'])
|
|
||||||
```"""
|
|
||||||
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples), "
|
|
||||||
f"but is of type {type(images)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
# Tesseract OCR to get words + normalized bounding boxes
|
|
||||||
if self.apply_ocr:
|
|
||||||
requires_backends(self, "pytesseract")
|
|
||||||
words_batch = []
|
|
||||||
boxes_batch = []
|
|
||||||
for image in images:
|
|
||||||
words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang, self.tesseract_config)
|
|
||||||
words_batch.append(words)
|
|
||||||
boxes_batch.append(boxes)
|
|
||||||
|
|
||||||
# transformations (resizing + normalization)
|
|
||||||
if self.do_resize and self.size is not None:
|
|
||||||
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
if self.apply_ocr:
|
|
||||||
encoded_inputs["words"] = words_batch
|
|
||||||
encoded_inputs["boxes"] = boxes_batch
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
|
@ -0,0 +1,371 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for LayoutLMv3."""
|
||||||
|
|
||||||
|
from typing import Dict, Iterable, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils import is_vision_available
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format, to_pil_image
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
infer_channel_dimension_format,
|
||||||
|
is_batched,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import is_pytesseract_available, logging, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
# soft dependency
|
||||||
|
if is_pytesseract_available():
|
||||||
|
import pytesseract
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_box(box, width, height):
|
||||||
|
return [
|
||||||
|
int(1000 * (box[0] / width)),
|
||||||
|
int(1000 * (box[1] / height)),
|
||||||
|
int(1000 * (box[2] / width)),
|
||||||
|
int(1000 * (box[3] / height)),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Optional[str]):
|
||||||
|
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
|
||||||
|
|
||||||
|
# apply OCR
|
||||||
|
pil_image = to_pil_image(image)
|
||||||
|
image_width, image_height = pil_image.size
|
||||||
|
data = pytesseract.image_to_data(pil_image, lang=lang, output_type="dict", config=tesseract_config)
|
||||||
|
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
|
||||||
|
|
||||||
|
# filter empty words and corresponding coordinates
|
||||||
|
irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
|
||||||
|
words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
|
||||||
|
left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
|
||||||
|
top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
|
||||||
|
width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
|
||||||
|
height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
|
||||||
|
|
||||||
|
# turn coordinates into (left, top, left+width, top+height) format
|
||||||
|
actual_boxes = []
|
||||||
|
for x, y, w, h in zip(left, top, width, height):
|
||||||
|
actual_box = [x, y, x + w, y + h]
|
||||||
|
actual_boxes.append(actual_box)
|
||||||
|
|
||||||
|
# finally, normalize the bounding boxes
|
||||||
|
normalized_boxes = []
|
||||||
|
for box in actual_boxes:
|
||||||
|
normalized_boxes.append(normalize_box(box, image_width, image_height))
|
||||||
|
|
||||||
|
assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
|
||||||
|
|
||||||
|
return words, normalized_boxes
|
||||||
|
|
||||||
|
|
||||||
|
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
|
||||||
|
input_data_format = infer_channel_dimension_format(image)
|
||||||
|
if input_data_format == ChannelDimension.LAST:
|
||||||
|
image = image[..., ::-1]
|
||||||
|
elif input_data_format == ChannelDimension.FIRST:
|
||||||
|
image = image[:, ::-1, ...]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
|
||||||
|
|
||||||
|
if data_format is not None:
|
||||||
|
image = to_channel_dimension_format(image, data_format)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class LayoutLMv3ImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a LayoutLMv3 image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to `(size["height"], size["width"])`. Can be
|
||||||
|
overridden by `do_resize` in `preprocess`.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||||
|
Size of the image after resizing. Can be overridden by `size` in `preprocess`.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image's pixel values by the specified `rescale_value`. Can be overridden by
|
||||||
|
`do_rescale` in `preprocess`.
|
||||||
|
rescale_factor (`float`, *optional*, defaults to 1 / 255):
|
||||||
|
Value by which the image's pixel values are rescaled. Can be overridden by `rescale_factor` in
|
||||||
|
`preprocess`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||||
|
method.
|
||||||
|
image_mean (`Iterable[float]` or `float`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`Iterable[float]` or `float`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
apply_ocr (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes. Can be overridden by
|
||||||
|
the `apply_ocr` parameter in the `preprocess` method.
|
||||||
|
ocr_lang (`str`, *optional*):
|
||||||
|
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
|
||||||
|
used. Can be overridden by the `ocr_lang` parameter in the `preprocess` method.
|
||||||
|
tesseract_config (`str`, *optional*):
|
||||||
|
Any additional custom configuration flags that are forwarded to the `config` parameter when calling
|
||||||
|
Tesseract. For example: '--psm 6'. Can be overridden by the `tesseract_config` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_value: float = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Union[float, Iterable[float]] = None,
|
||||||
|
image_std: Union[float, Iterable[float]] = None,
|
||||||
|
apply_ocr: bool = True,
|
||||||
|
ocr_lang: Optional[str] = None,
|
||||||
|
tesseract_config: Optional[str] = "",
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"height": 224, "width": 224}
|
||||||
|
size = get_size_dict(size)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_value
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
self.apply_ocr = apply_ocr
|
||||||
|
self.ocr_lang = ocr_lang
|
||||||
|
self.tesseract_config = tesseract_config
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image to (size["height"], size["width"]) dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Resampling filter to use when resiizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
if "height" not in size or "width" not in size:
|
||||||
|
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
|
||||||
|
output_size = (size["height"], size["width"])
|
||||||
|
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, Iterable[float]],
|
||||||
|
std: Union[float, Iterable[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
mean (`float` or `Iterable[float]`):
|
||||||
|
Mean values to be used for normalization.
|
||||||
|
std (`float` or `Iterable[float]`):
|
||||||
|
Standard deviation values to be used for normalization.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample=None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Union[float, Iterable[float]] = None,
|
||||||
|
image_std: Union[float, Iterable[float]] = None,
|
||||||
|
apply_ocr: bool = None,
|
||||||
|
ocr_lang: Optional[str] = None,
|
||||||
|
tesseract_config: Optional[str] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Desired size of the output image after applying `resize`.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` filters.
|
||||||
|
Only has an effect if `do_resize` is set to `True`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image pixel values between [0, 1].
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to apply to the image pixel values. Only has an effect if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `Iterable[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Mean values to be used for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||||
|
image_std (`float` or `Iterable[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Standard deviation values to be used for normalization. Only has an effect if `do_normalize` is set to
|
||||||
|
`True`.
|
||||||
|
apply_ocr (`bool`, *optional*, defaults to `self.apply_ocr`):
|
||||||
|
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
|
||||||
|
ocr_lang (`str`, *optional*, defaults to `self.ocr_lang`):
|
||||||
|
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
|
||||||
|
used.
|
||||||
|
tesseract_config (`str`, *optional*, defaults to `self.tesseract_config`):
|
||||||
|
Any additional custom configuration flags that are forwarded to the `config` parameter when calling
|
||||||
|
Tesseract.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size)
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
apply_ocr = apply_ocr if apply_ocr is not None else self.apply_ocr
|
||||||
|
ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang
|
||||||
|
tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None:
|
||||||
|
raise ValueError("Size must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_normalize and (image_mean is None or image_std is None):
|
||||||
|
raise ValueError("If do_normalize is True, image_mean and image_std must be specified.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
# Tesseract OCR to get words + normalized bounding boxes
|
||||||
|
if apply_ocr:
|
||||||
|
requires_backends(self, "pytesseract")
|
||||||
|
words_batch = []
|
||||||
|
boxes_batch = []
|
||||||
|
for image in images:
|
||||||
|
words, boxes = apply_tesseract(image, ocr_lang, tesseract_config)
|
||||||
|
words_batch.append(words)
|
||||||
|
boxes_batch.append(boxes)
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||||
|
|
||||||
|
# flip color channels from RGB to BGR (as Detectron2 requires this)
|
||||||
|
images = [flip_channel_order(image) for image in images]
|
||||||
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
|
data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
if apply_ocr:
|
||||||
|
data["words"] = words_batch
|
||||||
|
data["boxes"] = boxes_batch
|
||||||
|
return data
|
|
@ -14,148 +14,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for LeViT."""
|
"""Feature extractor class for LeViT."""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_levit import LevitImageProcessor
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import (
|
|
||||||
IMAGENET_DEFAULT_MEAN,
|
|
||||||
IMAGENET_DEFAULT_STD,
|
|
||||||
ImageFeatureExtractionMixin,
|
|
||||||
ImageInput,
|
|
||||||
is_torch_tensor,
|
|
||||||
)
|
|
||||||
from ...utils import TensorType, logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LevitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
# Feature extractor for Levit is being replaced by image processor
|
||||||
r"""
|
LevitFeatureExtractor = LevitImageProcessor
|
||||||
Constructs a LeViT feature extractor.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the shortest edge of the input to int(256/224 *`size`).
|
|
||||||
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
|
|
||||||
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
|
|
||||||
integer is provided, then shorter side of input will be resized to 'size'.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
do_center_crop (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to center crop the input to `size`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with mean and standard deviation.
|
|
||||||
image_mean (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images.
|
|
||||||
image_std (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize=True,
|
|
||||||
size=224,
|
|
||||||
resample=PILImageResampling.BICUBIC,
|
|
||||||
do_center_crop=True,
|
|
||||||
do_normalize=True,
|
|
||||||
image_mean=IMAGENET_DEFAULT_MEAN,
|
|
||||||
image_std=IMAGENET_DEFAULT_STD,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.do_center_crop = do_center_crop
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean
|
|
||||||
self.image_std = image_std
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
|
||||||
width).
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
# transformations (resizing + center cropping + normalization)
|
|
||||||
if self.do_resize and self.size is not None:
|
|
||||||
size_ = int((256 / 224) * self.size)
|
|
||||||
images = [
|
|
||||||
self.resize(image=image, size=size_, resample=self.resample, default_to_square=False)
|
|
||||||
for image in images
|
|
||||||
]
|
|
||||||
if self.do_center_crop:
|
|
||||||
images = [self.center_crop(image=image, size=self.size) for image in images]
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
|
@ -0,0 +1,342 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for LeViT."""
|
||||||
|
|
||||||
|
from typing import Dict, Iterable, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import (
|
||||||
|
center_crop,
|
||||||
|
get_resize_output_image_size,
|
||||||
|
normalize,
|
||||||
|
rescale,
|
||||||
|
resize,
|
||||||
|
to_channel_dimension_format,
|
||||||
|
)
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_DEFAULT_MEAN,
|
||||||
|
IMAGENET_DEFAULT_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
is_batched,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LevitImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a LeViT image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Wwhether to resize the shortest edge of the input to int(256/224 *`size`). Can be overridden by the
|
||||||
|
`do_resize` parameter in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `{"shortest_edge": 224}`):
|
||||||
|
Size of the output image after resizing. If size is a dict with keys "width" and "height", the image will
|
||||||
|
be resized to `(size["height"], size["width"])`. If size is a dict with key "shortest_edge", the shortest
|
||||||
|
edge value `c` is rescaled to `int(c * (256/224))`. The smaller edge of the image will be matched to this
|
||||||
|
value i.e, if height > width, then image will be rescaled to `(size["shortest_egde"] * height / width,
|
||||||
|
size["shortest_egde"])`. Can be overridden by the `size` parameter in the `preprocess` method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to center crop the input to `(crop_size["height"], crop_size["width"])`. Can be overridden
|
||||||
|
by the `do_center_crop` parameter in the `preprocess` method.
|
||||||
|
crop_size (`Dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||||
|
Desired image size after `center_crop`. Can be overridden by the `crop_size` parameter in the `preprocess`
|
||||||
|
method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
|
||||||
|
`do_rescale` parameter in the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
image_mean (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
do_center_crop: bool = True,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, Iterable[float]]] = IMAGENET_DEFAULT_MEAN,
|
||||||
|
image_std: Optional[Union[float, Iterable[float]]] = IMAGENET_DEFAULT_STD,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"shortest_edge": 224}
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_center_crop = do_center_crop
|
||||||
|
self.crop_size = crop_size
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image.
|
||||||
|
|
||||||
|
If size is a dict with keys "width" and "height", the image will be resized to `(size["height"],
|
||||||
|
size["width"])`.
|
||||||
|
|
||||||
|
If size is a dict with key "shortest_edge", the shortest edge value `c` is rescaled to `int(c * (256/224))`.
|
||||||
|
The smaller edge of the image will be matched to this value i.e, if height > width, then image will be rescaled
|
||||||
|
to `(size["shortest_egde"] * height / width, size["shortest_egde"])`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image after resizing. If size is a dict with keys "width" and "height", the image
|
||||||
|
will be resized to (height, width). If size is a dict with key "shortest_edge", the shortest edge value
|
||||||
|
`c` is rescaled to int(`c` * (256/224)). The smaller edge of the image will be matched to this value
|
||||||
|
i.e, if height > width, then image will be rescaled to (size * height / width, size).
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use when resiizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size_dict = get_size_dict(size, default_to_square=False)
|
||||||
|
# size_dict is a dict with either keys "height" and "width" or "shortest_edge"
|
||||||
|
if "shortest_edge" in size:
|
||||||
|
shortest_edge = int((256 / 224) * size["shortest_edge"])
|
||||||
|
output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False)
|
||||||
|
size_dict = {"height": output_size[0], "width": output_size[1]}
|
||||||
|
if "height" not in size_dict or "width" not in size_dict:
|
||||||
|
raise ValueError(
|
||||||
|
f"Size dict must have keys 'height' and 'width' or 'shortest_edge'. Got {size_dict.keys()}"
|
||||||
|
)
|
||||||
|
return resize(
|
||||||
|
image, size=(size_dict["height"], size_dict["width"]), resample=resample, data_format=data_format, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def center_crop(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Center crop an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to center crop.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Dict `{"height": int, "width": int}` specifying the size of the output image after cropping.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, List[float]],
|
||||||
|
std: Union[float, List[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image. image = (image - image_mean) / image_std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
mean (`float` or `List[float]`):
|
||||||
|
Image mean.
|
||||||
|
std (`float` or `List[float]`):
|
||||||
|
Image standard deviation.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: Optional[bool] = None,
|
||||||
|
size: Optional[Dict[str, int]] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_center_crop: Optional[bool] = None,
|
||||||
|
crop_size: Optional[Dict[str, int]] = None,
|
||||||
|
do_rescale: Optional[bool] = None,
|
||||||
|
rescale_factor: Optional[float] = None,
|
||||||
|
do_normalize: Optional[bool] = None,
|
||||||
|
image_mean: Optional[Union[float, Iterable[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, Iterable[float]]] = None,
|
||||||
|
return_tensors: Optional[TensorType] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images to be used as input to a LeViT model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image or batch of images to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the output image after resizing. If size is a dict with keys "width" and "height", the image
|
||||||
|
will be resized to (height, width). If size is a dict with key "shortest_edge", the shortest edge value
|
||||||
|
`c` is rescaled to int(`c` * (256/224)). The smaller edge of the image will be matched to this value
|
||||||
|
i.e, if height > width, then image will be rescaled to (size * height / width, size).
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use when resiizing the image.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
||||||
|
Whether to center crop the image.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
||||||
|
Size of the output image after center cropping. Crops images to (crop_size["height"],
|
||||||
|
crop_size["width"]).
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image pixel values by `rescaling_factor` - typical to values between 0 and 1.
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Factor to rescale the image pixel values by.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image pixel values by `image_mean` and `image_std`.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Mean to normalize the image pixel values by.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Standard deviation to normalize the image pixel values by.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||||
|
image is used. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None:
|
||||||
|
raise ValueError("Size must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_center_crop and crop_size is None:
|
||||||
|
raise ValueError("Crop size must be specified if do_center_crop is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_normalize and (image_mean is None or image_std is None):
|
||||||
|
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images = [self.resize(image, size, resample) for image in images]
|
||||||
|
|
||||||
|
if do_center_crop:
|
||||||
|
images = [self.center_crop(image, crop_size) for image in images]
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
images = [self.rescale(image, rescale_factor) for image in images]
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
images = [self.normalize(image, image_mean, image_std) for image in images]
|
||||||
|
|
||||||
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
|
data = {"pixel_values": images}
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
@ -14,189 +14,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for MobileViT."""
|
"""Feature extractor class for MobileViT."""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_mobilevit import MobileViTImageProcessor
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor
|
|
||||||
from ...utils import TensorType, is_torch_available, logging
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MobileViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
MobileViTFeatureExtractor = MobileViTImageProcessor
|
||||||
r"""
|
|
||||||
Constructs a MobileViT feature extractor.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the input to a certain `size`.
|
|
||||||
size (`int` or `Tuple(int)`, *optional*, defaults to 288):
|
|
||||||
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
|
|
||||||
integer is provided, then the input will be resized to match the shorter side. Only has an effect if
|
|
||||||
`do_resize` is set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
do_center_crop (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
|
|
||||||
image is padded with 0's and then center cropped.
|
|
||||||
crop_size (`int`, *optional*, defaults to 256):
|
|
||||||
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
|
|
||||||
do_flip_channel_order (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to flip the color channels from RGB to BGR.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize=True,
|
|
||||||
size=288,
|
|
||||||
resample=PILImageResampling.BILINEAR,
|
|
||||||
do_center_crop=True,
|
|
||||||
crop_size=256,
|
|
||||||
do_flip_channel_order=True,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.do_center_crop = do_center_crop
|
|
||||||
self.crop_size = crop_size
|
|
||||||
self.do_flip_channel_order = do_flip_channel_order
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
|
||||||
width).
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
# transformations (resizing + normalization)
|
|
||||||
if self.do_resize and self.size is not None:
|
|
||||||
images = [
|
|
||||||
self.resize(image=image, size=self.size, resample=self.resample, default_to_square=False)
|
|
||||||
for image in images
|
|
||||||
]
|
|
||||||
if self.do_center_crop and self.crop_size is not None:
|
|
||||||
images = [self.center_crop(image, self.crop_size) for image in images]
|
|
||||||
|
|
||||||
images = [self.to_numpy_array(image) for image in images]
|
|
||||||
|
|
||||||
# the pretrained checkpoints assume images are BGR, not RGB
|
|
||||||
if self.do_flip_channel_order:
|
|
||||||
images = [self.flip_channel_order(image) for image in images]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
||||||
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
|
|
||||||
"""
|
|
||||||
Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Only supports
|
|
||||||
PyTorch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
outputs ([`MobileViTForSemanticSegmentation`]):
|
|
||||||
Raw outputs of the model.
|
|
||||||
target_sizes (`List[Tuple]`, *optional*):
|
|
||||||
A list of length `batch_size`, where each item is a `Tuple[int, int]` corresponding to the requested
|
|
||||||
final size (height, width) of each prediction. If left to None, predictions will not be resized.
|
|
||||||
Returns:
|
|
||||||
`List[torch.Tensor]`:
|
|
||||||
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
|
|
||||||
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
|
|
||||||
`torch.Tensor` correspond to a semantic class id.
|
|
||||||
"""
|
|
||||||
logits = outputs.logits
|
|
||||||
|
|
||||||
# Resize logits and compute semantic segmentation maps
|
|
||||||
if target_sizes is not None:
|
|
||||||
if len(logits) != len(target_sizes):
|
|
||||||
raise ValueError(
|
|
||||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_torch_tensor(target_sizes):
|
|
||||||
target_sizes = target_sizes.numpy()
|
|
||||||
|
|
||||||
semantic_segmentation = []
|
|
||||||
|
|
||||||
for idx in range(len(logits)):
|
|
||||||
resized_logits = torch.nn.functional.interpolate(
|
|
||||||
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
semantic_map = resized_logits[0].argmax(dim=0)
|
|
||||||
semantic_segmentation.append(semantic_map)
|
|
||||||
else:
|
|
||||||
semantic_segmentation = logits.argmax(dim=1)
|
|
||||||
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
|
||||||
|
|
||||||
return semantic_segmentation
|
|
||||||
|
|
|
@ -0,0 +1,364 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for MobileViT."""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils import is_torch_available, is_torch_tensor, is_vision_available
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import center_crop, get_resize_output_image_size, rescale, resize, to_channel_dimension_format
|
||||||
|
from ...image_utils import (
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
infer_channel_dimension_format,
|
||||||
|
is_batched,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Flip the color channels from RGB to BGR or vice versa.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
The image, represented as a numpy array.
|
||||||
|
data_format (`ChannelDimension`, *`optional`*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`np.ndarray`: The image with the flipped color channels.
|
||||||
|
"""
|
||||||
|
input_data_format = infer_channel_dimension_format(image)
|
||||||
|
|
||||||
|
if input_data_format == ChannelDimension.LAST:
|
||||||
|
image = image[..., ::-1]
|
||||||
|
elif input_data_format == ChannelDimension.FIRST:
|
||||||
|
image = image[:, ::-1, ...]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid input channel dimension format: {input_data_format}")
|
||||||
|
|
||||||
|
if data_format is not None:
|
||||||
|
image = to_channel_dimension_format(image, data_format)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class MobileViTImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a MobileViT image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
||||||
|
`do_resize` parameter in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
|
||||||
|
Controls the size of the output image after resizing. Can be overridden by the `size` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Defines the resampling filter to use if resizing the image. Can be overridden by the `resample` parameter
|
||||||
|
in the `preprocess` method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||||
|
parameter in the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
|
||||||
|
image is padded with 0's and then center cropped. Can be overridden by the `do_center_crop` parameter in
|
||||||
|
the `preprocess` method.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`):
|
||||||
|
Desired output size `(size["height"], size["width"])` when applying center-cropping. Can be overridden by
|
||||||
|
the `crop_size` parameter in the `preprocess` method.
|
||||||
|
do_flip_channel_order (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order`
|
||||||
|
parameter in the `preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_center_crop: bool = True,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_flip_channel_order: bool = True,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"shortest_edge": 224}
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_center_crop = do_center_crop
|
||||||
|
self.crop_size = crop_size
|
||||||
|
self.do_flip_channel_order = do_flip_channel_order
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PIL.Image.BILINEAR,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Controls the size of the output image. The shortest edge of the image will be resized to
|
||||||
|
`size["shortest_edge"]` while maintaining the aspect ratio.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Resampling filter to use when resiizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
if "shortest_edge" not in size:
|
||||||
|
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
|
||||||
|
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
|
||||||
|
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def center_crop(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Center crop an image to size `(size["height], size["width"])`. If the input size is smaller than `size` along
|
||||||
|
any edge, the image is padded with 0's and then center cropped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to center crop.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def flip_channel_order(
|
||||||
|
self, image: np.ndarray, data_format: Optional[Union[str, ChannelDimension]] = None
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Flip the color channels from RGB to BGR or vice versa.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
The image, represented as a numpy array.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return flip_channel_order(image, data_format=data_format)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_center_crop: bool = None,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_flip_channel_order: bool = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image after resizing.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image by rescale factor.
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
||||||
|
Whether to center crop the image.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
||||||
|
Size of the center crop if `do_center_crop` is set to `True`.
|
||||||
|
do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`):
|
||||||
|
Whether to flip the channel order of the image.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||||
|
do_flip_channel_order = (
|
||||||
|
do_flip_channel_order if do_flip_channel_order is not None else self.do_flip_channel_order
|
||||||
|
)
|
||||||
|
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None:
|
||||||
|
raise ValueError("Size must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_center_crop and crop_size is None:
|
||||||
|
raise ValueError("Crop size must be specified if do_center_crop is True.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||||
|
|
||||||
|
if do_center_crop:
|
||||||
|
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||||
|
|
||||||
|
# the pretrained checkpoints assume images are BGR, not RGB
|
||||||
|
if do_flip_channel_order:
|
||||||
|
images = [self.flip_channel_order(image=image) for image in images]
|
||||||
|
|
||||||
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
|
data = {"pixel_values": images}
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Only supports
|
||||||
|
PyTorch.
|
||||||
|
outputs ([`MobileViTForSemanticSegmentation`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
target_sizes (`List[Tuple]`, *optional*):
|
||||||
|
A list of length `batch_size`, where each item is a `Tuple[int, int]` corresponding to the requested
|
||||||
|
final size (height, width) of each prediction. If left to None, predictions will not be resized.
|
||||||
|
Returns:
|
||||||
|
`List[torch.Tensor]`:
|
||||||
|
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
|
||||||
|
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
|
||||||
|
`torch.Tensor` correspond to a semantic class id.
|
||||||
|
"""
|
||||||
|
# TODO: add support for other frameworks
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# Resize logits and compute semantic segmentation maps
|
||||||
|
if target_sizes is not None:
|
||||||
|
if len(logits) != len(target_sizes):
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_torch_tensor(target_sizes):
|
||||||
|
target_sizes = target_sizes.numpy()
|
||||||
|
|
||||||
|
semantic_segmentation = []
|
||||||
|
|
||||||
|
for idx in range(len(logits)):
|
||||||
|
resized_logits = torch.nn.functional.interpolate(
|
||||||
|
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
semantic_map = resized_logits[0].argmax(dim=0)
|
||||||
|
semantic_segmentation.append(semantic_map)
|
||||||
|
else:
|
||||||
|
semantic_segmentation = logits.argmax(dim=1)
|
||||||
|
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
||||||
|
|
||||||
|
return semantic_segmentation
|
|
@ -14,179 +14,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for Perceiver."""
|
"""Feature extractor class for Perceiver."""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_perceiver import PerceiverImageProcessor
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import (
|
|
||||||
IMAGENET_DEFAULT_MEAN,
|
|
||||||
IMAGENET_DEFAULT_STD,
|
|
||||||
ImageFeatureExtractionMixin,
|
|
||||||
ImageInput,
|
|
||||||
is_torch_tensor,
|
|
||||||
)
|
|
||||||
from ...utils import TensorType, logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PerceiverFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
PerceiverFeatureExtractor = PerceiverImageProcessor
|
||||||
r"""
|
|
||||||
Constructs a Perceiver feature extractor.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`ImageFeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_center_crop (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
|
|
||||||
image is padded with 0's and then center cropped.
|
|
||||||
crop_size (`int`, *optional*, defaults to 256):
|
|
||||||
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the input to a certain `size`.
|
|
||||||
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
|
|
||||||
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
|
|
||||||
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
|
|
||||||
set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with `image_mean` and `image_std`.
|
|
||||||
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images.
|
|
||||||
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_center_crop=True,
|
|
||||||
crop_size=256,
|
|
||||||
do_resize=True,
|
|
||||||
size=224,
|
|
||||||
resample=PILImageResampling.BICUBIC,
|
|
||||||
do_normalize=True,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_center_crop = do_center_crop
|
|
||||||
self.crop_size = crop_size
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
|
||||||
|
|
||||||
def center_crop(self, image):
|
|
||||||
"""
|
|
||||||
Crops `image` to *self.crop_size* using a center crop. Note that if the image is too small to be cropped to the
|
|
||||||
size given, it will be padded (so the returned result has the size asked).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
|
||||||
The image to resize.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(image, Image.Image):
|
|
||||||
image = self.to_numpy_array(image)
|
|
||||||
|
|
||||||
image_height, image_width = image.shape[-2:]
|
|
||||||
|
|
||||||
padded_center_crop_size = (
|
|
||||||
(self.size / (self.crop_size)) * np.minimum(image_height, image_width).astype(np.float32)
|
|
||||||
).astype(np.int32)
|
|
||||||
|
|
||||||
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
|
|
||||||
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
|
|
||||||
crop_window = [offset_height, offset_width, padded_center_crop_size, padded_center_crop_size]
|
|
||||||
|
|
||||||
image = image[
|
|
||||||
:, crop_window[0] : crop_window[0] + crop_window[2], crop_window[1] : crop_window[1] + crop_window[3]
|
|
||||||
]
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
|
||||||
width).
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
# transformations (center cropping + resizing + normalization)
|
|
||||||
if self.do_center_crop and self.crop_size is not None:
|
|
||||||
images = [self.center_crop(image) for image in images]
|
|
||||||
if self.do_resize and self.size is not None and self.resample is not None:
|
|
||||||
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
|
@ -0,0 +1,330 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for Perceiver."""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils import is_vision_available
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
get_image_size,
|
||||||
|
is_batched,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PerceiverImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a Perceiver image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_center_crop (`bool`, `optional`, defaults to `True`):
|
||||||
|
Whether or not to center crop the image. If the input size if smaller than `crop_size` along any edge, the
|
||||||
|
image will be padded with zeros and then center cropped. Can be overridden by the `do_center_crop`
|
||||||
|
parameter in the `preprocess` method.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`):
|
||||||
|
Desired output size when applying center-cropping. Can be overridden by the `crop_size` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image to `(size["height"], size["width"])`. Can be overridden by the `do_resize`
|
||||||
|
parameter in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||||
|
Size of the image after resizing. Can be overridden by the `size` parameter in the `preprocess` method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Defines the resampling filter to use if resizing the image. Can be overridden by the `resample` parameter
|
||||||
|
in the `preprocess` method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||||
|
parameter in the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter
|
||||||
|
in the `preprocess` method.
|
||||||
|
do_normalize:
|
||||||
|
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||||
|
method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_center_crop: bool = True,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
size = size if size is not None else {"height": 224, "width": 224}
|
||||||
|
size = get_size_dict(size)
|
||||||
|
|
||||||
|
self.do_center_crop = do_center_crop
|
||||||
|
self.crop_size = crop_size
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
|
||||||
|
def center_crop(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
crop_size: Dict[str, int],
|
||||||
|
size: Optional[int] = None,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Center crop an image to `(size["height"] / crop_size["height"] * min_dim, size["width"] / crop_size["width"] *
|
||||||
|
min_dim)`. Where `min_dim = min(size["height"], size["width"])`.
|
||||||
|
|
||||||
|
If the input size is smaller than `crop_size` along any edge, the image will be padded with zeros and then
|
||||||
|
center cropped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to center crop.
|
||||||
|
crop_size (`Dict[str, int]`):
|
||||||
|
Desired output size after applying the center crop.
|
||||||
|
size (`Dict[str, int]`, *optional*):
|
||||||
|
Size of the image after resizing. If not provided, the self.size attribute will be used.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = self.size if size is None else size
|
||||||
|
size = get_size_dict(size)
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
|
||||||
|
height, width = get_image_size(image)
|
||||||
|
min_dim = min(height, width)
|
||||||
|
cropped_height = (size["height"] / crop_size["height"]) * min_dim
|
||||||
|
cropped_width = (size["width"] / crop_size["width"]) * min_dim
|
||||||
|
return center_crop(image, size=(cropped_height, cropped_width), data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PIL.Image.BILINEAR,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image to `(size["height"], size["width"])`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BILINEAR`):
|
||||||
|
Resampling filter to use when resizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
if "height" not in size or "width" not in size:
|
||||||
|
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
|
||||||
|
return resize(
|
||||||
|
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, List[float]],
|
||||||
|
std: Union[float, List[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image. image = (image - image_mean) / image_std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
mean (`float` or `List[float]`):
|
||||||
|
Image mean.
|
||||||
|
std (`float` or `List[float]`):
|
||||||
|
Image standard deviation.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_center_crop: Optional[bool] = None,
|
||||||
|
crop_size: Optional[Dict[str, int]] = None,
|
||||||
|
do_resize: Optional[bool] = None,
|
||||||
|
size: Optional[Dict[str, int]] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_rescale: Optional[bool] = None,
|
||||||
|
rescale_factor: Optional[float] = None,
|
||||||
|
do_normalize: Optional[bool] = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
||||||
|
Whether to center crop the image to `crop_size`.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
||||||
|
Desired output size after applying the center crop.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image after resizing.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image.
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
"""
|
||||||
|
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||||
|
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size)
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_center_crop and crop_size is None:
|
||||||
|
raise ValueError("If `do_center_crop` is set to `True`, `crop_size` must be provided.")
|
||||||
|
|
||||||
|
if do_resize and size is None:
|
||||||
|
raise ValueError("Size must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_normalize and (image_mean is None or image_std is None):
|
||||||
|
raise ValueError("Image mean and image standard deviation must be specified if do_normalize is True.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
if do_center_crop:
|
||||||
|
images = [self.center_crop(image, crop_size, size=size) for image in images]
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||||
|
|
||||||
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
|
data = {"pixel_values": images}
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
@ -14,161 +14,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for PoolFormer."""
|
"""Feature extractor class for PoolFormer."""
|
||||||
|
|
||||||
import math
|
from ...utils import logging
|
||||||
from typing import Optional, Union
|
from .image_processing_poolformer import PoolFormerImageProcessor
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import (
|
|
||||||
IMAGENET_DEFAULT_MEAN,
|
|
||||||
IMAGENET_DEFAULT_STD,
|
|
||||||
ImageFeatureExtractionMixin,
|
|
||||||
ImageInput,
|
|
||||||
is_torch_tensor,
|
|
||||||
)
|
|
||||||
from ...utils import TensorType, logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PoolFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
PoolFormerFeatureExtractor = PoolFormerImageProcessor
|
||||||
r"""
|
|
||||||
Constructs a PoolFormer feature extractor.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize_and_center_crop (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the shortest edge of the image and center crop the input to a certain `size`.
|
|
||||||
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
|
|
||||||
Center crop the input to the given size. If a tuple is provided, it should be (width, height). If only an
|
|
||||||
integer is provided, then the input will be center cropped to (size, size). Only has an effect if
|
|
||||||
`do_resize_and_center_crop` is set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
crop_pct (`float`, *optional*, defaults to `0.9`):
|
|
||||||
The percentage of the image to crop from the center. Only has an effect if `do_resize_and_center_crop` is
|
|
||||||
set to `True`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with `image_mean` and `image_std`.
|
|
||||||
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images.
|
|
||||||
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize_and_center_crop=True,
|
|
||||||
size=224,
|
|
||||||
resample=PILImageResampling.BICUBIC,
|
|
||||||
crop_pct=0.9,
|
|
||||||
do_normalize=True,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize_and_center_crop = do_resize_and_center_crop
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.crop_pct = crop_pct
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
|
||||||
width).
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
# transformations (resizing + center cropping + normalization)
|
|
||||||
if self.do_resize_and_center_crop and self.size is not None and self.crop_pct is not None:
|
|
||||||
if isinstance(self.size, (tuple, list)):
|
|
||||||
assert len(self.size) == 2
|
|
||||||
if self.size[-1] == self.size[-2]:
|
|
||||||
scale_size = int(math.floor(self.size[0] / self.crop_pct))
|
|
||||||
else:
|
|
||||||
scale_size = tuple([int(x / self.crop_pct) for x in self.size])
|
|
||||||
else:
|
|
||||||
scale_size = int(math.floor(self.size / self.crop_pct))
|
|
||||||
|
|
||||||
# resize shortest edge of the image
|
|
||||||
images = [
|
|
||||||
self.resize(image=image, size=scale_size, resample=self.resample, default_to_square=False)
|
|
||||||
for image in images
|
|
||||||
]
|
|
||||||
# center crop
|
|
||||||
images = [self.center_crop(image, size=self.size) for image in images]
|
|
||||||
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
|
@ -0,0 +1,382 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for PoolFormer."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import is_vision_available
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import (
|
||||||
|
center_crop,
|
||||||
|
get_resize_output_image_size,
|
||||||
|
normalize,
|
||||||
|
rescale,
|
||||||
|
resize,
|
||||||
|
to_channel_dimension_format,
|
||||||
|
)
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
is_batched,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PoolFormerImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a PoolFormer image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||||
|
`do_resize` in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 256}`):
|
||||||
|
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. If crop_pct is
|
||||||
|
unset:
|
||||||
|
- size is `{"height": h, "width": w}`: the image is resized to `(h, w)`.
|
||||||
|
- size is `{"shortest_edge": s}`: the shortest edge of the image is resized to s whilst maintaining the
|
||||||
|
aspect ratio.
|
||||||
|
|
||||||
|
If crop_pct is set:
|
||||||
|
- size is `{"height": h, "width": w}`: the image is resized to `(int(floor(h/crop_pct)),
|
||||||
|
int(floor(w/crop_pct)))`
|
||||||
|
- size is `{"height": c, "width": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
|
||||||
|
whilst maintaining the aspect ratio.
|
||||||
|
- size is `{"shortest_edge": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
|
||||||
|
whilst maintaining the aspect ratio.
|
||||||
|
crop_pct (`float`, *optional*, defaults to `0.9`):
|
||||||
|
Percentage of the image to crop from the center. Can be overridden by `crop_pct` in the `preprocess`
|
||||||
|
method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
|
||||||
|
is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in the `preprocess`
|
||||||
|
method.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`):
|
||||||
|
Size of the image after applying center crop. Only has an effect if `do_center_crop` is set to `True`. Can
|
||||||
|
be overridden by the `crop_size` parameter in the `preprocess` method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||||
|
parameter in the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
crop_pct: int = 0.9,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
do_center_crop: bool = True,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"shortest_edge": 256}
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.crop_pct = crop_pct
|
||||||
|
self.resample = resample
|
||||||
|
self.do_center_crop = do_center_crop
|
||||||
|
self.crop_size = crop_size
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
crop_pct: Optional[float] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image.
|
||||||
|
|
||||||
|
If crop_pct is unset:
|
||||||
|
- size is `{"height": h, "width": w}`: the image is resized to `(h, w)`.
|
||||||
|
- size is `{"shortest_edge": s}`: the shortest edge of the image is resized to s whilst maintaining the
|
||||||
|
aspect ratio.
|
||||||
|
|
||||||
|
if crop_pct is set:
|
||||||
|
- size is `{"height": h, "width": w}`: the image is resized to `(int(floor(h/crop_pct)),
|
||||||
|
int(floor(w/crop_pct)))`
|
||||||
|
- size is `{"height": c, "width": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
|
||||||
|
whilst maintaining the aspect ratio.
|
||||||
|
- size is `{"shortest_edge": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
|
||||||
|
whilst maintaining the aspect ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
crop_pct (`float`, *optional*):
|
||||||
|
Percentage of the image that will be cropped from the center. If set, the image is resized
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use when resizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
if "shortest_edge" not in size and ("height" not in size or "width" not in size):
|
||||||
|
raise ValueError(f"size must contain 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
|
||||||
|
if crop_pct is not None:
|
||||||
|
if "shortest_edge" in size:
|
||||||
|
scale_size = int(math.floor(size["shortest_edge"] / crop_pct))
|
||||||
|
elif "height" in size and "width" in size:
|
||||||
|
if size["height"] == size["width"]:
|
||||||
|
scale_size = int(math.floor(size["height"] / crop_pct))
|
||||||
|
else:
|
||||||
|
scale_size = (
|
||||||
|
int(math.floor(size["height"] / crop_pct)),
|
||||||
|
int(math.floor(size["width"] / crop_pct)),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid size for resize: {}".format(size))
|
||||||
|
|
||||||
|
output_size = get_resize_output_image_size(image, size=scale_size, default_to_square=False)
|
||||||
|
else:
|
||||||
|
if "shortest_edge" in size:
|
||||||
|
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
|
||||||
|
elif "height" in size and "width" in size:
|
||||||
|
output_size = (size["height"], size["width"])
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid size for resize: {}".format(size))
|
||||||
|
|
||||||
|
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def center_crop(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Center crop an image to (size["height"], size["width"]). If the input size is smaller than `crop_size` along
|
||||||
|
any edge, the image is padded with 0's and then center cropped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to center crop.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, List[float]],
|
||||||
|
std: Union[float, List[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image. image = (image - image_mean) / image_std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
image_mean (`float` or `List[float]`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`):
|
||||||
|
Image standard deviation.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
crop_pct: int = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_center_crop: bool = None,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image after applying resize.
|
||||||
|
crop_pct (`float`, *optional*, defaults to `self.crop_pct`):
|
||||||
|
Percentage of the image to crop. Only has an effect if `do_resize` is set to `True`.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
||||||
|
Whether to center crop the image.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
||||||
|
Size of the image after applying center crop.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image values between [0 - 1].
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
crop_pct = crop_pct if crop_pct is not None else self.crop_pct
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None or resample is None:
|
||||||
|
raise ValueError("Size and resample must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_center_crop and crop_pct is None:
|
||||||
|
raise ValueError("Crop_pct must be specified if do_center_crop is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_normalize and (image_mean is None or image_std is None):
|
||||||
|
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images = [self.resize(image=image, size=size, crop_pct=crop_pct, resample=resample) for image in images]
|
||||||
|
|
||||||
|
if do_center_crop:
|
||||||
|
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||||
|
|
||||||
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
|
data = {"pixel_values": images}
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
@ -14,248 +14,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for SegFormer."""
|
"""Feature extractor class for SegFormer."""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_segformer import SegformerImageProcessor
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import (
|
|
||||||
IMAGENET_DEFAULT_MEAN,
|
|
||||||
IMAGENET_DEFAULT_STD,
|
|
||||||
ImageFeatureExtractionMixin,
|
|
||||||
ImageInput,
|
|
||||||
is_torch_tensor,
|
|
||||||
)
|
|
||||||
from ...utils import TensorType, is_torch_available, logging
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SegformerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
SegformerFeatureExtractor = SegformerImageProcessor
|
||||||
r"""
|
|
||||||
Constructs a SegFormer feature extractor.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the input based on a certain `size`.
|
|
||||||
size (`int` or `Tuple(int)`, *optional*, defaults to 512):
|
|
||||||
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
|
|
||||||
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
|
|
||||||
set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with mean and standard deviation.
|
|
||||||
image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
|
|
||||||
image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
|
|
||||||
ImageNet std.
|
|
||||||
reduce_labels (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
|
|
||||||
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
|
|
||||||
background label will be replaced by 255.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize=True,
|
|
||||||
size=512,
|
|
||||||
resample=PILImageResampling.BILINEAR,
|
|
||||||
do_normalize=True,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
reduce_labels=False,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
|
||||||
self.reduce_labels = reduce_labels
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
images: ImageInput,
|
|
||||||
segmentation_maps: ImageInput = None,
|
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s) and optional corresponding segmentation maps.
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is
|
|
||||||
the number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
|
|
||||||
Optionally, the corresponding semantic segmentation maps with the pixel-wise annotations.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
|
||||||
width).
|
|
||||||
- **labels** -- Optional labels to be fed to a model (when `segmentation_maps` are provided)
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
valid_segmentation_maps = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that segmentation maps has a valid type
|
|
||||||
if segmentation_maps is not None:
|
|
||||||
if isinstance(segmentation_maps, (Image.Image, np.ndarray)) or is_torch_tensor(segmentation_maps):
|
|
||||||
valid_segmentation_maps = True
|
|
||||||
elif isinstance(segmentation_maps, (list, tuple)):
|
|
||||||
if (
|
|
||||||
len(segmentation_maps) == 0
|
|
||||||
or isinstance(segmentation_maps[0], (Image.Image, np.ndarray))
|
|
||||||
or is_torch_tensor(segmentation_maps[0])
|
|
||||||
):
|
|
||||||
valid_segmentation_maps = True
|
|
||||||
|
|
||||||
if not valid_segmentation_maps:
|
|
||||||
raise ValueError(
|
|
||||||
"Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single"
|
|
||||||
" example),`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of"
|
|
||||||
" examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
if segmentation_maps is not None:
|
|
||||||
segmentation_maps = [segmentation_maps]
|
|
||||||
|
|
||||||
# reduce zero label if needed
|
|
||||||
if self.reduce_labels:
|
|
||||||
if segmentation_maps is not None:
|
|
||||||
for idx, map in enumerate(segmentation_maps):
|
|
||||||
if not isinstance(map, np.ndarray):
|
|
||||||
map = np.array(map)
|
|
||||||
# avoid using underflow conversion
|
|
||||||
map[map == 0] = 255
|
|
||||||
map = map - 1
|
|
||||||
map[map == 254] = 255
|
|
||||||
segmentation_maps[idx] = Image.fromarray(map.astype(np.uint8))
|
|
||||||
|
|
||||||
# transformations (resizing + normalization)
|
|
||||||
if self.do_resize and self.size is not None:
|
|
||||||
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
|
|
||||||
if segmentation_maps is not None:
|
|
||||||
segmentation_maps = [
|
|
||||||
self.resize(map, size=self.size, resample=Image.NEAREST) for map in segmentation_maps
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
|
|
||||||
if segmentation_maps is not None:
|
|
||||||
labels = []
|
|
||||||
for map in segmentation_maps:
|
|
||||||
if not isinstance(map, np.ndarray):
|
|
||||||
map = np.array(map)
|
|
||||||
labels.append(map.astype(np.int64))
|
|
||||||
# cast to np.int64
|
|
||||||
data["labels"] = labels
|
|
||||||
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
||||||
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
|
|
||||||
"""
|
|
||||||
Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports
|
|
||||||
PyTorch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
outputs ([`SegformerForSemanticSegmentation`]):
|
|
||||||
Raw outputs of the model.
|
|
||||||
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
|
|
||||||
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
|
|
||||||
None, predictions will not be resized.
|
|
||||||
Returns:
|
|
||||||
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
|
||||||
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
|
||||||
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
|
||||||
"""
|
|
||||||
logits = outputs.logits
|
|
||||||
|
|
||||||
# Resize logits and compute semantic segmentation maps
|
|
||||||
if target_sizes is not None:
|
|
||||||
if len(logits) != len(target_sizes):
|
|
||||||
raise ValueError(
|
|
||||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_torch_tensor(target_sizes):
|
|
||||||
target_sizes = target_sizes.numpy()
|
|
||||||
|
|
||||||
semantic_segmentation = []
|
|
||||||
|
|
||||||
for idx in range(len(logits)):
|
|
||||||
resized_logits = torch.nn.functional.interpolate(
|
|
||||||
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
semantic_map = resized_logits[0].argmax(dim=0)
|
|
||||||
semantic_segmentation.append(semantic_map)
|
|
||||||
else:
|
|
||||||
semantic_segmentation = logits.argmax(dim=1)
|
|
||||||
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
|
||||||
|
|
||||||
return semantic_segmentation
|
|
||||||
|
|
|
@ -0,0 +1,488 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for Segformer."""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils import is_torch_available, is_torch_tensor, is_vision_available
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
is_batched,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL.Image
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SegformerImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a Segformer image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
|
||||||
|
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`):
|
||||||
|
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
||||||
|
method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||||
|
parameter in the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||||
|
method.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||||
|
method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
do_reduce_labels (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
|
||||||
|
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
|
||||||
|
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_reduce_labels: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
if "reduce_labels" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use "
|
||||||
|
"`do_reduce_labels` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
do_reduce_labels = kwargs.pop("reduce_labels")
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"height": 512, "width": 512}
|
||||||
|
size = get_size_dict(size)
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
self.do_reduce_labels = do_reduce_labels
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image to `(size["height"], size["width"])`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BICUBIC`):
|
||||||
|
Resampling filter to use when resiizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
if "height" not in size or "width" not in size:
|
||||||
|
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||||
|
return resize(
|
||||||
|
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def center_crop(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
|
||||||
|
any edge, the image is padded with 0's and then center cropped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to center crop.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, List[float]],
|
||||||
|
std: Union[float, List[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image. image = (image - image_mean) / image_std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
image_mean (`float` or `List[float]`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`):
|
||||||
|
Image standard deviation.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def reduce_label(self, label: ImageInput) -> np.ndarray:
|
||||||
|
label = to_numpy_array(label)
|
||||||
|
# Avoid using underflow conversion
|
||||||
|
label[label == 0] = 255
|
||||||
|
label = label - 1
|
||||||
|
label[label == 254] = 255
|
||||||
|
return label
|
||||||
|
|
||||||
|
def _preprocess(
|
||||||
|
self,
|
||||||
|
image: ImageInput,
|
||||||
|
do_reduce_labels: bool,
|
||||||
|
do_resize: bool,
|
||||||
|
do_rescale: bool,
|
||||||
|
do_normalize: bool,
|
||||||
|
size: Optional[Dict[str, int]] = None,
|
||||||
|
resample: Optional[PILImageResampling] = None,
|
||||||
|
rescale_factor: Optional[float] = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
):
|
||||||
|
if do_reduce_labels:
|
||||||
|
image = self.reduce_label(image)
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
image = self.resize(image=image, size=size, resample=resample)
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
image = self.rescale(image=image, scale=rescale_factor)
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
image = self.normalize(image=image, mean=image_mean, std=image_std)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _preprocess_image(
|
||||||
|
self,
|
||||||
|
image: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Preprocesses a single image."""
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
image = to_numpy_array(image)
|
||||||
|
image = self._preprocess(
|
||||||
|
image=image,
|
||||||
|
do_reduce_labels=False,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=resample,
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
do_normalize=do_normalize,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
)
|
||||||
|
if data_format is not None:
|
||||||
|
image = to_channel_dimension_format(image, data_format)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _preprocess_mask(
|
||||||
|
self,
|
||||||
|
segmentation_map: ImageInput,
|
||||||
|
do_reduce_labels: bool = None,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Preprocesses a single mask."""
|
||||||
|
segmentation_map = to_numpy_array(segmentation_map)
|
||||||
|
# Add channel dimension if missing - needed for certain transformations
|
||||||
|
added_channel_dim = False
|
||||||
|
if segmentation_map.ndim == 2:
|
||||||
|
added_channel_dim = True
|
||||||
|
segmentation_map = segmentation_map[None, ...]
|
||||||
|
# reduce zero label if needed
|
||||||
|
segmentation_map = self._preprocess(
|
||||||
|
image=segmentation_map,
|
||||||
|
do_reduce_labels=do_reduce_labels,
|
||||||
|
do_resize=do_resize,
|
||||||
|
resample=PIL.Image.NEAREST,
|
||||||
|
size=size,
|
||||||
|
do_rescale=False,
|
||||||
|
do_normalize=False,
|
||||||
|
)
|
||||||
|
# Remove extra channel dimension if added for processing
|
||||||
|
if added_channel_dim:
|
||||||
|
segmentation_map = segmentation_map.squeeze(0)
|
||||||
|
segmentation_map = segmentation_map.astype(np.int64)
|
||||||
|
return segmentation_map
|
||||||
|
|
||||||
|
def __call__(self, images, segmentation_maps=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Preprocesses a batch of images and optionally segmentation maps.
|
||||||
|
|
||||||
|
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
|
||||||
|
passed in as positional arguments.
|
||||||
|
"""
|
||||||
|
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
segmentation_maps: Optional[ImageInput] = None,
|
||||||
|
do_resize: Optional[bool] = None,
|
||||||
|
size: Optional[Dict[str, int]] = None,
|
||||||
|
resample: Optional[PILImageResampling] = None,
|
||||||
|
do_rescale: Optional[bool] = None,
|
||||||
|
rescale_factor: Optional[float] = None,
|
||||||
|
do_normalize: Optional[bool] = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_reduce_labels: Optional[bool] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
segmentation_maps (`ImageInput`, *optional*):
|
||||||
|
Segmentation map to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image after `resize` is applied.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image values between [0 - 1].
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation.
|
||||||
|
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
|
||||||
|
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
|
||||||
|
is used for background, and background itself is not included in all classes of a dataset (e.g.
|
||||||
|
ADE20k). The background label will be replaced by 255.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if segmentation_maps is not None and not valid_images(segmentation_maps):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None or resample is None:
|
||||||
|
raise ValueError("Size and resample must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_normalize and (image_mean is None or image_std is None):
|
||||||
|
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||||
|
|
||||||
|
images = [
|
||||||
|
self._preprocess_image(
|
||||||
|
image=img,
|
||||||
|
do_resize=do_resize,
|
||||||
|
resample=resample,
|
||||||
|
size=size,
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
do_normalize=do_normalize,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
data_format=data_format,
|
||||||
|
)
|
||||||
|
for img in images
|
||||||
|
]
|
||||||
|
|
||||||
|
data = {"pixel_values": images}
|
||||||
|
|
||||||
|
if segmentation_maps is not None:
|
||||||
|
segmentation_maps = [
|
||||||
|
self._preprocess_mask(
|
||||||
|
segmentation_map=segmentation_map,
|
||||||
|
do_reduce_labels=do_reduce_labels,
|
||||||
|
do_resize=do_resize,
|
||||||
|
resample=PIL.Image.NEAREST,
|
||||||
|
size=size,
|
||||||
|
)
|
||||||
|
for segmentation_map in segmentation_maps
|
||||||
|
]
|
||||||
|
data["labels"] = segmentation_maps
|
||||||
|
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
|
||||||
|
"""
|
||||||
|
Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports
|
||||||
|
PyTorch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs ([`SegformerForSemanticSegmentation`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
|
||||||
|
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
|
||||||
|
None, predictions will not be resized.
|
||||||
|
Returns:
|
||||||
|
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
||||||
|
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
||||||
|
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
||||||
|
"""
|
||||||
|
# TODO: add support for other frameworks
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# Resize logits and compute semantic segmentation maps
|
||||||
|
if target_sizes is not None:
|
||||||
|
if len(logits) != len(target_sizes):
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_torch_tensor(target_sizes):
|
||||||
|
target_sizes = target_sizes.numpy()
|
||||||
|
|
||||||
|
semantic_segmentation = []
|
||||||
|
|
||||||
|
for idx in range(len(logits)):
|
||||||
|
resized_logits = torch.nn.functional.interpolate(
|
||||||
|
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
semantic_map = resized_logits[0].argmax(dim=0)
|
||||||
|
semantic_segmentation.append(semantic_map)
|
||||||
|
else:
|
||||||
|
semantic_segmentation = logits.argmax(dim=1)
|
||||||
|
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
||||||
|
|
||||||
|
return semantic_segmentation
|
|
@ -14,159 +14,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for VideoMAE."""
|
"""Feature extractor class for VideoMAE."""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_videomae import VideoMAEImageProcessor
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor
|
|
||||||
from ...utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, TensorType, logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class VideoMAEFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
VideoMAEFeatureExtractor = VideoMAEImageProcessor
|
||||||
r"""
|
|
||||||
Constructs a VideoMAE feature extractor. This feature extractor can be used to prepare videos for the model.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the shorter edge of the input to a certain `size`.
|
|
||||||
size (`int`, *optional*, defaults to 224):
|
|
||||||
Resize the shorter edge of the input to the given size. Only has an effect if `do_resize` is set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
do_center_crop (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to center crop the input to a certain `size`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with mean and standard deviation.
|
|
||||||
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images.
|
|
||||||
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize=True,
|
|
||||||
size=224,
|
|
||||||
resample=PILImageResampling.BILINEAR,
|
|
||||||
do_center_crop=True,
|
|
||||||
do_normalize=True,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.do_center_crop = do_center_crop
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
|
||||||
|
|
||||||
def resize_video(self, video, size, resample="bilinear"):
|
|
||||||
return [self.resize(frame, size, resample, default_to_square=False) for frame in video]
|
|
||||||
|
|
||||||
def crop_video(self, video, size):
|
|
||||||
return [self.center_crop(frame, size) for frame in video]
|
|
||||||
|
|
||||||
def normalize_video(self, video, mean, std):
|
|
||||||
# video can be a list of PIL images, list of NumPy arrays or list of PyTorch tensors
|
|
||||||
# first: convert to list of NumPy arrays
|
|
||||||
video = [self.to_numpy_array(frame) for frame in video]
|
|
||||||
|
|
||||||
# second: stack to get (num_frames, num_channels, height, width)
|
|
||||||
video = np.stack(video, axis=0)
|
|
||||||
|
|
||||||
# third: normalize
|
|
||||||
if not isinstance(mean, np.ndarray):
|
|
||||||
mean = np.array(mean).astype(video.dtype)
|
|
||||||
if not isinstance(std, np.ndarray):
|
|
||||||
std = np.array(std).astype(video.dtype)
|
|
||||||
|
|
||||||
return (video - mean[None, :, None, None]) / std[None, :, None, None]
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, videos: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several video(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays are converted to PIL images when resizing, so the most efficient is to pass PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
videos (`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, `List[List[PIL.Image.Image]]`, `List[List[np.ndarrray]]`,:
|
|
||||||
`List[List[torch.Tensor]]`): The video or batch of videos to be prepared. Each video should be a list
|
|
||||||
of frames, which can be either PIL images or NumPy arrays. In case of NumPy arrays/PyTorch tensors,
|
|
||||||
each frame should be of shape (H, W, C), where H and W are frame height and width, and C is a number of
|
|
||||||
channels.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, num_frames,
|
|
||||||
height, width).
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_videos = False
|
|
||||||
is_batched = False
|
|
||||||
|
|
||||||
# Check that videos have a valid type
|
|
||||||
if isinstance(videos, (list, tuple)):
|
|
||||||
if isinstance(videos[0], (Image.Image, np.ndarray)) or is_torch_tensor(videos[0]):
|
|
||||||
valid_videos = True
|
|
||||||
elif isinstance(videos[0], (list, tuple)) and (
|
|
||||||
isinstance(videos[0][0], (Image.Image, np.ndarray)) or is_torch_tensor(videos[0][0])
|
|
||||||
):
|
|
||||||
valid_videos = True
|
|
||||||
is_batched = True
|
|
||||||
|
|
||||||
if not valid_videos:
|
|
||||||
raise ValueError(
|
|
||||||
"Videos must of type `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]` (single"
|
|
||||||
" example), `List[List[PIL.Image.Image]]`, `List[List[np.ndarray]]`, `List[List[torch.Tensor]]` (batch"
|
|
||||||
" of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
videos = [videos]
|
|
||||||
|
|
||||||
# transformations (resizing + center cropping + normalization)
|
|
||||||
if self.do_resize and self.size is not None:
|
|
||||||
videos = [self.resize_video(video, size=self.size, resample=self.resample) for video in videos]
|
|
||||||
if self.do_center_crop and self.size is not None:
|
|
||||||
videos = [self.crop_video(video, size=self.size) for video in videos]
|
|
||||||
if self.do_normalize:
|
|
||||||
videos = [self.normalize_video(video, mean=self.image_mean, std=self.image_std) for video in videos]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": videos}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
|
@ -0,0 +1,380 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for VideoMAE."""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils import is_vision_available
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import (
|
||||||
|
center_crop,
|
||||||
|
get_resize_output_image_size,
|
||||||
|
normalize,
|
||||||
|
rescale,
|
||||||
|
resize,
|
||||||
|
to_channel_dimension_format,
|
||||||
|
)
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
is_valid_image,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def make_batched(videos) -> List[List[ImageInput]]:
|
||||||
|
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
||||||
|
return videos
|
||||||
|
|
||||||
|
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
||||||
|
return [videos]
|
||||||
|
|
||||||
|
elif is_valid_image(videos):
|
||||||
|
return [[videos]]
|
||||||
|
|
||||||
|
raise ValueError(f"Could not make batched video from {videos}")
|
||||||
|
|
||||||
|
|
||||||
|
class VideoMAEImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a VideoMAE image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
||||||
|
`do_resize` parameter in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
|
||||||
|
Size of the output image after resizing. The shortest edge of the image will be resized to
|
||||||
|
`size["shortest_edge"]` while maintaining the aspect ratio of the original image. Can be overriden by
|
||||||
|
`size` in the `preprocess` method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to center crop the image to the specified `crop_size`. Can be overridden by the `do_center_crop`
|
||||||
|
parameter in the `preprocess` method.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||||
|
Size of the image after applying the center crop. Can be overridden by the `crop_size` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||||
|
parameter in the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter
|
||||||
|
in the `preprocess` method.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||||
|
method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
do_center_crop: bool = True,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"shortest_edge": 224}
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.do_center_crop = do_center_crop
|
||||||
|
self.crop_size = crop_size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image. If `size` is of the form `{"height": h, "width": w}`, the output image will
|
||||||
|
have the size `(h, w)`. If `size` is of the form `{"shortest_edge": s}`, the output image will have its
|
||||||
|
shortest edge of length `s` while keeping the aspect ratio of the original image.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Resampling filter to use when resiizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
if "shortest_edge" in size:
|
||||||
|
output_size = get_resize_output_image_size(image, size["shortest_edge"], default_to_square=False)
|
||||||
|
elif "height" in size and "width" in size:
|
||||||
|
output_size = (size["height"], size["width"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
|
||||||
|
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def center_crop(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `size` along any
|
||||||
|
edge, the image is padded with 0's and then center cropped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to center crop.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, List[float]],
|
||||||
|
std: Union[float, List[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image. image = (image - image_mean) / image_std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
image_mean (`float` or `List[float]`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`):
|
||||||
|
Image standard deviation.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def _preprocess_image(
|
||||||
|
self,
|
||||||
|
image: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_center_crop: bool = None,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Preprocesses a single image."""
|
||||||
|
if do_resize and size is None or resample is None:
|
||||||
|
raise ValueError("Size and resample must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_center_crop and crop_size is None:
|
||||||
|
raise ValueError("Crop size must be specified if do_center_crop is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_normalize and (image_mean is None or image_std is None):
|
||||||
|
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
image = to_numpy_array(image)
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
image = self.resize(image=image, size=size, resample=resample)
|
||||||
|
|
||||||
|
if do_center_crop:
|
||||||
|
image = self.center_crop(image, size=crop_size)
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
image = self.rescale(image=image, scale=rescale_factor)
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
image = self.normalize(image=image, mean=image_mean, std=image_std)
|
||||||
|
|
||||||
|
image = to_channel_dimension_format(image, data_format)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
videos: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_center_crop: bool = None,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image after applying resize.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `self.do_centre_crop`):
|
||||||
|
Whether to centre crop the image.
|
||||||
|
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
||||||
|
Size of the image after applying the centre crop.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image values between [0 - 1].
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- Unset: Use the inferred channel dimension format of the input image.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||||
|
crop_size = get_size_dict(crop_size)
|
||||||
|
|
||||||
|
if not valid_images(videos):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
videos = make_batched(videos)
|
||||||
|
|
||||||
|
videos = [
|
||||||
|
[
|
||||||
|
self._preprocess_image(
|
||||||
|
image=img,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=resample,
|
||||||
|
do_center_crop=do_center_crop,
|
||||||
|
crop_size=crop_size,
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
do_normalize=do_normalize,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
data_format=data_format,
|
||||||
|
)
|
||||||
|
for img in video
|
||||||
|
]
|
||||||
|
for video in videos
|
||||||
|
]
|
||||||
|
|
||||||
|
data = {"pixel_values": videos}
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
@ -14,282 +14,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for ViLT."""
|
"""Feature extractor class for ViLT."""
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_vilt import ViltImageProcessor
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import (
|
|
||||||
IMAGENET_STANDARD_MEAN,
|
|
||||||
IMAGENET_STANDARD_STD,
|
|
||||||
ImageFeatureExtractionMixin,
|
|
||||||
ImageInput,
|
|
||||||
is_torch_tensor,
|
|
||||||
)
|
|
||||||
from ...utils import TensorType, is_torch_available, logging
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ViltFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
ViltFeatureExtractor = ViltImageProcessor
|
||||||
r"""
|
|
||||||
Constructs a ViLT feature extractor.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the input based on `size`.
|
|
||||||
size (`int`, *optional*, defaults to 384):
|
|
||||||
Resize the shorter side of the input to the given size. Should be an integer. The longer side will be
|
|
||||||
limited to under int((1333 / 800) * size) while preserving the aspect ratio. Only has an effect if
|
|
||||||
`do_resize` is set to `True`.
|
|
||||||
size_divisor (`int`, *optional*, defaults to 32):
|
|
||||||
The size by which to make sure both the height and width can be divided.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with mean and standard deviation.
|
|
||||||
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images.
|
|
||||||
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values", "pixel_mask"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize=True,
|
|
||||||
size=384,
|
|
||||||
size_divisor=32,
|
|
||||||
resample=PILImageResampling.BICUBIC,
|
|
||||||
do_normalize=True,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.size_divisor = size_divisor
|
|
||||||
self.resample = resample
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
|
||||||
|
|
||||||
def _resize(self, image, shorter=800, longer=1333, size_divisor=32, resample=PILImageResampling.BICUBIC):
|
|
||||||
"""
|
|
||||||
Resizes the shorter edge of `image` to `shorter` and limits the longer edge to under `longer`, while preserving
|
|
||||||
the aspect ratio. Also makes sure that both the height and width can be divided by `size_divisor`.
|
|
||||||
|
|
||||||
Based on original implementation:
|
|
||||||
https://github.com/dandelin/ViLT/blob/3db8b5035464afee84d951bf6322e1b27f1d072d/vilt/transforms/utils.py#L5
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (`PIL.Image`):
|
|
||||||
The image to resize.
|
|
||||||
shorter (`int`, *optional*, defaults to `800`):
|
|
||||||
The size to which to resize the shorter side of the image.
|
|
||||||
longer (`int`, *optional*, defaults to `1333`):
|
|
||||||
The size by which to limit the longer side of the image, while preserving the aspect ratio.
|
|
||||||
size_divisor (`int`, *optional*, defaults to `32`):
|
|
||||||
The size by which both the height and the width must be divisible.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
|
|
||||||
An optional resampling filter.
|
|
||||||
"""
|
|
||||||
if not isinstance(image, Image.Image):
|
|
||||||
image = self.to_pil_image(image)
|
|
||||||
|
|
||||||
w, h = image.size
|
|
||||||
min_size = shorter
|
|
||||||
max_size = longer
|
|
||||||
scale = min_size / min(w, h)
|
|
||||||
if h < w:
|
|
||||||
newh, neww = min_size, scale * w
|
|
||||||
else:
|
|
||||||
newh, neww = scale * h, min_size
|
|
||||||
|
|
||||||
if max(newh, neww) > max_size:
|
|
||||||
scale = max_size / max(newh, neww)
|
|
||||||
newh = newh * scale
|
|
||||||
neww = neww * scale
|
|
||||||
|
|
||||||
newh, neww = int(newh + 0.5), int(neww + 0.5)
|
|
||||||
newh, neww = newh // size_divisor * size_divisor, neww // size_divisor * size_divisor
|
|
||||||
|
|
||||||
return self.resize(image, size=(neww, newh), resample=resample)
|
|
||||||
|
|
||||||
def _max_by_axis(self, the_list):
|
|
||||||
# type: (List[List[int]]) -> List[int]
|
|
||||||
maxes = the_list[0]
|
|
||||||
for sublist in the_list[1:]:
|
|
||||||
for index, item in enumerate(sublist):
|
|
||||||
maxes[index] = max(maxes[index], item)
|
|
||||||
return maxes
|
|
||||||
|
|
||||||
def pad_and_create_pixel_mask(
|
|
||||||
self, pixel_values_list: List["torch.Tensor"], return_tensors: Optional[Union[str, TensorType]] = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pixel_values_list (`List[torch.Tensor]`):
|
|
||||||
List of images (pixel values) to be padded. Each image should be a tensor of shape (C, H, W).
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
|
||||||
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
|
|
||||||
objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model.
|
|
||||||
- **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if
|
|
||||||
*"pixel_mask"* is in `self.model_input_names`).
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list])
|
|
||||||
c, h, w = max_size
|
|
||||||
padded_images = []
|
|
||||||
pixel_mask = []
|
|
||||||
for image in pixel_values_list:
|
|
||||||
# create padded image
|
|
||||||
padded_image = np.zeros((c, h, w), dtype=np.float32)
|
|
||||||
padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
|
|
||||||
padded_images.append(padded_image)
|
|
||||||
# create pixel mask
|
|
||||||
mask = np.zeros((h, w), dtype=np.int64)
|
|
||||||
mask[: image.shape[1], : image.shape[2]] = True
|
|
||||||
pixel_mask.append(mask)
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": padded_images, "pixel_mask": pixel_mask}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
images: ImageInput,
|
|
||||||
pad_and_return_pixel_mask: Optional[bool] = True,
|
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
|
|
||||||
|
|
||||||
If left to the default, will return a pixel mask that is:
|
|
||||||
|
|
||||||
- 1 for pixels that are real (i.e. **not masked**),
|
|
||||||
- 0 for pixels that are padding (i.e. **masked**).
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
|
||||||
width).
|
|
||||||
- **pixel_mask** -- Pixel mask to be fed to a model (when `return_pixel_mask=True` or if *"pixel_mask"* is
|
|
||||||
in `self.model_input_names`).
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
# transformations (resizing + normalization)
|
|
||||||
if self.do_resize and self.size is not None:
|
|
||||||
longer = int((1333 / 800) * self.size)
|
|
||||||
images = [
|
|
||||||
self._resize(
|
|
||||||
image=image,
|
|
||||||
shorter=self.size,
|
|
||||||
longer=longer,
|
|
||||||
size_divisor=self.size_divisor,
|
|
||||||
resample=self.resample,
|
|
||||||
)
|
|
||||||
for image in images
|
|
||||||
]
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
|
||||||
|
|
||||||
if pad_and_return_pixel_mask:
|
|
||||||
# pad images up to largest image in batch and create pixel_mask
|
|
||||||
max_size = self._max_by_axis([list(image.shape) for image in images])
|
|
||||||
c, h, w = max_size
|
|
||||||
padded_images = []
|
|
||||||
pixel_mask = []
|
|
||||||
for image in images:
|
|
||||||
# create padded image
|
|
||||||
padded_image = np.zeros((c, h, w), dtype=np.float32)
|
|
||||||
padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
|
|
||||||
padded_images.append(padded_image)
|
|
||||||
# create pixel mask
|
|
||||||
mask = np.zeros((h, w), dtype=np.int64)
|
|
||||||
mask[: image.shape[1], : image.shape[2]] = True
|
|
||||||
pixel_mask.append(mask)
|
|
||||||
images = padded_images
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {}
|
|
||||||
data["pixel_values"] = images
|
|
||||||
if pad_and_return_pixel_mask:
|
|
||||||
data["pixel_mask"] = pixel_mask
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
|
@ -0,0 +1,487 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for Vilt."""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils import is_vision_available
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
get_image_size,
|
||||||
|
infer_channel_dimension_format,
|
||||||
|
is_batched,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||||
|
"""
|
||||||
|
Return the maximum value across all indices of an iterable of values.
|
||||||
|
"""
|
||||||
|
return [max(values_i) for values_i in zip(*values)]
|
||||||
|
|
||||||
|
|
||||||
|
def pad(
|
||||||
|
image: np.ndarray,
|
||||||
|
output_size: Tuple[int, int],
|
||||||
|
input_channel_dimension: Optional[ChannelDimension] = None,
|
||||||
|
data_format: Optional[ChannelDimension] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Pad the bottom and right of the image with zeros to the output size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to pad.
|
||||||
|
output_size (`Tuple[int, int]`):
|
||||||
|
Output size of the image.
|
||||||
|
input_channel_dimension (`ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be inferred from the input image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
if input_channel_dimension is None:
|
||||||
|
input_channel_dimension = infer_channel_dimension_format(image)
|
||||||
|
|
||||||
|
output_height, output_width = output_size
|
||||||
|
input_height, input_width = get_image_size(image)
|
||||||
|
pad_bottom = output_height - input_height
|
||||||
|
pad_right = output_width - input_width
|
||||||
|
|
||||||
|
if input_channel_dimension == ChannelDimension.FIRST:
|
||||||
|
padded_image = np.pad(image, [(0, 0), (0, pad_bottom), (0, pad_right)], mode="constant", constant_values=0)
|
||||||
|
elif input_channel_dimension == ChannelDimension.LAST:
|
||||||
|
padded_image = np.pad(image, [(0, pad_bottom), (0, pad_right), (0, 0)], mode="constant", constant_values=0)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
|
||||||
|
|
||||||
|
if data_format is not None:
|
||||||
|
padded_image = to_channel_dimension_format(padded_image, data_format)
|
||||||
|
|
||||||
|
return padded_image
|
||||||
|
|
||||||
|
|
||||||
|
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to make the pixel mask for.
|
||||||
|
output_size (`Tuple[int, int]`):
|
||||||
|
Output size of the mask.
|
||||||
|
"""
|
||||||
|
input_height, input_width = get_image_size(image)
|
||||||
|
mask = np.zeros(output_size, dtype=np.int64)
|
||||||
|
mask[:input_height, :input_width] = 1
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_dimensions(images: List[np.ndarray]) -> List[int]:
|
||||||
|
"""
|
||||||
|
Get the maximum height and width across all images in a batch.
|
||||||
|
"""
|
||||||
|
input_channel_dimension = infer_channel_dimension_format(images[0])
|
||||||
|
|
||||||
|
if input_channel_dimension == ChannelDimension.FIRST:
|
||||||
|
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
||||||
|
elif input_channel_dimension == ChannelDimension.LAST:
|
||||||
|
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
|
||||||
|
return (max_height, max_width)
|
||||||
|
|
||||||
|
|
||||||
|
def get_resize_output_image_size(
|
||||||
|
input_image: np.ndarray, shorter: int = 800, longer: int = 1333, size_divisor: int = 32
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
input_height, input_width = get_image_size(input_image)
|
||||||
|
min_size, max_size = shorter, longer
|
||||||
|
|
||||||
|
scale = min_size / min(input_height, input_width)
|
||||||
|
|
||||||
|
if input_height < input_width:
|
||||||
|
new_height = min_size
|
||||||
|
new_width = scale * input_width
|
||||||
|
else:
|
||||||
|
new_height = scale * input_height
|
||||||
|
new_width = min_size
|
||||||
|
|
||||||
|
if max(new_height, new_width) > max_size:
|
||||||
|
scale = max_size / max(new_height, new_width)
|
||||||
|
new_height = scale * new_height
|
||||||
|
new_width = scale * new_width
|
||||||
|
|
||||||
|
new_height, new_width = int(new_height + 0.5), int(new_width + 0.5)
|
||||||
|
new_height = new_height // size_divisor * size_divisor
|
||||||
|
new_width = new_width // size_divisor * size_divisor
|
||||||
|
|
||||||
|
return new_height, new_width
|
||||||
|
|
||||||
|
|
||||||
|
class ViltImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a ViLT image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
||||||
|
`do_resize` parameter in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`):
|
||||||
|
Resize the shorter side of the input to `size["shortest_edge"]`. The longer side will be limited to under
|
||||||
|
`int((1333 / 800) * size["shortest_edge"])` while preserving the aspect ratio. Only has an effect if
|
||||||
|
`do_resize` is set to `True`. Can be overridden by the `size` parameter in the `preprocess` method.
|
||||||
|
size_divisor (`int`, *optional*, defaults to 32):
|
||||||
|
The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
|
||||||
|
is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
|
||||||
|
overridden by the `resample` parameter in the `preprocess` method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
|
||||||
|
`do_rescale` parameter in the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
|
||||||
|
overridden by the `rescale_factor` parameter in the `preprocess` method.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||||
|
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
|
||||||
|
overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
do_pad (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by
|
||||||
|
the `do_pad` parameter in the `preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
size_divisor: int = 32,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_pad: bool = True,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
if "pad_and_return_pixel_mask" in kwargs:
|
||||||
|
do_pad = kwargs.pop("pad_and_return_pixel_mask")
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"shortest_edge": 384}
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.size_divisor = size_divisor
|
||||||
|
self.resample = resample
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
self.do_pad = do_pad
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
size_divisor: int = 32,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image.
|
||||||
|
|
||||||
|
Resizes the shorter side of the image to `size["shortest_edge"]` while preserving the aspect ratio. If the
|
||||||
|
longer side is larger than the max size `(int(`size["shortest_edge"]` * 1333 / 800))`, the longer side is then
|
||||||
|
resized to the max size while preserving the aspect ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Controls the size of the output image. Should be of the form `{"shortest_edge": int}`.
|
||||||
|
size_divisor (`int`, defaults to 32):
|
||||||
|
The image is resized to a size that is a multiple of this value.
|
||||||
|
resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use when resiizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
if "shortest_edge" not in size:
|
||||||
|
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
|
||||||
|
shorter = size["shortest_edge"]
|
||||||
|
longer = int(1333 / 800 * shorter)
|
||||||
|
output_size = get_resize_output_image_size(image, shorter=shorter, longer=longer, size_divisor=size_divisor)
|
||||||
|
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, List[float]],
|
||||||
|
std: Union[float, List[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image. image = (image - image_mean) / image_std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
mean (`float` or `List[float]`):
|
||||||
|
Image mean.
|
||||||
|
std (`float` or `List[float]`):
|
||||||
|
Image standard deviation.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def pad(
|
||||||
|
self,
|
||||||
|
images: List[np.ndarray],
|
||||||
|
return_pixel_mask: bool = True,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: Optional[ChannelDimension] = None,
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""
|
||||||
|
Pads a batch of images with zeros to the size of largest height and width in the batch and optionally returns
|
||||||
|
their corresponding pixel mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`List[np.ndarray]`):
|
||||||
|
Batch of images to pad.
|
||||||
|
return_pixel_mask (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to return the pixel mask.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
pad_size = get_max_dimensions(images)
|
||||||
|
padded_images = [pad(image=image, output_size=pad_size, data_format=data_format) for image in images]
|
||||||
|
data = {"pixel_values": padded_images}
|
||||||
|
if return_pixel_mask:
|
||||||
|
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
|
||||||
|
data["pixel_mask"] = masks
|
||||||
|
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def pad_and_create_pixel_mask(
|
||||||
|
self,
|
||||||
|
pixel_values_list: List[ImageInput],
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: Optional[ChannelDimension] = None,
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""
|
||||||
|
Pads a batch of images with zeros to the size of largest height and width in the batch and returns their
|
||||||
|
corresponding pixel mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`List[np.ndarray]`):
|
||||||
|
Batch of images to pad.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
"This method is deprecated and will be removed in v4.26.0. Please use pad instead.", FutureWarning
|
||||||
|
)
|
||||||
|
# pad expects a list of np.ndarray, but the previous feature extractors expected torch tensors
|
||||||
|
images = [to_numpy_array(image) for image in pixel_values_list]
|
||||||
|
return self.pad(
|
||||||
|
images=images,
|
||||||
|
return_pixel_mask=True,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
data_format=data_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: Optional[bool] = None,
|
||||||
|
size: Optional[Dict[str, int]] = None,
|
||||||
|
size_divisor: Optional[int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_rescale: Optional[bool] = None,
|
||||||
|
rescale_factor: Optional[float] = None,
|
||||||
|
do_normalize: Optional[bool] = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_pad: Optional[bool] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Controls the size of the image after `resize`. The shortest edge of the image is resized to
|
||||||
|
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
|
||||||
|
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
|
||||||
|
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
|
||||||
|
size_divisor (`int`, *optional*, defaults to `self.size_divisor`):
|
||||||
|
The image is resized to a size that is a multiple of this value.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image values between [0 - 1].
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean to normalize the image by if `do_normalize` is set to `True`.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
|
||||||
|
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
||||||
|
Whether to pad the image to the (max_height, max_width) in the batch. If `True`, a pixel mask is also
|
||||||
|
created and returned.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
size_divisor = size_divisor if size_divisor is not None else self.size_divisor
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
do_pad = do_pad if do_pad is not None else self.do_pad
|
||||||
|
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None or resample is None:
|
||||||
|
raise ValueError("Size and resample must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_normalize and (image_mean is None or image_std is None):
|
||||||
|
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images = [
|
||||||
|
self.resize(image=image, size=size, size_divisor=size_divisor, resample=resample) for image in images
|
||||||
|
]
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||||
|
|
||||||
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
|
if do_pad:
|
||||||
|
encoded_outputs = self.pad(images, return_pixel_mask=True, return_tensors=return_tensors)
|
||||||
|
else:
|
||||||
|
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
return encoded_outputs
|
|
@ -14,139 +14,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for ViT."""
|
"""Feature extractor class for ViT."""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_vit import ViTImageProcessor
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import (
|
|
||||||
IMAGENET_STANDARD_MEAN,
|
|
||||||
IMAGENET_STANDARD_STD,
|
|
||||||
ImageFeatureExtractionMixin,
|
|
||||||
ImageInput,
|
|
||||||
is_torch_tensor,
|
|
||||||
)
|
|
||||||
from ...utils import TensorType, logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
# Feature extractor for ViT is being replaced by image processor
|
||||||
r"""
|
ViTFeatureExtractor = ViTImageProcessor
|
||||||
Constructs a ViT feature extractor.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the input to a certain `size`.
|
|
||||||
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
|
|
||||||
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
|
|
||||||
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
|
|
||||||
set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
|
|
||||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
|
||||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
|
||||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
|
||||||
to `True`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with mean and standard deviation.
|
|
||||||
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images.
|
|
||||||
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize=True,
|
|
||||||
size=224,
|
|
||||||
resample=PILImageResampling.BILINEAR,
|
|
||||||
do_normalize=True,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
|
||||||
width).
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
# transformations (resizing + normalization)
|
|
||||||
if self.do_resize and self.size is not None:
|
|
||||||
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|
|
@ -0,0 +1,275 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for ViT."""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.utils.generic import TensorType
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
is_batched,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ViTImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a ViT image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
|
||||||
|
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
|
||||||
|
size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||||
|
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
||||||
|
method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||||
|
parameter in the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
||||||
|
`preprocess` method.
|
||||||
|
do_normalize:
|
||||||
|
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||||
|
method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Optional[Dict[str, int]] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"height": 224, "width": 224}
|
||||||
|
size = get_size_dict(size)
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image to `(size["height"], size["width"])`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||||
|
resample:
|
||||||
|
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||||
|
image is used. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`np.ndarray`: The resized image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
if "height" not in size or "width" not in size:
|
||||||
|
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||||
|
return resize(
|
||||||
|
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self, image: np.ndarray, scale: float, data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`float`):
|
||||||
|
The scaling factor to rescale pixel values by.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||||
|
image is used. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`np.ndarray`: The rescaled image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, List[float]],
|
||||||
|
std: Union[float, List[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image. image = (image - image_mean) / image_std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
mean (`float` or `List[float]`):
|
||||||
|
Image mean to use for normalization.
|
||||||
|
std (`float` or `List[float]`):
|
||||||
|
Image standard deviation to use for normalization.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||||
|
image is used. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`np.ndarray`: The normalized image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: Optional[bool] = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_rescale: Optional[bool] = None,
|
||||||
|
rescale_factor: Optional[float] = None,
|
||||||
|
do_normalize: Optional[bool] = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
|
||||||
|
resizing.
|
||||||
|
resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
|
||||||
|
`PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
|
||||||
|
an effect if `do_resize` is set to `True`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image values between [0 - 1].
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean to use if `do_normalize` is set to `True`.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation to use if `do_normalize` is set to `True`.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- Unset: Use the channel dimension format of the input image.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size_dict = get_size_dict(size)
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None:
|
||||||
|
raise ValueError("Size must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images = [self.resize(image=image, size=size_dict, resample=resample) for image in images]
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||||
|
|
||||||
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
|
data = {"pixel_values": images}
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
@ -44,14 +44,16 @@ class BeitFeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=20,
|
size=None,
|
||||||
do_center_crop=True,
|
do_center_crop=True,
|
||||||
crop_size=18,
|
crop_size=None,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=[0.5, 0.5, 0.5],
|
image_mean=[0.5, 0.5, 0.5],
|
||||||
image_std=[0.5, 0.5, 0.5],
|
image_std=[0.5, 0.5, 0.5],
|
||||||
reduce_labels=False,
|
do_reduce_labels=False,
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"height": 20, "width": 20}
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
@ -65,7 +67,7 @@ class BeitFeatureExtractionTester(unittest.TestCase):
|
||||||
self.do_normalize = do_normalize
|
self.do_normalize = do_normalize
|
||||||
self.image_mean = image_mean
|
self.image_mean = image_mean
|
||||||
self.image_std = image_std
|
self.image_std = image_std
|
||||||
self.reduce_labels = reduce_labels
|
self.do_reduce_labels = do_reduce_labels
|
||||||
|
|
||||||
def prepare_feat_extract_dict(self):
|
def prepare_feat_extract_dict(self):
|
||||||
return {
|
return {
|
||||||
|
@ -76,7 +78,7 @@ class BeitFeatureExtractionTester(unittest.TestCase):
|
||||||
"do_normalize": self.do_normalize,
|
"do_normalize": self.do_normalize,
|
||||||
"image_mean": self.image_mean,
|
"image_mean": self.image_mean,
|
||||||
"image_std": self.image_std,
|
"image_std": self.image_std,
|
||||||
"reduce_labels": self.reduce_labels,
|
"do_reduce_labels": self.do_reduce_labels,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,8 +143,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -153,8 +155,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -173,8 +175,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -185,8 +187,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -205,8 +207,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -217,8 +219,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -239,16 +241,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["labels"].shape,
|
encoding["labels"].shape,
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
|
@ -262,16 +264,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["labels"].shape,
|
encoding["labels"].shape,
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
|
@ -287,16 +289,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["labels"].shape,
|
encoding["labels"].shape,
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
|
@ -312,16 +314,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
2,
|
2,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["labels"].shape,
|
encoding["labels"].shape,
|
||||||
(
|
(
|
||||||
2,
|
2,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
|
|
|
@ -43,14 +43,16 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=20,
|
size=None,
|
||||||
do_center_crop=True,
|
do_center_crop=True,
|
||||||
crop_size=18,
|
crop_size=None,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||||
do_convert_rgb=True,
|
do_convert_rgb=True,
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"shortest_edge": 20}
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
@ -151,8 +153,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -163,8 +165,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -183,8 +185,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -195,8 +197,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -215,8 +217,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -227,8 +229,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -276,8 +278,8 @@ class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, un
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.expected_encoded_image_num_channels,
|
self.expected_encoded_image_num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -288,7 +290,7 @@ class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, un
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.expected_encoded_image_num_channels,
|
self.expected_encoded_image_num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -43,12 +43,13 @@ class ConvNextFeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=20,
|
size=None,
|
||||||
crop_pct=0.875,
|
crop_pct=0.875,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=[0.5, 0.5, 0.5],
|
image_mean=[0.5, 0.5, 0.5],
|
||||||
image_std=[0.5, 0.5, 0.5],
|
image_std=[0.5, 0.5, 0.5],
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"shortest_edge": 20}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
@ -113,8 +114,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["shortest_edge"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["shortest_edge"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -125,8 +126,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["shortest_edge"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["shortest_edge"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -145,8 +146,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["shortest_edge"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["shortest_edge"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -157,8 +158,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["shortest_edge"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["shortest_edge"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -177,8 +178,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["shortest_edge"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["shortest_edge"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -189,7 +190,7 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["shortest_edge"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["shortest_edge"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -43,13 +43,16 @@ class DeiTFeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=20,
|
size=None,
|
||||||
do_center_crop=True,
|
do_center_crop=True,
|
||||||
crop_size=18,
|
crop_size=None,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=[0.5, 0.5, 0.5],
|
image_mean=[0.5, 0.5, 0.5],
|
||||||
image_std=[0.5, 0.5, 0.5],
|
image_std=[0.5, 0.5, 0.5],
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"height": 20, "width": 20}
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||||
|
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
@ -117,8 +120,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -129,8 +132,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -149,8 +152,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -161,8 +164,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -181,8 +184,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -193,7 +196,7 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -43,11 +43,12 @@ class DPTFeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=18,
|
size=None,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=[0.5, 0.5, 0.5],
|
image_mean=[0.5, 0.5, 0.5],
|
||||||
image_std=[0.5, 0.5, 0.5],
|
image_std=[0.5, 0.5, 0.5],
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"height": 18, "width": 18}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
@ -106,8 +107,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -118,8 +119,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -138,8 +139,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -150,8 +151,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -170,8 +171,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -182,7 +183,7 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -28,11 +28,10 @@ if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
import PIL
|
||||||
|
|
||||||
from transformers import FlavaFeatureExtractor
|
from transformers import FlavaFeatureExtractor
|
||||||
from transformers.image_utils import PILImageResampling
|
from transformers.models.flava.image_processing_flava import (
|
||||||
from transformers.models.flava.feature_extraction_flava import (
|
|
||||||
FLAVA_CODEBOOK_MEAN,
|
FLAVA_CODEBOOK_MEAN,
|
||||||
FLAVA_CODEBOOK_STD,
|
FLAVA_CODEBOOK_STD,
|
||||||
FLAVA_IMAGE_MEAN,
|
FLAVA_IMAGE_MEAN,
|
||||||
|
@ -51,10 +50,12 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=224,
|
size=None,
|
||||||
do_center_crop=True,
|
do_center_crop=True,
|
||||||
crop_size=224,
|
crop_size=None,
|
||||||
resample=None,
|
resample=None,
|
||||||
|
do_rescale=True,
|
||||||
|
rescale_factor=1 / 255,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=FLAVA_IMAGE_MEAN,
|
image_mean=FLAVA_IMAGE_MEAN,
|
||||||
image_std=FLAVA_IMAGE_STD,
|
image_std=FLAVA_IMAGE_STD,
|
||||||
|
@ -65,23 +66,30 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
|
||||||
mask_group_min_aspect_ratio=0.3,
|
mask_group_min_aspect_ratio=0.3,
|
||||||
mask_group_max_aspect_ratio=None,
|
mask_group_max_aspect_ratio=None,
|
||||||
codebook_do_resize=True,
|
codebook_do_resize=True,
|
||||||
codebook_size=112,
|
codebook_size=None,
|
||||||
codebook_resample=None,
|
codebook_resample=None,
|
||||||
codebook_do_center_crop=True,
|
codebook_do_center_crop=True,
|
||||||
codebook_crop_size=112,
|
codebook_crop_size=None,
|
||||||
codebook_do_map_pixels=True,
|
codebook_do_map_pixels=True,
|
||||||
codebook_do_normalize=True,
|
codebook_do_normalize=True,
|
||||||
codebook_image_mean=FLAVA_CODEBOOK_MEAN,
|
codebook_image_mean=FLAVA_CODEBOOK_MEAN,
|
||||||
codebook_image_std=FLAVA_CODEBOOK_STD,
|
codebook_image_std=FLAVA_CODEBOOK_STD,
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"height": 224, "width": 224}
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||||
|
codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112}
|
||||||
|
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112}
|
||||||
|
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.do_resize = do_resize
|
self.do_resize = do_resize
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
self.min_resolution = min_resolution
|
self.min_resolution = min_resolution
|
||||||
self.max_resolution = max_resolution
|
self.max_resolution = max_resolution
|
||||||
self.size = size
|
self.size = size
|
||||||
self.resample = resample if resample is not None else PILImageResampling.BICUBIC
|
self.resample = resample if resample is not None else PIL.Image.Resampling.BICUBIC
|
||||||
self.do_normalize = do_normalize
|
self.do_normalize = do_normalize
|
||||||
self.image_mean = image_mean
|
self.image_mean = image_mean
|
||||||
self.image_std = image_std
|
self.image_std = image_std
|
||||||
|
@ -97,7 +105,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
|
||||||
|
|
||||||
self.codebook_do_resize = codebook_do_resize
|
self.codebook_do_resize = codebook_do_resize
|
||||||
self.codebook_size = codebook_size
|
self.codebook_size = codebook_size
|
||||||
self.codebook_resample = codebook_resample if codebook_resample is not None else PILImageResampling.LANCZOS
|
self.codebook_resample = codebook_resample if codebook_resample is not None else PIL.Image.Resampling.LANCZOS
|
||||||
self.codebook_do_center_crop = codebook_do_center_crop
|
self.codebook_do_center_crop = codebook_do_center_crop
|
||||||
self.codebook_crop_size = codebook_crop_size
|
self.codebook_crop_size = codebook_crop_size
|
||||||
self.codebook_do_map_pixels = codebook_do_map_pixels
|
self.codebook_do_map_pixels = codebook_do_map_pixels
|
||||||
|
@ -113,6 +121,8 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
|
||||||
"do_resize": self.do_resize,
|
"do_resize": self.do_resize,
|
||||||
"size": self.size,
|
"size": self.size,
|
||||||
"resample": self.resample,
|
"resample": self.resample,
|
||||||
|
"do_rescale": self.do_rescale,
|
||||||
|
"rescale_factor": self.rescale_factor,
|
||||||
"do_center_crop": self.do_center_crop,
|
"do_center_crop": self.do_center_crop,
|
||||||
"crop_size": self.crop_size,
|
"crop_size": self.crop_size,
|
||||||
"input_size_patches": self.input_size_patches,
|
"input_size_patches": self.input_size_patches,
|
||||||
|
@ -133,7 +143,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_expected_image_size(self):
|
def get_expected_image_size(self):
|
||||||
return (self.size, self.size) if not isinstance(self.size, tuple) else self.size
|
return (self.size["height"], self.size["width"])
|
||||||
|
|
||||||
def get_expected_mask_size(self):
|
def get_expected_mask_size(self):
|
||||||
return (
|
return (
|
||||||
|
@ -143,10 +153,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_expected_codebook_image_size(self):
|
def get_expected_codebook_image_size(self):
|
||||||
if not isinstance(self.codebook_size, tuple):
|
return (self.codebook_size["height"], self.codebook_size["width"])
|
||||||
return (self.codebook_size, self.codebook_size)
|
|
||||||
else:
|
|
||||||
return self.codebook_size
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
@ -172,6 +179,8 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
||||||
self.assertTrue(hasattr(feature_extractor, "resample"))
|
self.assertTrue(hasattr(feature_extractor, "resample"))
|
||||||
self.assertTrue(hasattr(feature_extractor, "crop_size"))
|
self.assertTrue(hasattr(feature_extractor, "crop_size"))
|
||||||
self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
|
self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
|
||||||
|
self.assertTrue(hasattr(feature_extractor, "do_rescale"))
|
||||||
|
self.assertTrue(hasattr(feature_extractor, "rescale_factor"))
|
||||||
self.assertTrue(hasattr(feature_extractor, "masking_generator"))
|
self.assertTrue(hasattr(feature_extractor, "masking_generator"))
|
||||||
self.assertTrue(hasattr(feature_extractor, "codebook_do_resize"))
|
self.assertTrue(hasattr(feature_extractor, "codebook_do_resize"))
|
||||||
self.assertTrue(hasattr(feature_extractor, "codebook_size"))
|
self.assertTrue(hasattr(feature_extractor, "codebook_size"))
|
||||||
|
@ -192,7 +201,7 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
||||||
# create random PIL images
|
# create random PIL images
|
||||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
|
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
|
||||||
for image in image_inputs:
|
for image in image_inputs:
|
||||||
self.assertIsInstance(image, Image.Image)
|
self.assertIsInstance(image, PIL.Image.Image)
|
||||||
|
|
||||||
# Test not batched input
|
# Test not batched input
|
||||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt")
|
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt")
|
||||||
|
@ -324,7 +333,7 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
||||||
# create random PIL images
|
# create random PIL images
|
||||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
|
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
|
||||||
for image in image_inputs:
|
for image in image_inputs:
|
||||||
self.assertIsInstance(image, Image.Image)
|
self.assertIsInstance(image, PIL.Image.Image)
|
||||||
|
|
||||||
# Test not batched input
|
# Test not batched input
|
||||||
encoded_images = feature_extractor(image_inputs[0], return_codebook_pixels=True, return_tensors="pt")
|
encoded_images = feature_extractor(image_inputs[0], return_codebook_pixels=True, return_tensors="pt")
|
||||||
|
|
|
@ -32,7 +32,7 @@ if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from transformers import FlavaFeatureExtractor, FlavaProcessor
|
from transformers import FlavaFeatureExtractor, FlavaProcessor
|
||||||
from transformers.models.flava.feature_extraction_flava import (
|
from transformers.models.flava.image_processing_flava import (
|
||||||
FLAVA_CODEBOOK_MEAN,
|
FLAVA_CODEBOOK_MEAN,
|
||||||
FLAVA_CODEBOOK_STD,
|
FLAVA_CODEBOOK_STD,
|
||||||
FLAVA_IMAGE_MEAN,
|
FLAVA_IMAGE_MEAN,
|
||||||
|
@ -69,7 +69,6 @@ class FlavaProcessorTest(unittest.TestCase):
|
||||||
"mask_group_max_aspect_ratio": None,
|
"mask_group_max_aspect_ratio": None,
|
||||||
"codebook_do_resize": True,
|
"codebook_do_resize": True,
|
||||||
"codebook_size": 112,
|
"codebook_size": 112,
|
||||||
"codebook_resample": None,
|
|
||||||
"codebook_do_center_crop": True,
|
"codebook_do_center_crop": True,
|
||||||
"codebook_crop_size": 112,
|
"codebook_crop_size": 112,
|
||||||
"codebook_do_map_pixels": True,
|
"codebook_do_map_pixels": True,
|
||||||
|
|
|
@ -47,9 +47,10 @@ class ImageGPTFeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=18,
|
size=None,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"height": 18, "width": 18}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
|
|
@ -43,9 +43,10 @@ class LayoutLMv2FeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=18,
|
size=None,
|
||||||
apply_ocr=True,
|
apply_ocr=True,
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"height": 18, "width": 18}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
@ -97,8 +98,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -112,8 +113,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -132,8 +133,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -144,8 +145,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -164,8 +165,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -176,8 +177,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -210,12 +211,4 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
|
|
||||||
encoding = feature_extractor(image, return_tensors="pt")
|
encoding = feature_extractor(image, return_tensors="pt")
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(encoding.pixel_values.shape, (1, 3, 224, 224))
|
||||||
encoding.pixel_values.shape,
|
|
||||||
(
|
|
||||||
1,
|
|
||||||
3,
|
|
||||||
224,
|
|
||||||
224,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
|
@ -43,9 +43,10 @@ class LayoutLMv3FeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=18,
|
size=None,
|
||||||
apply_ocr=True,
|
apply_ocr=True,
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"height": 18, "width": 18}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
@ -97,8 +98,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -112,8 +113,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -132,8 +133,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -144,8 +145,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -164,8 +165,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -176,8 +177,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -43,12 +43,15 @@ class LevitFeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=18,
|
size=None,
|
||||||
do_center_crop=True,
|
do_center_crop=True,
|
||||||
|
crop_size=None,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=[0.5, 0.5, 0.5],
|
image_mean=[0.5, 0.5, 0.5],
|
||||||
image_std=[0.5, 0.5, 0.5],
|
image_std=[0.5, 0.5, 0.5],
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"shortest_edge": 18}
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
@ -58,6 +61,7 @@ class LevitFeatureExtractionTester(unittest.TestCase):
|
||||||
self.do_resize = do_resize
|
self.do_resize = do_resize
|
||||||
self.size = size
|
self.size = size
|
||||||
self.do_center_crop = do_center_crop
|
self.do_center_crop = do_center_crop
|
||||||
|
self.crop_size = crop_size
|
||||||
self.do_normalize = do_normalize
|
self.do_normalize = do_normalize
|
||||||
self.image_mean = image_mean
|
self.image_mean = image_mean
|
||||||
self.image_std = image_std
|
self.image_std = image_std
|
||||||
|
@ -70,6 +74,7 @@ class LevitFeatureExtractionTester(unittest.TestCase):
|
||||||
"do_resize": self.do_resize,
|
"do_resize": self.do_resize,
|
||||||
"do_center_crop": self.do_center_crop,
|
"do_center_crop": self.do_center_crop,
|
||||||
"size": self.size,
|
"size": self.size,
|
||||||
|
"crop_size": self.crop_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,8 +118,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -125,8 +130,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -145,8 +150,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -157,8 +162,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -177,8 +182,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -189,7 +194,7 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -43,11 +43,13 @@ class MobileViTFeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=20,
|
size=None,
|
||||||
do_center_crop=True,
|
do_center_crop=True,
|
||||||
crop_size=18,
|
crop_size=None,
|
||||||
do_flip_channel_order=True,
|
do_flip_channel_order=True,
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"shortest_edge": 20}
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
@ -109,8 +111,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -121,8 +123,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -141,8 +143,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -153,8 +155,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -173,8 +175,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -185,7 +187,7 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -41,12 +41,15 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize_and_center_crop=True,
|
do_resize_and_center_crop=True,
|
||||||
size=30,
|
size=None,
|
||||||
crop_pct=0.9,
|
crop_pct=0.9,
|
||||||
|
crop_size=None,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=[0.5, 0.5, 0.5],
|
image_mean=[0.5, 0.5, 0.5],
|
||||||
image_std=[0.5, 0.5, 0.5],
|
image_std=[0.5, 0.5, 0.5],
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"shortest_edge": 30}
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 30, "width": 30}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
@ -55,6 +58,7 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase):
|
||||||
self.do_resize_and_center_crop = do_resize_and_center_crop
|
self.do_resize_and_center_crop = do_resize_and_center_crop
|
||||||
self.size = size
|
self.size = size
|
||||||
self.crop_pct = crop_pct
|
self.crop_pct = crop_pct
|
||||||
|
self.crop_size = crop_size
|
||||||
self.do_normalize = do_normalize
|
self.do_normalize = do_normalize
|
||||||
self.image_mean = image_mean
|
self.image_mean = image_mean
|
||||||
self.image_std = image_std
|
self.image_std = image_std
|
||||||
|
@ -64,6 +68,7 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase):
|
||||||
"size": self.size,
|
"size": self.size,
|
||||||
"do_resize_and_center_crop": self.do_resize_and_center_crop,
|
"do_resize_and_center_crop": self.do_resize_and_center_crop,
|
||||||
"crop_pct": self.crop_pct,
|
"crop_pct": self.crop_pct,
|
||||||
|
"crop_size": self.crop_size,
|
||||||
"do_normalize": self.do_normalize,
|
"do_normalize": self.do_normalize,
|
||||||
"image_mean": self.image_mean,
|
"image_mean": self.image_mean,
|
||||||
"image_std": self.image_std,
|
"image_std": self.image_std,
|
||||||
|
@ -111,8 +116,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -123,8 +128,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -143,8 +148,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -155,8 +160,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -175,8 +180,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -187,7 +192,7 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -43,12 +43,13 @@ class SegformerFeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=30,
|
size=None,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=[0.5, 0.5, 0.5],
|
image_mean=[0.5, 0.5, 0.5],
|
||||||
image_std=[0.5, 0.5, 0.5],
|
image_std=[0.5, 0.5, 0.5],
|
||||||
reduce_labels=False,
|
do_reduce_labels=False,
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"height": 30, "width": 30}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
@ -59,7 +60,7 @@ class SegformerFeatureExtractionTester(unittest.TestCase):
|
||||||
self.do_normalize = do_normalize
|
self.do_normalize = do_normalize
|
||||||
self.image_mean = image_mean
|
self.image_mean = image_mean
|
||||||
self.image_std = image_std
|
self.image_std = image_std
|
||||||
self.reduce_labels = reduce_labels
|
self.do_reduce_labels = do_reduce_labels
|
||||||
|
|
||||||
def prepare_feat_extract_dict(self):
|
def prepare_feat_extract_dict(self):
|
||||||
return {
|
return {
|
||||||
|
@ -68,7 +69,7 @@ class SegformerFeatureExtractionTester(unittest.TestCase):
|
||||||
"do_normalize": self.do_normalize,
|
"do_normalize": self.do_normalize,
|
||||||
"image_mean": self.image_mean,
|
"image_mean": self.image_mean,
|
||||||
"image_std": self.image_std,
|
"image_std": self.image_std,
|
||||||
"reduce_labels": self.reduce_labels,
|
"do_reduce_labels": self.do_reduce_labels,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -112,7 +113,7 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
|
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
|
||||||
self.assertTrue(hasattr(feature_extractor, "image_mean"))
|
self.assertTrue(hasattr(feature_extractor, "image_mean"))
|
||||||
self.assertTrue(hasattr(feature_extractor, "image_std"))
|
self.assertTrue(hasattr(feature_extractor, "image_std"))
|
||||||
self.assertTrue(hasattr(feature_extractor, "reduce_labels"))
|
self.assertTrue(hasattr(feature_extractor, "do_reduce_labels"))
|
||||||
|
|
||||||
def test_batch_feature(self):
|
def test_batch_feature(self):
|
||||||
pass
|
pass
|
||||||
|
@ -132,8 +133,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -144,8 +145,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -164,8 +165,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -176,8 +177,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -196,8 +197,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -208,8 +209,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -230,16 +231,16 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["labels"].shape,
|
encoding["labels"].shape,
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
|
@ -253,16 +254,16 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["labels"].shape,
|
encoding["labels"].shape,
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
|
@ -278,16 +279,16 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["labels"].shape,
|
encoding["labels"].shape,
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
|
@ -303,16 +304,16 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
|
||||||
(
|
(
|
||||||
2,
|
2,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["labels"].shape,
|
encoding["labels"].shape,
|
||||||
(
|
(
|
||||||
2,
|
2,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
|
|
|
@ -44,11 +44,15 @@ class VideoMAEFeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=18,
|
size=None,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=[0.5, 0.5, 0.5],
|
image_mean=[0.5, 0.5, 0.5],
|
||||||
image_std=[0.5, 0.5, 0.5],
|
image_std=[0.5, 0.5, 0.5],
|
||||||
|
crop_size=None,
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"shortest_edge": 18}
|
||||||
|
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||||
|
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
@ -61,6 +65,7 @@ class VideoMAEFeatureExtractionTester(unittest.TestCase):
|
||||||
self.do_normalize = do_normalize
|
self.do_normalize = do_normalize
|
||||||
self.image_mean = image_mean
|
self.image_mean = image_mean
|
||||||
self.image_std = image_std
|
self.image_std = image_std
|
||||||
|
self.crop_size = crop_size
|
||||||
|
|
||||||
def prepare_feat_extract_dict(self):
|
def prepare_feat_extract_dict(self):
|
||||||
return {
|
return {
|
||||||
|
@ -69,6 +74,7 @@ class VideoMAEFeatureExtractionTester(unittest.TestCase):
|
||||||
"do_normalize": self.do_normalize,
|
"do_normalize": self.do_normalize,
|
||||||
"do_resize": self.do_resize,
|
"do_resize": self.do_resize,
|
||||||
"size": self.size,
|
"size": self.size,
|
||||||
|
"crop_size": self.crop_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,6 +97,7 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
|
||||||
self.assertTrue(hasattr(feature_extractor, "image_std"))
|
self.assertTrue(hasattr(feature_extractor, "image_std"))
|
||||||
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
|
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
|
||||||
self.assertTrue(hasattr(feature_extractor, "do_resize"))
|
self.assertTrue(hasattr(feature_extractor, "do_resize"))
|
||||||
|
self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
|
||||||
self.assertTrue(hasattr(feature_extractor, "size"))
|
self.assertTrue(hasattr(feature_extractor, "size"))
|
||||||
|
|
||||||
def test_batch_feature(self):
|
def test_batch_feature(self):
|
||||||
|
@ -113,8 +120,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_frames,
|
self.feature_extract_tester.num_frames,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -126,8 +133,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_frames,
|
self.feature_extract_tester.num_frames,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -148,8 +155,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_frames,
|
self.feature_extract_tester.num_frames,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -161,8 +168,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_frames,
|
self.feature_extract_tester.num_frames,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -183,8 +190,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_frames,
|
self.feature_extract_tester.num_frames,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -196,7 +203,7 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_frames,
|
self.feature_extract_tester.num_frames,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.crop_size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -43,12 +43,13 @@ class ViltFeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=30,
|
size=None,
|
||||||
size_divisor=2,
|
size_divisor=2,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=[0.5, 0.5, 0.5],
|
image_mean=[0.5, 0.5, 0.5],
|
||||||
image_std=[0.5, 0.5, 0.5],
|
image_std=[0.5, 0.5, 0.5],
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"shortest_edge": 30}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
@ -78,18 +79,19 @@ class ViltFeatureExtractionTester(unittest.TestCase):
|
||||||
assuming do_resize is set to True with a scalar size and size_divisor.
|
assuming do_resize is set to True with a scalar size and size_divisor.
|
||||||
"""
|
"""
|
||||||
if not batched:
|
if not batched:
|
||||||
|
size = self.size["shortest_edge"]
|
||||||
image = image_inputs[0]
|
image = image_inputs[0]
|
||||||
if isinstance(image, Image.Image):
|
if isinstance(image, Image.Image):
|
||||||
w, h = image.size
|
w, h = image.size
|
||||||
else:
|
else:
|
||||||
h, w = image.shape[1], image.shape[2]
|
h, w = image.shape[1], image.shape[2]
|
||||||
scale = self.size / min(w, h)
|
scale = size / min(w, h)
|
||||||
if h < w:
|
if h < w:
|
||||||
newh, neww = self.size, scale * w
|
newh, neww = size, scale * w
|
||||||
else:
|
else:
|
||||||
newh, neww = scale * h, self.size
|
newh, neww = scale * h, size
|
||||||
|
|
||||||
max_size = int((1333 / 800) * self.size)
|
max_size = int((1333 / 800) * size)
|
||||||
if max(newh, neww) > max_size:
|
if max(newh, neww) > max_size:
|
||||||
scale = max_size / max(newh, neww)
|
scale = max_size / max(newh, neww)
|
||||||
newh = newh * scale
|
newh = newh * scale
|
||||||
|
@ -233,7 +235,7 @@ class ViltFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||||
def test_equivalence_pad_and_create_pixel_mask(self):
|
def test_equivalence_pad_and_create_pixel_mask(self):
|
||||||
# Initialize feature_extractors
|
# Initialize feature_extractors
|
||||||
feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
|
feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
|
||||||
feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False)
|
feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False, do_rescale=False)
|
||||||
# create random PyTorch tensors
|
# create random PyTorch tensors
|
||||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
|
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
|
||||||
for image in image_inputs:
|
for image in image_inputs:
|
||||||
|
|
|
@ -43,11 +43,12 @@ class ViTFeatureExtractionTester(unittest.TestCase):
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=18,
|
size=None,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=[0.5, 0.5, 0.5],
|
image_mean=[0.5, 0.5, 0.5],
|
||||||
image_std=[0.5, 0.5, 0.5],
|
image_std=[0.5, 0.5, 0.5],
|
||||||
):
|
):
|
||||||
|
size = size if size is not None else {"height": 18, "width": 18}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
|
@ -109,8 +110,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -121,8 +122,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -141,8 +142,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -153,8 +154,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -173,8 +174,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -185,7 +186,7 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
|
||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size,
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers.image_processing_utils import get_size_dict
|
||||||
|
|
||||||
|
|
||||||
|
class ImageProcessingUtilsTester(unittest.TestCase):
|
||||||
|
def test_get_size_dict(self):
|
||||||
|
# Test a dict with the wrong keys raises an error
|
||||||
|
inputs = {"wrong_key": 224}
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
get_size_dict(inputs)
|
||||||
|
|
||||||
|
inputs = {"height": 224}
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
get_size_dict(inputs)
|
||||||
|
|
||||||
|
inputs = {"width": 224, "shortest_edge": 224}
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
get_size_dict(inputs)
|
||||||
|
|
||||||
|
# Test a dict with the correct keys is returned as is
|
||||||
|
inputs = {"height": 224, "width": 224}
|
||||||
|
outputs = get_size_dict(inputs)
|
||||||
|
self.assertEqual(outputs, inputs)
|
||||||
|
|
||||||
|
inputs = {"shortest_edge": 224}
|
||||||
|
outputs = get_size_dict(inputs)
|
||||||
|
self.assertEqual(outputs, {"shortest_edge": 224})
|
||||||
|
|
||||||
|
inputs = {"longest_edge": 224, "shortest_edge": 224}
|
||||||
|
outputs = get_size_dict(inputs)
|
||||||
|
self.assertEqual(outputs, {"longest_edge": 224, "shortest_edge": 224})
|
||||||
|
|
||||||
|
# Test a single int value which represents (size, size)
|
||||||
|
outputs = get_size_dict(224)
|
||||||
|
self.assertEqual(outputs, {"height": 224, "width": 224})
|
||||||
|
|
||||||
|
# Test a single int value which represents the shortest edge
|
||||||
|
outputs = get_size_dict(224, default_to_square=False)
|
||||||
|
self.assertEqual(outputs, {"shortest_edge": 224})
|
||||||
|
|
||||||
|
# Test a tuple of ints which represents (height, width)
|
||||||
|
outputs = get_size_dict((150, 200))
|
||||||
|
self.assertEqual(outputs, {"height": 150, "width": 200})
|
||||||
|
|
||||||
|
# Test a tuple of ints which represents (width, height)
|
||||||
|
outputs = get_size_dict((150, 200), height_width_order=False)
|
||||||
|
self.assertEqual(outputs, {"height": 200, "width": 150})
|
||||||
|
|
||||||
|
# Test an int representing the shortest edge and max_size which represents the longest edge
|
||||||
|
outputs = get_size_dict(224, max_size=256, default_to_square=False)
|
||||||
|
self.assertEqual(outputs, {"shortest_edge": 224, "longest_edge": 256})
|
||||||
|
|
||||||
|
# Test int with default_to_square=True and max_size fails
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
get_size_dict(224, max_size=256, default_to_square=True)
|
Loading…
Reference in New Issue