Add Image Processors (#19796)

* Add CLIP image processor

* Crop size as dict too

* Update warning

* Actually use logger this time

* Normalize doesn't change dtype of input

* Add perceiver image processor

* Tidy up

* Add DPT image processor

* Add Vilt image processor

* Tidy up

* Add poolformer image processor

* Tidy up

* Add LayoutLM v2 and v3 imsge processors

* Tidy up

* Add Flava image processor

* Tidy up

* Add deit image processor

* Tidy up

* Add ConvNext image processor

* Tidy up

* Add levit image processor

* Add segformer image processor

* Add in post processing

* Fix up

* Add ImageGPT image processor

* Fixup

* Add mobilevit image processor

* Tidy up

* Add postprocessing

* Fixup

* Add VideoMAE image processor

* Tidy up

* Add ImageGPT image processor

* Fixup

* Add ViT image processor

* Tidy up

* Add beit image processor

* Add mobilevit image processor

* Tidy up

* Add postprocessing

* Fixup

* Fix up

* Fix flava and remove tree module

* Fix image classification pipeline failing tests

* Update feature extractor in trainer scripts

* Update pad_if_smaller to accept tuple and int size

* Update for image segmentation pipeline

* Update src/transformers/models/perceiver/image_processing_perceiver.py

Co-authored-by: Alara Dirik <8944735+alaradirik@users.noreply.github.com>

* Update src/transformers/image_processing_utils.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/beit/image_processing_beit.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* PR comments - docstrings; remove accidentally added resize; var names

* Update docstrings

* Add exception if size is not in the right format

* Fix exception check

* Fix up

* Use shortest_edge in tuple in script

Co-authored-by: Alara Dirik <8944735+alaradirik@users.noreply.github.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
amyeroberts 2022-11-02 11:57:36 +00:00 committed by GitHub
parent 2e3452af0f
commit a6b7759880
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
65 changed files with 7060 additions and 3590 deletions

View File

@ -361,9 +361,12 @@ For computer vision tasks, it is common to add some type of data augmentation to
>>> from torchvision.transforms import Compose, Normalize, RandomResizedCrop, ColorJitter, ToTensor
>>> normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
>>> _transforms = Compose(
... [RandomResizedCrop(feature_extractor.size), ColorJitter(brightness=0.5, hue=0.5), ToTensor(), normalize]
>>> size = (
... feature_extractor.size["shortest_edge"]
... if "shortest_edge" in feature_extractor.size
... else (feature_extractor.size["height"], feature_extractor.size["width"])
... )
>>> _transforms = Compose([RandomResizedCrop(size), ColorJitter(brightness=0.5, hue=0.5), ToTensor(), normalize])
```
2. The model accepts [`pixel_values`](model_doc/visionencoderdecoder#transformers.VisionEncoderDecoderModel.forward.pixel_values) as its input, which is generated by the feature extractor. Create a function that generates `pixel_values` from the transforms:

View File

@ -83,7 +83,12 @@ Apply several image transformations to the dataset to make the model more robust
>>> from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
>>> normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
>>> _transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
>>> size = (
... feature_extractor.size["shortest_edge"]
... if "shortest_edge" in feature_extractor.size
... else (feature_extractor.size["height"], feature_extractor.size["width"])
... )
>>> _transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
```
Create a preprocessing function that will apply the transforms and return the `pixel_values` - the inputs to the model - of the image:

View File

@ -291,10 +291,14 @@ def main():
)
# Define torchvision transforms to be applied to each image.
if "shortest_edge" in feature_extractor.size:
size = feature_extractor.size["shortest_edge"]
else:
size = (feature_extractor.size["height"], feature_extractor.size["width"])
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_train_transforms = Compose(
[
RandomResizedCrop(feature_extractor.size),
RandomResizedCrop(size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
@ -302,8 +306,8 @@ def main():
)
_val_transforms = Compose(
[
Resize(feature_extractor.size),
CenterCrop(feature_extractor.size),
Resize(size),
CenterCrop(size),
ToTensor(),
normalize,
]

View File

@ -315,10 +315,14 @@ def main():
# Preprocessing the datasets
# Define torchvision transforms to be applied to each image.
if "shortest_edge" in feature_extractor.size:
size = feature_extractor.size["shortest_edge"]
else:
size = (feature_extractor.size["height"], feature_extractor.size["width"])
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
train_transforms = Compose(
[
RandomResizedCrop(feature_extractor.size),
RandomResizedCrop(size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
@ -326,8 +330,8 @@ def main():
)
val_transforms = Compose(
[
Resize(feature_extractor.size),
CenterCrop(feature_extractor.size),
Resize(size),
CenterCrop(size),
ToTensor(),
normalize,
]

View File

@ -298,10 +298,14 @@ def main():
# transformations as done in original MAE paper
# source: https://github.com/facebookresearch/mae/blob/main/main_pretrain.py
if "shortest_edge" in feature_extractor.size:
size = feature_extractor.size["shortest_edge"]
else:
size = (feature_extractor.size["height"], feature_extractor.size["width"])
transforms = Compose(
[
Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
RandomResizedCrop(feature_extractor.size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC),
RandomResizedCrop(size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC),
RandomHorizontalFlip(),
ToTensor(),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),

View File

@ -57,11 +57,10 @@ require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/sema
def pad_if_smaller(img, size, fill=0):
min_size = min(img.size)
if min_size < size:
size = (size, size) if isinstance(size, int) else size
original_width, original_height = img.size
pad_height = size - original_height if original_height < size else 0
pad_width = size - original_width if original_width < size else 0
pad_height = size[1] - original_height if original_height < size[1] else 0
pad_width = size[0] - original_width if original_width < size[0] else 0
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
return img
@ -110,12 +109,12 @@ class RandomResize:
class RandomCrop:
def __init__(self, size):
self.size = size
self.size = size if isinstance(size, tuple) else (size, size)
def __call__(self, image, target):
image = pad_if_smaller(image, self.size)
target = pad_if_smaller(target, self.size, fill=255)
crop_params = transforms.RandomCrop.get_params(image, (self.size, self.size))
crop_params = transforms.RandomCrop.get_params(image, self.size)
image = functional.crop(image, *crop_params)
target = functional.crop(target, *crop_params)
return image, target
@ -359,7 +358,7 @@ def main():
references=labels,
num_labels=len(id2label),
ignore_index=0,
reduce_labels=feature_extractor.reduce_labels,
reduce_labels=feature_extractor.do_reduce_labels,
)
# add per category metrics as individual key-value pairs
per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
@ -396,10 +395,15 @@ def main():
# Define torchvision transforms to be applied to each image + target.
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
if "shortest_edge" in feature_extractor.size:
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
size = (feature_extractor.size["shortest_edge"], feature_extractor.size["shortest_edge"])
else:
size = (feature_extractor.size["height"], feature_extractor.size["width"])
train_transforms = Compose(
[
ReduceLabels() if data_args.reduce_labels else Identity(),
RandomCrop(size=feature_extractor.size),
RandomCrop(size=size),
RandomHorizontalFlip(flip_prob=0.5),
PILToTensor(),
ConvertImageDtype(torch.float),
@ -411,7 +415,7 @@ def main():
val_transforms = Compose(
[
ReduceLabels() if data_args.reduce_labels else Identity(),
Resize(size=(feature_extractor.size, feature_extractor.size)),
Resize(size=size),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),

View File

@ -405,10 +405,15 @@ def main():
# Define torchvision transforms to be applied to each image + target.
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
if "shortest_edge" in feature_extractor.size:
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
size = (feature_extractor.size["shortest_edge"], feature_extractor.size["shortest_edge"])
else:
size = (feature_extractor.size["height"], feature_extractor.size["width"])
train_transforms = Compose(
[
ReduceLabels() if args.reduce_labels else Identity(),
RandomCrop(size=feature_extractor.size),
RandomCrop(size=size),
RandomHorizontalFlip(flip_prob=0.5),
PILToTensor(),
ConvertImageDtype(torch.float),
@ -420,7 +425,7 @@ def main():
val_transforms = Compose(
[
ReduceLabels() if args.reduce_labels else Identity(),
Resize(size=(feature_extractor.size, feature_extractor.size)),
Resize(size=size),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable, Optional, Union
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
from .feature_extraction_utils import FeatureExtractionMixin
from .utils import logging
@ -48,7 +50,72 @@ class BaseImageProcessor(ImageProcessorMixin):
super().__init__(**kwargs)
def __call__(self, images, **kwargs) -> BatchFeature:
"""Preprocess an image or a batch of images."""
return self.preprocess(images, **kwargs)
def preprocess(self, images, **kwargs) -> BatchFeature:
raise NotImplementedError("Each image processor must implement its own preprocess method")
def get_size_dict(
size: Union[int, Iterable[int], Dict[str, int]] = None,
max_size: Optional[int] = None,
height_width_order: bool = True,
default_to_square: bool = True,
) -> dict:
"""
Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
compatibility with the old feature extractor configs and removes ambiguity over whether the tuple is in (height,
width) or (width, height) format.
- If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
size[0]}` if `height_width_order` is `False`.
- If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
- If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
is set, it is added to the dict as `{"longest_edge": max_size}`.
Args:
size (`Union[int, Iterable[int], Dict[str, int]]`, *optional*):
The `size` parameter to be cast into a size dictionary.
max_size (`Optional[int]`, *optional*):
The `max_size` parameter to be cast into a size dictionary.
height_width_order (`bool`, *optional*, defaults to `True`):
If `size` is a tuple, whether it's in (height, width) or (width, height) order.
default_to_square (`bool`, *optional*, defaults to `True`):
If `size` is an int, whether to default to a square image or not.
"""
# If a dict is passed, we check if it's a valid size dict and then return it.
if isinstance(size, dict):
size_keys = set(size.keys())
if (
size_keys != set(["height", "width"])
and size_keys != set(["shortest_edge"])
and size_keys != set(["shortest_edge", "longest_edge"])
):
raise ValueError(
"The size dict must contain either the keys ('height', 'width') or ('shortest_edge')"
f"or ('shortest_edge', 'longest_edge') but got {size_keys}"
)
return size
# By default, if size is an int we assume it represents a tuple of (size, size).
elif isinstance(size, int) and default_to_square:
if max_size is not None:
raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
size_dict = {"height": size, "width": size}
# In other configs, if size is an int and default_to_square is False, size represents the length of the shortest edge after resizing.
elif isinstance(size, int) and not default_to_square:
if max_size is not None:
size_dict = {"shortest_edge": size, "longest_edge": max_size}
else:
size_dict = {"shortest_edge": size}
elif isinstance(size, (tuple, list)) and height_width_order:
size_dict = {"height": size[0], "width": size[1]}
elif isinstance(size, (tuple, list)) and not height_width_order:
size_dict = {"height": size[1], "width": size[0]}
logger.warning(
"The size parameter should be a dictionary with keys ('height', 'width'), ('shortest_edge', 'longest_edge')"
f" or ('shortest_edge',) got {size}. Setting as {size_dict}.",
)
return size_dict

View File

@ -139,6 +139,9 @@ def to_pil_image(
# If the channel as been moved to first dim, we put it back at the end.
image = to_channel_dimension_format(image, ChannelDimension.LAST)
# If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
# PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed.
do_rescale = isinstance(image.flat[0], float) if do_rescale is None else do_rescale
if do_rescale:
@ -259,6 +262,9 @@ def resize(
if return_numpy:
resized_image = np.array(resized_image)
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
# so we need to add it back if necessary.
resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
resized_image = to_channel_dimension_format(resized_image, data_format)
return resized_image
@ -303,12 +309,14 @@ def normalize(
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
else:
mean = [mean] * num_channels
mean = np.array(mean, dtype=image.dtype)
if isinstance(std, Iterable):
if len(std) != num_channels:
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
else:
std = [std] * num_channels
std = np.array(std, dtype=image.dtype)
if input_data_format == ChannelDimension.LAST:
image = (image - mean) / std
@ -372,6 +380,7 @@ def center_crop(
orig_height, orig_width = get_image_size(image)
crop_height, crop_width = size
crop_height, crop_width = int(crop_height), int(crop_width)
# In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
top = (orig_height - crop_height) // 2

View File

@ -72,7 +72,15 @@ def is_valid_image(img):
def valid_images(imgs):
return all(is_valid_image(img) for img in imgs)
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):
for img in imgs:
if not valid_images(img):
return False
# If not a list of tuple, we have been given a single image or batched tensor of images
elif not is_valid_image(imgs):
return False
return True
def is_batched(img):

View File

@ -14,258 +14,10 @@
# limitations under the License.
"""Feature extractor class for BEiT."""
from typing import List, Optional, Tuple, Union
from ...utils import logging
from .image_processing_beit import BeitImageProcessor
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageFeatureExtractionMixin,
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a BEiT feature extractor.
This feature extractor inherits from [`~feature_extraction_utils.FeatureExtractionMixin`] which contains most of
the main methods. Users should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int` or `Tuple(int)`, *optional*, defaults to 256):
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
image is padded with 0's and then center cropped.
crop_size (`int`, *optional*, defaults to 224):
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with `image_mean` and `image_std`.
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
reduce_labels (`bool`, *optional*, defaults to `False`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
background label will be replaced by 255.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=256,
resample=PILImageResampling.BICUBIC,
do_center_crop=True,
crop_size=224,
do_normalize=True,
image_mean=None,
image_std=None,
reduce_labels=False,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.reduce_labels = reduce_labels
def __call__(
self,
images: ImageInput,
segmentation_maps: ImageInput = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
Optionally, the corresponding semantic segmentation maps with the pixel-wise annotations.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
- **labels** -- Optional labels to be fed to a model (when `segmentation_maps` are provided)
"""
# Input type checking for clearer error
valid_images = False
valid_segmentation_maps = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
# Check that segmentation maps has a valid type
if segmentation_maps is not None:
if isinstance(segmentation_maps, (Image.Image, np.ndarray)) or is_torch_tensor(segmentation_maps):
valid_segmentation_maps = True
elif isinstance(segmentation_maps, (list, tuple)):
if (
len(segmentation_maps) == 0
or isinstance(segmentation_maps[0], (Image.Image, np.ndarray))
or is_torch_tensor(segmentation_maps[0])
):
valid_segmentation_maps = True
if not valid_segmentation_maps:
raise ValueError(
"Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single"
" example),`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of"
" examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
if segmentation_maps is not None:
segmentation_maps = [segmentation_maps]
# reduce zero label if needed
if self.reduce_labels:
if segmentation_maps is not None:
for idx, map in enumerate(segmentation_maps):
if not isinstance(map, np.ndarray):
map = np.array(map)
# avoid using underflow conversion
map[map == 0] = 255
map = map - 1
map[map == 254] = 255
segmentation_maps[idx] = Image.fromarray(map.astype(np.uint8))
# transformations (resizing + center cropping + normalization)
if self.do_resize and self.size is not None and self.resample is not None:
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
if segmentation_maps is not None:
segmentation_maps = [
self.resize(map, size=self.size, resample=self.resample) for map in segmentation_maps
]
if self.do_center_crop and self.crop_size is not None:
images = [self.center_crop(image, self.crop_size) for image in images]
if segmentation_maps is not None:
segmentation_maps = [self.center_crop(map, size=self.crop_size) for map in segmentation_maps]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
if segmentation_maps is not None:
labels = []
for map in segmentation_maps:
if not isinstance(map, np.ndarray):
map = np.array(map)
labels.append(map.astype(np.int64))
# cast to np.int64
data["labels"] = labels
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`BeitForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
BeitFeatureExtractor = BeitImageProcessor

View File

@ -0,0 +1,525 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Beit."""
import warnings
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from transformers.utils import is_torch_available, is_torch_tensor, is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import logging
if is_vision_available():
import PIL
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class BeitImageProcessor(BaseImageProcessor):
r"""
Constructs a BEiT image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
is padded with 0's and then center cropped. Can be overridden by the `do_center_crop` parameter in the
`preprocess` method.
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
Can be overridden by the `crop_size` parameter in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
The mean to use if normalizing the image. This is a float or list of floats of length of the number of
channels of the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
The standard deviation to use if normalizing the image. This is a float or list of floats of length of the
number of channels of the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_reduce_labels (`bool`, *optional*, defaults to `False`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
`preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
rescale_factor: Union[int, float] = 1 / 255,
do_rescale: bool = True,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_reduce_labels: bool = False,
**kwargs
) -> None:
if "reduce_labels" in kwargs:
warnings.warn(
"The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use"
" `do_reduce_labels` instead.",
FutureWarning,
)
do_reduce_labels = kwargs.pop("reduce_labels")
super().__init__(**kwargs)
size = size if size is not None else {"height": 256, "width": 256}
size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_reduce_labels = do_reduce_labels
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image to (size["height"], size["width"]).
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}")
return resize(
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
)
def center_crop(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Center crop an image to (size["height"], size["width"]). If the input size is smaller than `size` along any
edge, the image is padded with 0's and then center cropped.
Args:
image (`np.ndarray`):
Image to center crop.
size (`Dict[str, int]`):
Size of the output image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
image_mean (`float` or `List[float]`):
Image mean.
image_std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def reduce_label(self, label: ImageInput) -> np.ndarray:
label = to_numpy_array(label)
# Avoid using underflow conversion
label[label == 0] = 255
label = label - 1
label[label == 254] = 255
return label
def _preprocess(
self,
image: ImageInput,
do_reduce_labels: bool = None,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
):
if do_reduce_labels:
image = self.reduce_label(image)
if do_resize:
image = self.resize(image=image, size=size, resample=resample)
if do_center_crop:
image = self.center_crop(image=image, size=crop_size)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor)
if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std)
return image
def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
# All transformations expect numpy arrays.
image = to_numpy_array(image)
image = self._preprocess(
image,
do_reduce_labels=False,
do_resize=do_resize,
size=size,
resample=resample,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
)
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
return image
def _preprocess_segmentation_map(
self,
segmentation_map: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_reduce_labels: bool = None,
):
"""Preprocesses a single segmentation map."""
# All transformations expect numpy arrays.
segmentation_map = to_numpy_array(segmentation_map)
# Add an axis to the segmentation maps for transformations.
if segmentation_map.ndim == 2:
segmentation_map = segmentation_map[None, ...]
added_dimension = True
else:
added_dimension = False
segmentation_map = self._preprocess(
image=segmentation_map,
do_reduce_labels=do_reduce_labels,
do_resize=do_resize,
resample=resample,
size=size,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_normalize=False,
do_rescale=False,
)
# Remove extra axis if added
if added_dimension:
segmentation_map = np.squeeze(segmentation_map, axis=0)
segmentation_map = segmentation_map.astype(np.int64)
return segmentation_map
def __call__(self, images, segmentation_maps=None, **kwargs):
# Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both
# be passed in as positional arguments.
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
padded with zeros and then cropped
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g.
ADE20k). The background label will be replaced by 255.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size)
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
if not is_batched(images):
images = [images]
segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
images = [
self._preprocess_image(
image=img,
do_resize=do_resize,
do_center_crop=do_center_crop,
do_rescale=do_rescale,
do_normalize=do_normalize,
resample=resample,
size=size,
rescale_factor=rescale_factor,
crop_size=crop_size,
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
)
for img in images
]
data = {"pixel_values": images}
if segmentation_maps is not None:
segmentation_maps = [
self._preprocess_segmentation_map(
segmentation_map=segmentation_map,
do_reduce_labels=do_reduce_labels,
do_resize=do_resize,
resample=resample,
size=size,
do_center_crop=do_center_crop,
crop_size=crop_size,
)
for segmentation_map in segmentation_maps
]
data["labels"] = segmentation_maps
return BatchFeature(data=data, tensor_type=return_tensors)
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`BeitForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
# TODO: add support for other frameworks
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation

View File

@ -14,155 +14,11 @@
# limitations under the License.
"""Feature extractor class for CLIP."""
from typing import List, Optional, Union
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
from ...utils import TensorType, logging
from ...utils import logging
from .image_processing_clip import CLIPImageProcessor
logger = logging.get_logger(__name__)
class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a CLIP feature extractor.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int`, *optional*, defaults to 224):
Resize the input to the given size. Only has an effect if `do_resize` is set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
image is padded with 0's and then center cropped.
crop_size (`int`, *optional*, defaults to 224):
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with `image_mean` and `image_std`.
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
convert_rgb (`bool`, defaults to `True`):
Whether or not to convert `PIL.Image.Image` into `RGB` format
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=224,
resample=PILImageResampling.BICUBIC,
do_center_crop=True,
crop_size=224,
do_normalize=True,
image_mean=None,
image_std=None,
do_convert_rgb=True,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
self.do_convert_rgb = do_convert_rgb
def __call__(
self,
images: Union[
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
],
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model.
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (convert rgb + resizing + center cropping + normalization)
if self.do_convert_rgb:
images = [self.convert_rgb(image) for image in images]
if self.do_resize and self.size is not None and self.resample is not None:
images = [
self.resize(image=image, size=self.size, resample=self.resample, default_to_square=False)
for image in images
]
if self.do_center_crop and self.crop_size is not None:
images = [self.center_crop(image, self.crop_size) for image in images]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
CLIPFeatureExtractor = CLIPImageProcessor

View File

@ -0,0 +1,342 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for CLIP."""
from typing import Any, Dict, List, Optional, Union
import numpy as np
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_crop,
get_resize_output_image_size,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images
from ...utils import logging
from ...utils.import_utils import is_vision_available
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
"""
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
Args:
image (`PIL.Image.Image`):
The image to convert.
"""
if not isinstance(image, PIL.Image.Image):
return image
return image.convert("RGB")
class CLIPImageProcessor(BaseImageProcessor):
r"""
Constructs a CLIP image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
`preprocess` method.
crop_size (`Dict[str, int]` *optional*, defaults to 224):
Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize:
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Image standard deviation.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
self.do_convert_rgb = do_convert_rgb
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
resized to keep the input aspect ratio.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
def center_crop(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Center crop an image. If the image is too small to be cropped to the size given, it will be padded (so the
returned result will always be of size `size`).
Args:
image (`np.ndarray`):
Image to center crop.
size (`Dict[str, int]`):
Size of the output image in the form of a dictionary with keys `height` and `width`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
image_mean (`float` or `List[float]`):
Image mean.
image_std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: int = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
the longest edge resized to keep the input aspect ratio.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: defaults to the channel dimension format of the input image.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images]
if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images]
if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

View File

@ -14,157 +14,11 @@
# limitations under the License.
"""Feature extractor class for ConvNeXT."""
from typing import Optional, Union
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ImageFeatureExtractionMixin,
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, logging
from ...utils import logging
from .image_processing_convnext import ConvNextImageProcessor
logger = logging.get_logger(__name__)
class ConvNextFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a ConvNeXT feature extractor.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize (and optionally center crop) the input to a certain `size`.
size (`int`, *optional*, defaults to 224):
Resize the input to the given size. If 384 or larger, the image is resized to (`size`, `size`). Else, the
smaller edge of the image will be matched to int(`size`/ `crop_pct`), after which the image is cropped to
`size`. Only has an effect if `do_resize` is set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
crop_pct (`float`, *optional*):
The percentage of the image to crop. If `None`, then a cropping percentage of 224 / 256 is used. Only has
an effect if `do_resize` is set to `True` and `size` < 384.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=224,
resample=PILImageResampling.BICUBIC,
crop_pct=None,
do_normalize=True,
image_mean=None,
image_std=None,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.crop_pct = crop_pct
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
def __call__(
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (resizing and optional center cropping + normalization)
if self.do_resize and self.size is not None:
if self.size >= 384:
# warping (no cropping) when evaluated at 384 or larger
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
else:
if self.crop_pct is None:
self.crop_pct = 224 / 256
size = int(self.size / self.crop_pct)
# to maintain same ratio w.r.t. 224 images
images = [
self.resize(image=image, size=size, default_to_square=False, resample=self.resample)
for image in images
]
images = [self.center_crop(image=image, size=self.size) for image in images]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
ConvNextFeatureExtractor = ConvNextImageProcessor

View File

@ -0,0 +1,310 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for ConvNeXT."""
from typing import Dict, List, Optional, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_crop,
get_resize_output_image_size,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import logging
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
class ConvNextImageProcessor(BaseImageProcessor):
r"""
Constructs a ConvNeXT image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden
by `do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`):
Resolution of the output image after `resize` is applied. If `size["shortest_edge"]` >= 384, the image is
resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the image will
be matched to `int(size["shortest_edge"]/crop_pct)`, after which the image is cropped to
`(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. Can
be overriden by `size` in the `preprocess` method.
crop_pct (`float` *optional*, defaults to 244 / 256):
Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be
overriden by `crop_pct` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overriden by `resample` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
crop_pct: float = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 384}
size = get_size_dict(size, default_to_square=False)
self.do_resize = do_resize
self.size = size
# Default value set here for backwards compatibility where the value in config is None
self.crop_pct = crop_pct if crop_pct is not None else 224 / 256
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
crop_pct: float,
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If
`size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`.
Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`,
after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`.
crop_pct (`float`):
Percentage of the image to crop. Only has an effect if size < 384.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size:
raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}")
shortest_edge = size["shortest_edge"]
if shortest_edge < 384:
# maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
resize_shortest_edge = int(shortest_edge / crop_pct)
resize_size = get_resize_output_image_size(image, size=resize_shortest_edge, default_to_square=False)
image = resize(image=image, size=resize_size, resample=resample, data_format=data_format, **kwargs)
# then crop to (shortest_edge, shortest_edge)
return center_crop(image=image, size=(shortest_edge, shortest_edge), data_format=data_format, **kwargs)
else:
# warping (no cropping) when evaluated at 384 or larger
return resize(
image, size=(shortest_edge, shortest_edge), resample=resample, data_format=data_format, **kwargs
)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
image_mean (`float` or `List[float]`):
Image mean.
image_std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
crop_pct: float = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
`(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`.
crop_pct (`float`, *optional*, defaults to `self.crop_pct`):
Percentage of the image to crop if size < 384.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
crop_pct = crop_pct if crop_pct is not None else self.crop_pct
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_resize and size["shortest_edge"] < 384 and crop_pct is None:
raise ValueError("crop_pct must be specified if size < 384.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_resize:
images = [self.resize(image=image, size=size, crop_pct=crop_pct, resample=resample) for image in images]
if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

View File

@ -308,7 +308,7 @@ def convert_cvt_checkpoint(cvt_model, image_size, cvt_file_name, pytorch_dump_fo
model = CvtForImageClassification(config)
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/convnext-base-224-22k-1k")
feature_extractor.size = image_size
feature_extractor.size["shortest_edge"] = image_size
original_weights = torch.load(cvt_file_name, map_location=torch.device("cpu"))
huggingface_weights = OrderedDict()

View File

@ -14,150 +14,10 @@
# limitations under the License.
"""Feature extractor class for DeiT."""
from typing import Optional, Union
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ImageFeatureExtractionMixin,
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, logging
from ...utils import logging
from .image_processing_deit import DeiTImageProcessor
logger = logging.get_logger(__name__)
class DeiTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a DeiT feature extractor.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int` or `Tuple(int)`, *optional*, defaults to 256):
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
image is padded with 0's and then center cropped.
crop_size (`int`, *optional*, defaults to 224):
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with `image_mean` and `image_std`.
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=256,
resample=PILImageResampling.BICUBIC,
do_center_crop=True,
crop_size=224,
do_normalize=True,
image_mean=None,
image_std=None,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
def __call__(
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (resizing + center cropping + normalization)
if self.do_resize and self.size is not None and self.resample is not None:
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
if self.do_center_crop and self.crop_size is not None:
images = [self.center_crop(image, self.crop_size) for image in images]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
DeiTFeatureExtractor = DeiTImageProcessor

View File

@ -0,0 +1,315 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for DeiT."""
from typing import Dict, List, Optional, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import logging
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
class DeiTImageProcessor(BaseImageProcessor):
r"""
Constructs a DeiT image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in `preprocess`.
size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
Size of the image after `resize`. Can be overridden by `size` in `preprocess`.
resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in `preprocess`.
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
Desired output size when applying center-cropping. Can be overridden by `crop_size` in `preprocess`.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PIL.Image.BICUBIC,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
rescale_factor: Union[int, float] = 1 / 255,
do_rescale: bool = True,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 256, "width": 256}
size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PIL.Image.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])` using the specified resampling filter.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
return resize(
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
)
def center_crop(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Center crop an image to `(crop_size["height"], crop_size["width"])`. If the input size is smaller than
`crop_size` along any edge, the image is padded with 0's and then center cropped.
Args:
image (`np.ndarray`):
Image to center crop.
size (`Dict[str, int]`):
Size of the output image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
image_mean (`float` or `List[float]`):
Image mean.
image_std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample=None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after `resize`.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
PILImageResampling filter to use if resizing the image Only has an effect if `do_resize` is set to
`True`.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
padded with zeros and then cropped
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- `None`: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
size = size if size is not None else self.size
size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images]
if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images]
if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

View File

@ -14,235 +14,11 @@
# limitations under the License.
"""Feature extractor class for DPT."""
from typing import List, Optional, Tuple, Union
from ...utils import logging
from .image_processing_dpt import DPTImageProcessor
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageFeatureExtractionMixin,
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class DPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a DPT feature extractor.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size ('int' or `Tuple(int)`, *optional*, defaults to 384):
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
set to `True`.
ensure_multiple_of (`int`, *optional*, defaults to 1):
Ensure that the input is resized to a multiple of this value. Only has an effect if `do_resize` is set to
`True`.
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
Whether to keep the aspect ratio of the input. Only has an effect if `do_resize` is set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=384,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resample=PILImageResampling.BILINEAR,
do_normalize=True,
image_mean=None,
image_std=None,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.keep_aspect_ratio = keep_aspect_ratio
self.ensure_multiple_of = ensure_multiple_of
self.resample = resample
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def constrain_to_multiple_of(self, size, min_val=0, max_val=None):
y = (np.round(size / self.ensure_multiple_of) * self.ensure_multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(size / self.ensure_multiple_of) * self.ensure_multiple_of).astype(int)
if y < min_val:
y = (np.ceil(size / self.ensure_multiple_of) * self.ensure_multiple_of).astype(int)
return y
def update_size(self, image):
image = self.to_pil_image(image)
width, height = image.size
size = self.size
if isinstance(size, list):
size = tuple(size)
if isinstance(size, int) or len(size) == 1:
size = (size, size)
# determine new width and height
scale_width = size[0] / width
scale_height = size[1] / height
if self.keep_aspect_ratio:
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
new_width = self.constrain_to_multiple_of(scale_width * width)
new_height = self.constrain_to_multiple_of(scale_height * height)
return (new_width, new_height)
def __call__(
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~file_utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (resizing + normalization)
if self.do_resize and self.size is not None:
for idx, image in enumerate(images):
size = self.update_size(image)
images[idx] = self.resize(image, size=size, resample=self.resample)
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`DPTForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
DPTFeatureExtractor = DPTImageProcessor

View File

@ -0,0 +1,384 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for DPT."""
import math
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
is_batched,
is_torch_available,
is_torch_tensor,
to_numpy_array,
valid_images,
)
from ...utils import logging
if is_torch_available():
import torch
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
def get_resize_output_image_size(
input_image: np.ndarray, output_size: Union[int, Iterable[int]], keep_aspect_ratio: bool, multiple: int
) -> Tuple[int, int]:
def constraint_to_multiple_of(val, multiple, min_val=0, max_val=None):
x = round(val / multiple) * multiple
if max_val is not None and x > max_val:
x = math.floor(val / multiple) * multiple
if x < min_val:
x = math.ceil(val / multiple) * multiple
return x
output_size = (output_size, output_size) if isinstance(output_size, int) else output_size
input_height, input_width = get_image_size(input_image)
output_height, output_width = output_size
# determine new height and width
scale_height = output_height / input_height
scale_width = output_width / input_width
if keep_aspect_ratio:
# scale as little as possible
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
new_height = constraint_to_multiple_of(scale_height * input_height, multiple=multiple)
new_width = constraint_to_multiple_of(scale_width * input_width, multiple=multiple)
return (new_height, new_width)
class DPTImageProcessor(BaseImageProcessor):
r"""
Constructs a DPT image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`.
size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 384}`):
Size of the image after resizing. Can be overidden by `size` in `preprocess`.
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
be overidden by `keep_aspect_ratio` in `preprocess`.
ensure_multiple_of (`int`, *optional*, defaults to `1`):
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden
by `ensure_multiple_of` in `preprocess`.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overidden by `do_rescale` in
`preprocess`.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overidden by `rescale_factor` in `preprocess`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
keep_aspect_ratio: bool = False,
ensure_multiple_of: int = 1,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 384, "width": 384}
size = get_size_dict(size)
self.do_resize = do_resize
self.size = size
self.keep_aspect_ratio = keep_aspect_ratio
self.ensure_multiple_of = ensure_multiple_of
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
keep_aspect_ratio: bool = False,
ensure_multiple_of: int = 1,
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image
is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is
set, the image is resized to a size that is a multiple of this value.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Target size of the output image.
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
ensure_multiple_of (`int`, *optional*, defaults to `1`):
The image is resized to a size that is a multiple of this value.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size
specified in `size`.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
output_size = get_resize_output_image_size(
image,
output_size=(size["height"], size["width"]),
keep_aspect_ratio=keep_aspect_ratio,
multiple=ensure_multiple_of,
)
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
image_mean (`float` or `List[float]`):
Image mean.
image_std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: int = None,
keep_aspect_ratio: bool = None,
ensure_multiple_of: int = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after reszing. If `keep_aspect_ratio` is `True`, the image is resized to the largest
possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is set, the image is
resized to a size that is a multiple of this value.
keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`):
Whether to keep the aspect ratio of the image. If False, the image will be resized to (size, size). If
True, the image will be resized to keep the aspect ratio and the size will be the maximum possible.
ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`):
Ensure that the image size is a multiple of this value.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size)
keep_aspect_ratio = keep_aspect_ratio if keep_aspect_ratio is not None else self.keep_aspect_ratio
ensure_multiple_of = ensure_multiple_of if ensure_multiple_of is not None else self.ensure_multiple_of
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images]
if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Args:
Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
outputs ([`DPTForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
# TODO: add support for other frameworks
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation

View File

@ -14,344 +14,10 @@
# limitations under the License.
"""Feature extractor class for FLAVA."""
import math
import random
from functools import lru_cache
from typing import Any, List, Optional, Tuple, Union
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
from ...utils import TensorType, logging
from ...utils import logging
from .image_processing_flava import FlavaImageProcessor
logger = logging.get_logger(__name__)
# These values are taken from CLIP
FLAVA_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073]
FLAVA_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711]
FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0]
FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0]
LOGIT_LAPLACE_EPS: float = 0.1
# Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py
class FlavaMaskingGenerator:
def __init__(
self,
input_size: Union[int, Tuple[int, int]] = 14,
total_mask_patches: int = 75,
mask_group_max_patches: Optional[int] = None,
mask_group_min_patches: int = 16,
mask_group_min_aspect_ratio: Optional[float] = 0.3,
mask_group_max_aspect_ratio: float = None,
):
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.height, self.width = input_size
self.num_patches = self.height * self.width
self.total_mask_patches = total_mask_patches
self.mask_group_min_patches = mask_group_min_patches
self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
def __repr__(self):
repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
self.height,
self.width,
self.mask_group_min_patches,
self.mask_group_max_patches,
self.total_mask_patches,
self.log_aspect_ratio[0],
self.log_aspect_ratio[1],
)
return repr_str
def get_shape(self):
return self.height, self.width
def _mask(self, mask, max_mask_patches):
delta = 0
for _attempt in range(10):
target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
height = int(round(math.sqrt(target_area * aspect_ratio)))
width = int(round(math.sqrt(target_area / aspect_ratio)))
if width < self.width and height < self.height:
top = random.randint(0, self.height - height)
left = random.randint(0, self.width - width)
num_masked = mask[top : top + height, left : left + width].sum()
# Overlap
if 0 < height * width - num_masked <= max_mask_patches:
for i in range(top, top + height):
for j in range(left, left + width):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
def __call__(self):
mask = np.zeros(shape=self.get_shape(), dtype=int)
mask_count = 0
while mask_count < self.total_mask_patches:
max_mask_patches = self.total_mask_patches - mask_count
max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
delta = self._mask(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
return mask
class FlavaFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a FLAVA feature extractor.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int`, *optional*, defaults to 224):
Resize the input to the given size. Only has an effect if `do_resize` is set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
image is padded with 0's and then center cropped.
crop_size (`int`, *optional*, defaults to 224):
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with `image_mean` and `image_std`.
image_mean (`Tuple[float, float, float]`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`Tuple[float, float, float]`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
input_size_patches (`int`, *optional*, defaults to 14):
Number of patches in the image in height and width direction. 14x14 = 196 total patches.
total_mask_patches (`int`, *optional*, defaults to 75):
Total number of patches that should be masked.
mask_group_min_patches (`int`, *optional*, defaults to 16):
Minimum number of patches that should be masked.
mask_group_max_patches (`int`, *optional*, defaults to None):
Maximum number of patches that should be masked.
mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
Minimum aspect ratio of the mask window.
mask_group_max_aspect_ratio (`float`, *optional*, defaults to None):
Maximum aspect ratio of the mask window
codebook_do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input for codebook to a certain `codebook_size`.
codebook_size (`int`, *optional*, defaults to 224):
Resize the input for codebook to the given size. Only has an effect if `codebook_do_resize` is set to
`True`.
codebook_resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to crop the input for codebook at the center. If the input size is smaller than
`codebook_crop_size` along any edge, the image is padded with 0's and then center cropped.
codebook_crop_size (`int`, *optional*, defaults to 224):
Desired output size for codebook input when applying center-cropping. Only has an effect if
`codebook_do_center_crop` is set to `True`.
codebook_do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`.
codebook_image_mean (`Tuple[float, float, float]`, *optional*, defaults to `[0, 0, 0]`):
The sequence of means for each channel, to be used when normalizing images for codebook.
codebook_image_std (`Tuple[float, float, float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
The sequence of standard deviations for each channel, to be used when normalizing images for codebook.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Union[int, Tuple[int, int]] = 224,
resample: int = PILImageResampling.BICUBIC,
do_center_crop: bool = True,
crop_size: Union[int, Tuple[int, int]] = 224,
do_normalize: bool = True,
image_mean: Tuple[float, float, float] = FLAVA_IMAGE_MEAN,
image_std: Tuple[float, float, float] = FLAVA_IMAGE_STD,
# Mask related params
input_size_patches: int = 14,
total_mask_patches: int = 75,
mask_group_min_patches: int = 16,
mask_group_max_patches: Optional[int] = None,
mask_group_min_aspect_ratio: float = 0.3,
mask_group_max_aspect_ratio: Optional[float] = None,
# Codebook related params
codebook_do_resize: bool = True,
codebook_size: bool = 112,
codebook_resample: int = PILImageResampling.LANCZOS,
codebook_do_center_crop: bool = True,
codebook_crop_size: int = 112,
codebook_do_map_pixels: bool = True,
codebook_do_normalize: bool = True,
codebook_image_mean: Tuple[float, float, float] = FLAVA_CODEBOOK_MEAN,
codebook_image_std: Tuple[float, float, float] = FLAVA_CODEBOOK_STD,
**kwargs: Any,
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.input_size_patches = input_size_patches
self.total_mask_patches = total_mask_patches
self.mask_group_min_patches = mask_group_min_patches
self.mask_group_max_patches = mask_group_max_patches
self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio
self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio
self.codebook_do_resize = codebook_do_resize
self.codebook_size = codebook_size
self.codebook_resample = codebook_resample
self.codebook_do_center_crop = codebook_do_center_crop
self.codebook_crop_size = codebook_crop_size
self.codebook_do_map_pixels = codebook_do_map_pixels
self.codebook_do_normalize = codebook_do_normalize
self.codebook_image_mean = codebook_image_mean
self.codebook_image_std = codebook_image_std
@property
@lru_cache()
def masking_generator(self):
return FlavaMaskingGenerator(
input_size=self.input_size_patches,
total_mask_patches=self.total_mask_patches,
mask_group_min_patches=self.mask_group_min_patches,
mask_group_max_patches=self.mask_group_max_patches,
mask_group_min_aspect_ratio=self.mask_group_min_aspect_ratio,
mask_group_max_aspect_ratio=self.mask_group_max_aspect_ratio,
)
def map_pixels(self, x):
return (1 - 2 * LOGIT_LAPLACE_EPS) * x + LOGIT_LAPLACE_EPS
def __call__(
self,
images: Union[
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
],
return_image_mask: Optional[bool] = None,
return_codebook_pixels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs: Any
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_image_mask (`bool`, *optional*, defaults to None):
If True, the processor will return `bool_masked_pos` suggesting masks for image's patch version.
return_codebook_pixels (`bool`, *optional*, defaults to None):
If True, the processor will return `codebook_pixel_values` providing image pixels to be used with the
default FLAVA codebook. Used in pretraining by Masked Image Modeling (MIM) loss.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model.
"""
# Input type checking for clearer error
if isinstance(images, (list, tuple)) and len(images) != 0:
self._ensure_format_supported(images[0])
else:
self._ensure_format_supported(images)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
images_for_codebook = images
# transformations (resizing + center cropping + normalization)
if self.do_resize and self.size is not None and self.resample is not None:
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
if self.do_center_crop and self.crop_size is not None:
images = [self.center_crop(image, self.crop_size) for image in images]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
if return_codebook_pixels:
images = images_for_codebook
if self.codebook_do_resize and self.codebook_size is not None and self.codebook_resample is not None:
images = [
self.resize(image=image, size=self.codebook_size, resample=self.codebook_resample)
for image in images
]
if self.codebook_do_center_crop and self.codebook_crop_size is not None:
images = [self.center_crop(image, self.codebook_crop_size) for image in images]
if self.codebook_do_normalize:
images = [
self.normalize(image=image, mean=self.codebook_image_mean, std=self.codebook_image_std)
for image in images
]
if self.codebook_do_map_pixels:
images = [self.map_pixels(image) for image in images]
data["codebook_pixel_values"] = images
if return_image_mask:
masks = [self.masking_generator() for _ in images]
data["bool_masked_pos"] = masks
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
FlavaFeatureExtractor = FlavaImageProcessor

View File

@ -0,0 +1,696 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Flava."""
import math
import random
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images
from ...utils import logging
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
# These values are taken from CLIP
FLAVA_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073]
FLAVA_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711]
FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0]
FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0]
LOGIT_LAPLACE_EPS: float = 0.1
# Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py
class FlavaMaskingGenerator:
def __init__(
self,
input_size: Union[int, Tuple[int, int]] = 14,
total_mask_patches: int = 75,
mask_group_max_patches: Optional[int] = None,
mask_group_min_patches: int = 16,
mask_group_min_aspect_ratio: Optional[float] = 0.3,
mask_group_max_aspect_ratio: float = None,
):
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.height, self.width = input_size
self.num_patches = self.height * self.width
self.total_mask_patches = total_mask_patches
self.mask_group_min_patches = mask_group_min_patches
self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
def __repr__(self):
repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
self.height,
self.width,
self.mask_group_min_patches,
self.mask_group_max_patches,
self.total_mask_patches,
self.log_aspect_ratio[0],
self.log_aspect_ratio[1],
)
return repr_str
def get_shape(self):
return self.height, self.width
def _mask(self, mask, max_mask_patches):
delta = 0
for _attempt in range(10):
target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
height = int(round(math.sqrt(target_area * aspect_ratio)))
width = int(round(math.sqrt(target_area / aspect_ratio)))
if width < self.width and height < self.height:
top = random.randint(0, self.height - height)
left = random.randint(0, self.width - width)
num_masked = mask[top : top + height, left : left + width].sum()
# Overlap
if 0 < height * width - num_masked <= max_mask_patches:
for i in range(top, top + height):
for j in range(left, left + width):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
def __call__(self):
mask = np.zeros(shape=self.get_shape(), dtype=int)
mask_count = 0
while mask_count < self.total_mask_patches:
max_mask_patches = self.total_mask_patches - mask_count
max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
delta = self._mask(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
return mask
class FlavaImageProcessor(BaseImageProcessor):
r"""
Constructs a Flava image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in `preprocess`.
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after resizing. Can be overridden by the `size` parameter in `preprocess`.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in
`preprocess`.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the images. Can be overridden by the `do_center_crop` parameter in `preprocess`.
crop_size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
Size of image after the center crop `(crop_size["height"], crop_size["width"])`. Can be overridden by the
`crop_size` parameter in `preprocess`.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in `preprocess`.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in
`preprocess`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in `preprocess`.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
return_image_mask (`bool`, *optional*, defaults to `False`):
Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`.
input_size_patches (`int`, *optional*, defaults to 14):
Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden
by the `input_size_patches` parameter in `preprocess`.
total_mask_patches (`int`, *optional*, defaults to 75):
Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in
`preprocess`.
mask_group_min_patches (`int`, *optional*, defaults to 16):
Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches`
parameter in `preprocess`.
mask_group_max_patches (`int`, *optional*):
Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches`
parameter in `preprocess`.
mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter
in `preprocess`.
mask_group_max_aspect_ratio (`float`, *optional*):
Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter
in `preprocess`.
codebook_do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize`
parameter in `preprocess`. `codebook_size`.
codebook_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in
`preprocess`.
codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample`
parameter in `preprocess`.
codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to crop the input for codebook at the center. If the input size is smaller than
`codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be
overridden by the `codebook_do_center_crop` parameter in `preprocess`.
codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
Desired output size for codebook input when applying center-cropping. Can be overridden by the
`codebook_crop_size` parameter in `preprocess`.
codebook_do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be
overridden by the `codebook_do_rescale` parameter in `preprocess`.
codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Defines the scale factor to use if rescaling the codebook image. Can be overridden by the
`codebook_rescale_factor` parameter in `preprocess`.
codebook_do_map_pixels (`bool`, *optional*, defaults to `True`):
Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the
`codebook_do_map_pixels` parameter in `preprocess`.
codebook_do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can
be overridden by the `codebook_do_normalize` parameter in `preprocess`.
codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`):
The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden
by the `codebook_image_mean` parameter in `preprocess`.
codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can
be overridden by the `codebook_image_std` parameter in `preprocess`.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, Iterable[float]]] = None,
image_std: Optional[Union[float, Iterable[float]]] = None,
# Mask related params
return_image_mask: bool = False,
input_size_patches: int = 14,
total_mask_patches: int = 75,
mask_group_min_patches: int = 16,
mask_group_max_patches: Optional[int] = None,
mask_group_min_aspect_ratio: float = 0.3,
mask_group_max_aspect_ratio: Optional[float] = None,
# Codebook related params
return_codebook_pixels: bool = False,
codebook_do_resize: bool = True,
codebook_size: bool = None,
codebook_resample: int = PILImageResampling.LANCZOS,
codebook_do_center_crop: bool = True,
codebook_crop_size: int = None,
codebook_do_rescale: bool = True,
codebook_rescale_factor: Union[int, float] = 1 / 255,
codebook_do_map_pixels: bool = True,
codebook_do_normalize: bool = True,
codebook_image_mean: Optional[Union[float, Iterable[float]]] = None,
codebook_image_std: Optional[Union[float, Iterable[float]]] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112}
codebook_size = get_size_dict(codebook_size)
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112}
codebook_crop_size = get_size_dict(codebook_crop_size)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else FLAVA_IMAGE_MEAN
self.image_std = image_std if image_std is not None else FLAVA_IMAGE_STD
self.return_image_mask = return_image_mask
self.input_size_patches = input_size_patches
self.total_mask_patches = total_mask_patches
self.mask_group_min_patches = mask_group_min_patches
self.mask_group_max_patches = mask_group_max_patches
self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio
self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio
self.return_codebook_pixels = return_codebook_pixels
self.codebook_do_resize = codebook_do_resize
self.codebook_size = codebook_size
self.codebook_resample = codebook_resample
self.codebook_do_center_crop = codebook_do_center_crop
self.codebook_crop_size = codebook_crop_size
self.codebook_do_rescale = codebook_do_rescale
self.codebook_rescale_factor = codebook_rescale_factor
self.codebook_do_map_pixels = codebook_do_map_pixels
self.codebook_do_normalize = codebook_do_normalize
self.codebook_image_mean = codebook_image_mean
self.codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else FLAVA_CODEBOOK_MEAN
self.codebook_image_std = codebook_image_std if codebook_image_std is not None else FLAVA_CODEBOOK_STD
@lru_cache()
def masking_generator(
self,
input_size_patches,
total_mask_patches,
mask_group_min_patches,
mask_group_max_patches,
mask_group_min_aspect_ratio,
mask_group_max_aspect_ratio,
) -> FlavaMaskingGenerator:
return FlavaMaskingGenerator(
input_size=input_size_patches,
total_mask_patches=total_mask_patches,
mask_group_min_patches=mask_group_min_patches,
mask_group_max_patches=mask_group_max_patches,
mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
)
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must contain 'height' and 'width' keys. Got {size.keys()}")
return resize(
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
)
def center_crop(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
any edge, the image is padded with 0's and then center cropped.
Args:
image (`np.ndarray`):
Image to center crop.
size (`Dict[str, int]`):
Size of the output image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
image_mean (`float` or `List[float]`):
Image mean.
image_std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def map_pixels(self, image: np.ndarray) -> np.ndarray:
return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS
def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_map_pixels: bool = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
) -> np.ndarray:
"""Preprocesses a single image."""
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# All transformations expect numpy arrays.
image = to_numpy_array(image)
if do_resize:
image = self.resize(image=image, size=size, resample=resample)
if do_center_crop:
image = self.center_crop(image=image, size=crop_size)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor)
if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std)
if do_map_pixels:
image = self.map_pixels(image)
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
return image
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
# Mask related params
return_image_mask: Optional[bool] = None,
input_size_patches: Optional[int] = None,
total_mask_patches: Optional[int] = None,
mask_group_min_patches: Optional[int] = None,
mask_group_max_patches: Optional[int] = None,
mask_group_min_aspect_ratio: Optional[float] = None,
mask_group_max_aspect_ratio: Optional[float] = None,
# Codebook related params
return_codebook_pixels: Optional[bool] = None,
codebook_do_resize: Optional[bool] = None,
codebook_size: Optional[Dict[str, int]] = None,
codebook_resample: Optional[int] = None,
codebook_do_center_crop: Optional[bool] = None,
codebook_crop_size: Optional[Dict[str, int]] = None,
codebook_do_rescale: Optional[bool] = None,
codebook_rescale_factor: Optional[float] = None,
codebook_do_map_pixels: Optional[bool] = None,
codebook_do_normalize: Optional[bool] = None,
codebook_image_mean: Optional[Iterable[float]] = None,
codebook_image_std: Optional[Iterable[float]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
return_image_mask (`bool`, *optional*, defaults to `self.return_image_mask`):
Whether to return the image mask.
input_size_patches (`int`, *optional*, defaults to `self.input_size_patches`):
Size of the patches to extract from the image.
total_mask_patches (`int`, *optional*, defaults to `self.total_mask_patches`):
Total number of patches to extract from the image.
mask_group_min_patches (`int`, *optional*, defaults to `self.mask_group_min_patches`):
Minimum number of patches to extract from the image.
mask_group_max_patches (`int`, *optional*, defaults to `self.mask_group_max_patches`):
Maximum number of patches to extract from the image.
mask_group_min_aspect_ratio (`float`, *optional*, defaults to `self.mask_group_min_aspect_ratio`):
Minimum aspect ratio of the patches to extract from the image.
mask_group_max_aspect_ratio (`float`, *optional*, defaults to `self.mask_group_max_aspect_ratio`):
Maximum aspect ratio of the patches to extract from the image.
return_codebook_pixels (`bool`, *optional*, defaults to `self.return_codebook_pixels`):
Whether to return the codebook pixels.
codebook_do_resize (`bool`, *optional*, defaults to `self.codebook_do_resize`):
Whether to resize the codebook pixels.
codebook_size (`Dict[str, int]`, *optional*, defaults to `self.codebook_size`):
Size of the codebook pixels.
codebook_resample (`int`, *optional*, defaults to `self.codebook_resample`):
Resampling filter to use if resizing the codebook pixels. This can be one of the enum
`PILImageResampling`, Only has an effect if `codebook_do_resize` is set to `True`.
codebook_do_center_crop (`bool`, *optional*, defaults to `self.codebook_do_center_crop`):
Whether to center crop the codebook pixels.
codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `self.codebook_crop_size`):
Size of the center crop of the codebook pixels. Only has an effect if `codebook_do_center_crop` is set
to `True`.
codebook_do_rescale (`bool`, *optional*, defaults to `self.codebook_do_rescale`):
Whether to rescale the codebook pixels values between [0 - 1].
codebook_rescale_factor (`float`, *optional*, defaults to `self.codebook_rescale_factor`):
Rescale factor to rescale the codebook pixels by if `codebook_do_rescale` is set to `True`.
codebook_do_map_pixels (`bool`, *optional*, defaults to `self.codebook_do_map_pixels`):
Whether to map the codebook pixels values.
codebook_do_normalize (`bool`, *optional*, defaults to `self.codebook_do_normalize`):
Whether to normalize the codebook pixels.
codebook_image_mean (`float` or `List[float]`, *optional*, defaults to `self.codebook_image_mean`):
Codebook pixels mean to normalize the codebook pixels by if `codebook_do_normalize` is set to `True`.
codebook_image_std (`float` or `List[float]`, *optional*, defaults to `self.codebook_image_std`):
Codebook pixels standard deviation to normalize the codebook pixels by if `codebook_do_normalize` is
set to `True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size)
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
return_image_mask = return_image_mask if return_image_mask is not None else self.return_image_mask
input_size_patches = input_size_patches if input_size_patches is not None else self.input_size_patches
total_mask_patches = total_mask_patches if total_mask_patches is not None else self.total_mask_patches
mask_group_min_patches = (
mask_group_min_patches if mask_group_min_patches is not None else self.mask_group_min_patches
)
mask_group_max_patches = (
mask_group_max_patches if mask_group_max_patches is not None else self.mask_group_max_patches
)
mask_group_min_aspect_ratio = (
mask_group_min_aspect_ratio
if mask_group_min_aspect_ratio is not None
else self.mask_group_min_aspect_ratio
)
mask_group_max_aspect_ratio = (
mask_group_max_aspect_ratio
if mask_group_max_aspect_ratio is not None
else self.mask_group_max_aspect_ratio
)
return_codebook_pixels = (
return_codebook_pixels if return_codebook_pixels is not None else self.return_codebook_pixels
)
codebook_do_resize = codebook_do_resize if codebook_do_resize is not None else self.codebook_do_resize
codebook_size = codebook_size if codebook_size is not None else self.codebook_size
codebook_size = get_size_dict(codebook_size)
codebook_resample = codebook_resample if codebook_resample is not None else self.codebook_resample
codebook_do_rescale = codebook_do_rescale if codebook_do_rescale is not None else self.codebook_do_rescale
codebook_rescale_factor = (
codebook_rescale_factor if codebook_rescale_factor is not None else self.codebook_rescale_factor
)
codebook_do_center_crop = (
codebook_do_center_crop if codebook_do_center_crop is not None else self.codebook_do_center_crop
)
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else self.codebook_crop_size
codebook_crop_size = get_size_dict(codebook_crop_size)
codebook_do_map_pixels = (
codebook_do_map_pixels if codebook_do_map_pixels is not None else self.codebook_do_map_pixels
)
codebook_do_normalize = (
codebook_do_normalize if codebook_do_normalize is not None else self.codebook_do_normalize
)
codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else self.codebook_image_mean
codebook_image_std = codebook_image_std if codebook_image_std is not None else self.codebook_image_std
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
processed_images = [
self._preprocess_image(
image=img,
do_resize=do_resize,
size=size,
resample=resample,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_map_pixels=False,
data_format=data_format,
)
for img in images
]
data = {"pixel_values": processed_images}
if return_codebook_pixels:
codebook_images = [
self._preprocess_image(
image=img,
do_resize=codebook_do_resize,
size=codebook_size,
resample=codebook_resample,
do_center_crop=codebook_do_center_crop,
crop_size=codebook_crop_size,
do_rescale=codebook_do_rescale,
rescale_factor=codebook_rescale_factor,
do_normalize=codebook_do_normalize,
image_mean=codebook_image_mean,
image_std=codebook_image_std,
do_map_pixels=codebook_do_map_pixels,
data_format=data_format,
)
for img in images
]
data["codebook_pixel_values"] = codebook_images
if return_image_mask:
mask_generator = self.masking_generator(
input_size_patches=input_size_patches,
total_mask_patches=total_mask_patches,
mask_group_min_patches=mask_group_min_patches,
mask_group_max_patches=mask_group_max_patches,
mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
)
masks = [mask_generator() for _ in images]
data["bool_masked_pos"] = masks
return BatchFeature(data=data, tensor_type=return_tensors)

View File

@ -37,16 +37,16 @@ class GLPNImageProcessor(BaseImageProcessor):
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Set the class default for the `do_resize` parameter. Controls whether to resize the image's (height, width)
dimensions, rounding them down to the closest multiple of `size_divisor`.
Whether to resize the image's (height, width) dimensions, rounding them down to the closest multiple of
`size_divisor`. Can be overridden by `do_resize` in `preprocess`.
size_divisor (`int`, *optional*, defaults to 32):
Set the class default for the `size_divisor` parameter. When `do_resize` is `True`, images are resized so
their height and width are rounded down to the closest multiple of `size_divisor`.
When `do_resize` is `True`, images are resized so their height and width are rounded down to the closest
multiple of `size_divisor`. Can be overridden by `size_divisor` in `preprocess`.
resample (`PIL.Image` resampling filter, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
Set the class default for `resample`. Defines the resampling filter to use if resizing the image.
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
do_rescale (`bool`, *optional*, defaults to `True`):
Set the class default for the `do_rescale` parameter. Controls whether or not to apply the scaling factor
(to make pixel values floats between 0. and 1.).
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Can be
overridden by `do_rescale` in `preprocess`.
"""
model_input_names = ["pixel_values"]
@ -81,7 +81,7 @@ class GLPNImageProcessor(BaseImageProcessor):
`size_divisor`.
resample:
`PIL.Image` resampling filter to use when resizing the image e.g. `PIL.Image.Resampling.BILINEAR`.
data_format (`ChannelDimension`, *optional*):
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If `None`, the channel dimension format of the input
image is used. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
@ -108,7 +108,7 @@ class GLPNImageProcessor(BaseImageProcessor):
The image to rescale.
scale (`float`):
The scaling factor to rescale pixel values by.
data_format (`ChannelDimension`, *optional*):
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If `None`, the channel dimension format of the input
image is used. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
@ -146,14 +146,14 @@ class GLPNImageProcessor(BaseImageProcessor):
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.).
return_tensors (`str`, *optional*):
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- `None`: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.

View File

@ -14,168 +14,11 @@
# limitations under the License.
"""Feature extractor class for ImageGPT."""
from typing import List, Optional, Union
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
from ...utils import TensorType, logging
from ...utils import logging
from .image_processing_imagegpt import ImageGPTImageProcessor
logger = logging.get_logger(__name__)
def squared_euclidean_distance(a, b):
b = b.T
a2 = np.sum(np.square(a), axis=1)
b2 = np.sum(np.square(b), axis=0)
ab = np.matmul(a, b)
d = a2[:, None] - 2 * ab + b2[None, :]
return d
def color_quantize(x, clusters):
x = x.reshape(-1, 3)
d = squared_euclidean_distance(x, clusters)
return np.argmin(d, axis=1)
class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs an ImageGPT feature extractor. This feature extractor can be used to resize images to a smaller
resolution (such as 32x32 or 64x64), normalize them and finally color quantize them to obtain sequences of "pixel
values" (color clusters).
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
clusters (`np.ndarray`):
The color clusters to use, as a `np.ndarray` of shape `(n_clusters, 3)`.
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int` or `Tuple(int)`, *optional*, defaults to 32):
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input to the range between -1 and +1.
"""
model_input_names = ["input_ids"]
def __init__(
self, clusters, do_resize=True, size=32, resample=PILImageResampling.BILINEAR, do_normalize=True, **kwargs
):
super().__init__(**kwargs)
self.clusters = np.asarray(clusters)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_normalize = do_normalize
def normalize(self, image):
"""
Normalizes `image` into the range -1 to +1.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to normalize.
Returns:
`np.ndarray`: The normalized image.
"""
image = self.to_numpy_array(image, rescale=False, channel_first=False)
return image / 127.5 - 1
def __call__(
self,
images: Union[
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
],
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- Input IDs to be fed to a model, of shape `(batch_size, height * width)`.
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (resizing + normalization)
if self.do_resize and self.size is not None:
images = [self.resize(image, size=self.size, resample=self.resample) for image in images]
if self.do_normalize:
images = [self.normalize(image) for image in images]
# color quantize from (batch_size, height, width, 3) to (batch_size, height, width)
images = np.array(images)
images = color_quantize(images, self.clusters).reshape(images.shape[:-1])
# flatten to (batch_size, height*width)
batch_size = images.shape[0]
images = images.reshape(batch_size, -1)
# return as BatchFeature
data = {"input_ids": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
ImageGPTFeatureExtractor = ImageGPTImageProcessor

View File

@ -0,0 +1,239 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for ImageGPT."""
from typing import Dict, List, Optional, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import rescale, resize, to_channel_dimension_format
from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images
from ...utils import logging
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
def squared_euclidean_distance(a, b):
b = b.T
a2 = np.sum(np.square(a), axis=1)
b2 = np.sum(np.square(b), axis=0)
ab = np.matmul(a, b)
d = a2[:, None] - 2 * ab + b2[None, :]
return d
def color_quantize(x, clusters):
x = x.reshape(-1, 3)
d = squared_euclidean_distance(x, clusters)
return np.argmin(d, axis=1)
class ImageGPTImageProcessor(BaseImageProcessor):
r"""
Constructs a ImageGPT image processor. This image processor can be used to resize images to a smaller resolution
(such as 32x32 or 64x64), normalize them and finally color quantize them to obtain sequences of "pixel values"
(color clusters).
Args:
clusters (`np.ndarray`, *optional*):
The color clusters to use, as a `np.ndarray` of shape `(n_clusters, 3)` when color quantizing. Can be
overriden by `clusters` in `preprocess`.
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's dimensions to `(size["height"], size["width"])`. Can be overridden by
`do_resize` in `preprocess`.
size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
Size of the image after resizing. Can be overridden by `size` in `preprocess`.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image pixel value to between [-1, 1]. Can be overridden by `do_normalize` in
`preprocess`.
do_color_quantize (`bool`, *optional*, defaults to `True`):
Whether to color quantize the image. Can be overridden by `do_color_quantize` in `preprocess`.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
# clusters is a first argument to maintain backwards compatibility with the old ImageGPTFeatureExtractor
clusters: Optional[np.ndarray] = None,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_normalize: bool = True,
do_color_quantize: bool = True,
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 256, "width": 256}
size = get_size_dict(size)
self.clusters = clusters
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_normalize = do_normalize
self.do_color_quantize = do_color_quantize
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image to (size["height"], size["width"]).
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"Size dictionary must contain both height and width keys. Got {size.keys()}")
return resize(
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
)
def normalize(
self,
image: np.ndarray,
data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Normalizes an images' pixel values to between [-1, 1].
Args:
image (`np.ndarray`):
Image to normalize.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
image = rescale(image=image, scale=1 / 127.5, data_format=data_format)
image = image - 1
return image
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_normalize: bool = None,
do_color_quantize: Optional[bool] = None,
clusters: Optional[Union[int, List[int]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image
do_color_quantize (`bool`, *optional*, defaults to `self.do_color_quantize`):
Whether to color quantize the image.
clusters (`np.ndarray`, *optional*, defaults to `self.clusters`):
Clusters used to quantize the image of shape `(n_clusters, 3)`. Only has an effect if
`do_color_quantize` is set to `True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Only has an effect if `do_color_quantize` is set to `False`.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size)
resample = resample if resample is not None else self.resample
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
do_color_quantize = do_color_quantize if do_color_quantize is not None else self.do_color_quantize
clusters = clusters if clusters is not None else self.clusters
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_color_quantize and clusters is None:
raise ValueError("Clusters must be specified if do_color_quantize is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images]
if do_normalize:
images = [self.normalize(image=image) for image in images]
if do_color_quantize:
images = [to_channel_dimension_format(image, ChannelDimension.LAST) for image in images]
# color quantize from (batch_size, height, width, 3) to (batch_size, height, width)
images = np.array(images)
clusters = np.array(clusters)
images = color_quantize(images, clusters).reshape(images.shape[:-1])
# flatten to (batch_size, height*width)
batch_size = images.shape[0]
images = images.reshape(batch_size, -1)
# We need to convert back to a list of images to keep consistent behaviour across processors.
images = list(images)
else:
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"input_ids": images}
return BatchFeature(data=data, tensor_type=return_tensors)

View File

@ -992,11 +992,12 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
... )
>>> clusters = feature_extractor.clusters
>>> n_px = feature_extractor.size
>>> height = feature_extractor.size["height"]
>>> width = feature_extractor.size["width"]
>>> samples = output[:, 1:].cpu().detach().numpy()
>>> samples_img = [
... np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples
... np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [height, width, 3]).astype(np.uint8) for s in samples
... ] # convert color cluster tokens back to pixels
>>> f, axes = plt.subplots(1, batch_size, dpi=300)

View File

@ -16,226 +16,10 @@
Feature extractor class for LayoutLMv2.
"""
from typing import List, Optional, Union
from ...utils import logging
from .image_processing_layoutlmv2 import LayoutLMv2ImageProcessor
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
from ...utils import TensorType, is_pytesseract_available, logging, requires_backends
# soft dependency
if is_pytesseract_available():
import pytesseract
logger = logging.get_logger(__name__)
ImageInput = Union[
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
]
def normalize_box(box, width, height):
return [
int(1000 * (box[0] / width)),
int(1000 * (box[1] / height)),
int(1000 * (box[2] / width)),
int(1000 * (box[3] / height)),
]
def apply_tesseract(image: Image.Image, lang: Optional[str], tesseract_config: Optional[str]):
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
# apply OCR
data = pytesseract.image_to_data(image, lang=lang, output_type="dict", config=tesseract_config)
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
# filter empty words and corresponding coordinates
irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
# turn coordinates into (left, top, left+width, top+height) format
actual_boxes = []
for x, y, w, h in zip(left, top, width, height):
actual_box = [x, y, x + w, y + h]
actual_boxes.append(actual_box)
image_width, image_height = image.size
# finally, normalize the bounding boxes
normalized_boxes = []
for box in actual_boxes:
normalized_boxes.append(normalize_box(box, image_width, image_height))
assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
return words, normalized_boxes
class LayoutLMv2FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a LayoutLMv2 feature extractor. This can be used to resize document images to the same size, as well as
to apply OCR on them in order to get a list of words and normalized bounding boxes.
This feature extractor inherits from [`~feature_extraction_utils.PreTrainedFeatureExtractor`] which contains most
of the main methods. Users should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
apply_ocr (`bool`, *optional*, defaults to `True`):
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
ocr_lang (`str`, *optional*):
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
used.
tesseract_config (`str`, *optional*):
Any additional custom configuration flags that are forwarded to the `config` parameter when calling
Tesseract. For example: '--psm 6'.
<Tip>
LayoutLMv2FeatureExtractor uses Google's Tesseract OCR engine under the hood.
</Tip>"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=224,
resample=PILImageResampling.BILINEAR,
apply_ocr=True,
ocr_lang=None,
tesseract_config="",
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.apply_ocr = apply_ocr
self.ocr_lang = ocr_lang
self.tesseract_config = tesseract_config
def __call__(
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
- **words** -- Optional words as identified by Tesseract OCR (only when [`LayoutLMv2FeatureExtractor`] was
initialized with `apply_ocr` set to `True`).
- **boxes** -- Optional bounding boxes as identified by Tesseract OCR, normalized based on the image size
(only when [`LayoutLMv2FeatureExtractor`] was initialized with `apply_ocr` set to `True`).
Examples:
```python
>>> from transformers import LayoutLMv2FeatureExtractor
>>> from PIL import Image
>>> # Document can be a png, jpg, etc. PDFs must be converted to images.
>>> image = Image.open(name_of_your_document).convert("RGB")
>>> # option 1: with apply_ocr=True (default)
>>> feature_extractor = LayoutLMv2FeatureExtractor()
>>> encoding = feature_extractor(image, return_tensors="pt")
>>> print(encoding.keys())
>>> # dict_keys(['pixel_values', 'words', 'boxes'])
>>> # option 2: with apply_ocr=False
>>> feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
>>> encoding = feature_extractor(image, return_tensors="pt")
>>> print(encoding.keys())
>>> # dict_keys(['pixel_values'])
```"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples), "
f"but is of type {type(images)}."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# Tesseract OCR to get words + normalized bounding boxes
if self.apply_ocr:
requires_backends(self, "pytesseract")
words_batch = []
boxes_batch = []
for image in images:
words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang, self.tesseract_config)
words_batch.append(words)
boxes_batch.append(boxes)
# transformations (resizing)
if self.do_resize and self.size is not None:
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
images = [self.to_numpy_array(image, rescale=False) for image in images]
# flip color channels from RGB to BGR (as Detectron2 requires this)
images = [image[::-1, :, :] for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
if self.apply_ocr:
encoded_inputs["words"] = words_batch
encoded_inputs["boxes"] = boxes_batch
return encoded_inputs
LayoutLMv2FeatureExtractor = LayoutLMv2ImageProcessor

View File

@ -0,0 +1,268 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for LayoutLMv2."""
from typing import Dict, Optional, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import resize, to_channel_dimension_format, to_pil_image
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import is_pytesseract_available, logging, requires_backends
if is_vision_available():
import PIL
# soft dependency
if is_pytesseract_available():
import pytesseract
logger = logging.get_logger(__name__)
def normalize_box(box, width, height):
return [
int(1000 * (box[0] / width)),
int(1000 * (box[1] / height)),
int(1000 * (box[2] / width)),
int(1000 * (box[3] / height)),
]
def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Optional[str] = None):
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
tesseract_config = tesseract_config if tesseract_config is not None else ""
# apply OCR
pil_image = to_pil_image(image)
image_width, image_height = pil_image.size
data = pytesseract.image_to_data(pil_image, lang=lang, output_type="dict", config=tesseract_config)
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
# filter empty words and corresponding coordinates
irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
# turn coordinates into (left, top, left+width, top+height) format
actual_boxes = []
for x, y, w, h in zip(left, top, width, height):
actual_box = [x, y, x + w, y + h]
actual_boxes.append(actual_box)
# finally, normalize the bounding boxes
normalized_boxes = []
for box in actual_boxes:
normalized_boxes.append(normalize_box(box, image_width, image_height))
assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
return words, normalized_boxes
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.LAST:
image = image[..., ::-1]
elif input_data_format == ChannelDimension.FIRST:
image = image[:, ::-1, ...]
else:
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
return image
class LayoutLMv2ImageProcessor(BaseImageProcessor):
r"""
Constructs a LayoutLMv2 image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to `(size["height"], size["width"])`. Can be
overridden by `do_resize` in `preprocess`.
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after resizing. Can be overridden by `size` in `preprocess`.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
apply_ocr (`bool`, *optional*, defaults to `True`):
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes. Can be overridden by
`apply_ocr` in `preprocess`.
ocr_lang (`str`, *optional*):
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
used. Can be overridden by `ocr_lang` in `preprocess`.
tesseract_config (`str`, *optional*):
Any additional custom configuration flags that are forwarded to the `config` parameter when calling
Tesseract. For example: '--psm 6'. Can be overridden by `tesseract_config` in `preprocess`.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
apply_ocr: bool = True,
ocr_lang: Optional[str] = None,
tesseract_config: Optional[str] = "",
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.apply_ocr = apply_ocr
self.ocr_lang = ocr_lang
self.tesseract_config = tesseract_config
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
output_size = (size["height"], size["width"])
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
apply_ocr: bool = None,
ocr_lang: Optional[str] = None,
tesseract_config: Optional[str] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Desired size of the output image after resizing.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PIL.Image` resampling
filter. Only has an effect if `do_resize` is set to `True`.
apply_ocr (`bool`, *optional*, defaults to `self.apply_ocr`):
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
ocr_lang (`str`, *optional*, defaults to `self.ocr_lang`):
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
used.
tesseract_config (`str`, *optional*, defaults to `self.tesseract_config`):
Any additional custom configuration flags that are forwarded to the `config` parameter when calling
Tesseract.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size)
resample = resample if resample is not None else self.resample
apply_ocr = apply_ocr if apply_ocr is not None else self.apply_ocr
ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang
tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if apply_ocr:
requires_backends(self, "pytesseract")
words_batch = []
boxes_batch = []
for image in images:
words, boxes = apply_tesseract(image, ocr_lang, tesseract_config)
words_batch.append(words)
boxes_batch.append(boxes)
if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images]
# flip color channels from RGB to BGR (as Detectron2 requires this)
images = [flip_channel_order(image) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
if apply_ocr:
data["words"] = words_batch
data["boxes"] = boxes_batch
return data

View File

@ -16,235 +16,11 @@
Feature extractor class for LayoutLMv3.
"""
from typing import List, Optional, Union
from ...utils import logging
from .image_processing_layoutlmv3 import LayoutLMv3ImageProcessor
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageFeatureExtractionMixin, is_torch_tensor
from ...utils import TensorType, is_pytesseract_available, logging, requires_backends
# soft dependency
if is_pytesseract_available():
import pytesseract
logger = logging.get_logger(__name__)
ImageInput = Union[
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
]
def normalize_box(box, width, height):
return [
int(1000 * (box[0] / width)),
int(1000 * (box[1] / height)),
int(1000 * (box[2] / width)),
int(1000 * (box[3] / height)),
]
def apply_tesseract(image: Image.Image, lang: Optional[str], tesseract_config: Optional[str]):
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
# apply OCR
data = pytesseract.image_to_data(image, lang=lang, output_type="dict", config=tesseract_config)
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
# filter empty words and corresponding coordinates
irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
# turn coordinates into (left, top, left+width, top+height) format
actual_boxes = []
for x, y, w, h in zip(left, top, width, height):
actual_box = [x, y, x + w, y + h]
actual_boxes.append(actual_box)
image_width, image_height = image.size
# finally, normalize the bounding boxes
normalized_boxes = []
for box in actual_boxes:
normalized_boxes.append(normalize_box(box, image_width, image_height))
assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
return words, normalized_boxes
class LayoutLMv3FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a LayoutLMv3 feature extractor. This can be used to resize + normalize document images, as well as to
apply OCR on them in order to get a list of words and normalized bounding boxes.
This feature extractor inherits from [`~feature_extraction_utils.PreTrainedFeatureExtractor`] which contains most
of the main methods. Users should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
apply_ocr (`bool`, *optional*, defaults to `True`):
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
ocr_lang (`str`, *optional*):
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
used.
tesseract_config (`str`, *optional*):
Any additional custom configuration flags that are forwarded to the `config` parameter when calling
Tesseract. For example: '--psm 6'.
<Tip>
LayoutLMv3FeatureExtractor uses Google's Tesseract OCR engine under the hood.
</Tip>"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=224,
resample=PILImageResampling.BILINEAR,
do_normalize=True,
image_mean=None,
image_std=None,
apply_ocr=True,
ocr_lang=None,
tesseract_config="",
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.apply_ocr = apply_ocr
self.ocr_lang = ocr_lang
self.tesseract_config = tesseract_config
def __call__(
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
- **words** -- Optional words as identified by Tesseract OCR (only when [`LayoutLMv3FeatureExtractor`] was
initialized with `apply_ocr` set to `True`).
- **boxes** -- Optional bounding boxes as identified by Tesseract OCR, normalized based on the image size
(only when [`LayoutLMv3FeatureExtractor`] was initialized with `apply_ocr` set to `True`).
Examples:
```python
>>> from transformers import LayoutLMv3FeatureExtractor
>>> from PIL import Image
>>> # Document can be a png, jpg, etc. PDFs must be converted to images.
>>> image = Image.open(name_of_your_document).convert("RGB")
>>> # option 1: with apply_ocr=True (default)
>>> feature_extractor = LayoutLMv3FeatureExtractor()
>>> encoding = feature_extractor(image, return_tensors="pt")
>>> print(encoding.keys())
>>> # dict_keys(['pixel_values', 'words', 'boxes'])
>>> # option 2: with apply_ocr=False
>>> feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
>>> encoding = feature_extractor(image, return_tensors="pt")
>>> print(encoding.keys())
>>> # dict_keys(['pixel_values'])
```"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples), "
f"but is of type {type(images)}."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# Tesseract OCR to get words + normalized bounding boxes
if self.apply_ocr:
requires_backends(self, "pytesseract")
words_batch = []
boxes_batch = []
for image in images:
words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang, self.tesseract_config)
words_batch.append(words)
boxes_batch.append(boxes)
# transformations (resizing + normalization)
if self.do_resize and self.size is not None:
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
if self.apply_ocr:
encoded_inputs["words"] = words_batch
encoded_inputs["boxes"] = boxes_batch
return encoded_inputs
LayoutLMv3FeatureExtractor = LayoutLMv3ImageProcessor

View File

@ -0,0 +1,371 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for LayoutLMv3."""
from typing import Dict, Iterable, Optional, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format, to_pil_image
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import is_pytesseract_available, logging, requires_backends
if is_vision_available():
import PIL
# soft dependency
if is_pytesseract_available():
import pytesseract
logger = logging.get_logger(__name__)
def normalize_box(box, width, height):
return [
int(1000 * (box[0] / width)),
int(1000 * (box[1] / height)),
int(1000 * (box[2] / width)),
int(1000 * (box[3] / height)),
]
def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Optional[str]):
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
# apply OCR
pil_image = to_pil_image(image)
image_width, image_height = pil_image.size
data = pytesseract.image_to_data(pil_image, lang=lang, output_type="dict", config=tesseract_config)
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
# filter empty words and corresponding coordinates
irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
# turn coordinates into (left, top, left+width, top+height) format
actual_boxes = []
for x, y, w, h in zip(left, top, width, height):
actual_box = [x, y, x + w, y + h]
actual_boxes.append(actual_box)
# finally, normalize the bounding boxes
normalized_boxes = []
for box in actual_boxes:
normalized_boxes.append(normalize_box(box, image_width, image_height))
assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
return words, normalized_boxes
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.LAST:
image = image[..., ::-1]
elif input_data_format == ChannelDimension.FIRST:
image = image[:, ::-1, ...]
else:
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
return image
class LayoutLMv3ImageProcessor(BaseImageProcessor):
r"""
Constructs a LayoutLMv3 image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to `(size["height"], size["width"])`. Can be
overridden by `do_resize` in `preprocess`.
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after resizing. Can be overridden by `size` in `preprocess`.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image's pixel values by the specified `rescale_value`. Can be overridden by
`do_rescale` in `preprocess`.
rescale_factor (`float`, *optional*, defaults to 1 / 255):
Value by which the image's pixel values are rescaled. Can be overridden by `rescale_factor` in
`preprocess`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`Iterable[float]` or `float`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`Iterable[float]` or `float`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
apply_ocr (`bool`, *optional*, defaults to `True`):
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes. Can be overridden by
the `apply_ocr` parameter in the `preprocess` method.
ocr_lang (`str`, *optional*):
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
used. Can be overridden by the `ocr_lang` parameter in the `preprocess` method.
tesseract_config (`str`, *optional*):
Any additional custom configuration flags that are forwarded to the `config` parameter when calling
Tesseract. For example: '--psm 6'. Can be overridden by the `tesseract_config` parameter in the
`preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_value: float = 1 / 255,
do_normalize: bool = True,
image_mean: Union[float, Iterable[float]] = None,
image_std: Union[float, Iterable[float]] = None,
apply_ocr: bool = True,
ocr_lang: Optional[str] = None,
tesseract_config: Optional[str] = "",
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_value
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.apply_ocr = apply_ocr
self.ocr_lang = ocr_lang
self.tesseract_config = tesseract_config
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image to (size["height"], size["width"]) dimensions.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
output_size = (size["height"], size["width"])
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image.
Args:
image (`np.ndarray`):
Image to normalize.
mean (`float` or `Iterable[float]`):
Mean values to be used for normalization.
std (`float` or `Iterable[float]`):
Standard deviation values to be used for normalization.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample=None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Union[float, Iterable[float]] = None,
image_std: Union[float, Iterable[float]] = None,
apply_ocr: bool = None,
ocr_lang: Optional[str] = None,
tesseract_config: Optional[str] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Desired size of the output image after applying `resize`.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` filters.
Only has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image pixel values between [0, 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to apply to the image pixel values. Only has an effect if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `Iterable[float]`, *optional*, defaults to `self.image_mean`):
Mean values to be used for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `Iterable[float]`, *optional*, defaults to `self.image_std`):
Standard deviation values to be used for normalization. Only has an effect if `do_normalize` is set to
`True`.
apply_ocr (`bool`, *optional*, defaults to `self.apply_ocr`):
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
ocr_lang (`str`, *optional*, defaults to `self.ocr_lang`):
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
used.
tesseract_config (`str`, *optional*, defaults to `self.tesseract_config`):
Any additional custom configuration flags that are forwarded to the `config` parameter when calling
Tesseract.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size)
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
apply_ocr = apply_ocr if apply_ocr is not None else self.apply_ocr
ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang
tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("If do_normalize is True, image_mean and image_std must be specified.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
# Tesseract OCR to get words + normalized bounding boxes
if apply_ocr:
requires_backends(self, "pytesseract")
words_batch = []
boxes_batch = []
for image in images:
words, boxes = apply_tesseract(image, ocr_lang, tesseract_config)
words_batch.append(words)
boxes_batch.append(boxes)
if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images]
if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
# flip color channels from RGB to BGR (as Detectron2 requires this)
images = [flip_channel_order(image) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
if apply_ocr:
data["words"] = words_batch
data["boxes"] = boxes_batch
return data

View File

@ -14,148 +14,12 @@
# limitations under the License.
"""Feature extractor class for LeViT."""
from typing import Optional, Union
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ImageFeatureExtractionMixin,
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, logging
from ...utils import logging
from .image_processing_levit import LevitImageProcessor
logger = logging.get_logger(__name__)
class LevitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a LeViT feature extractor.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the shortest edge of the input to int(256/224 *`size`).
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then shorter side of input will be resized to 'size'.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether or not to center crop the input to `size`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=224,
resample=PILImageResampling.BICUBIC,
do_center_crop=True,
do_normalize=True,
image_mean=IMAGENET_DEFAULT_MEAN,
image_std=IMAGENET_DEFAULT_STD,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
def __call__(
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (resizing + center cropping + normalization)
if self.do_resize and self.size is not None:
size_ = int((256 / 224) * self.size)
images = [
self.resize(image=image, size=size_, resample=self.resample, default_to_square=False)
for image in images
]
if self.do_center_crop:
images = [self.center_crop(image=image, size=self.size) for image in images]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
# Feature extractor for Levit is being replaced by image processor
LevitFeatureExtractor = LevitImageProcessor

View File

@ -0,0 +1,342 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for LeViT."""
from typing import Dict, Iterable, List, Optional, Union
import numpy as np
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_crop,
get_resize_output_image_size,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import logging
logger = logging.get_logger(__name__)
class LevitImageProcessor(BaseImageProcessor):
r"""
Constructs a LeViT image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Wwhether to resize the shortest edge of the input to int(256/224 *`size`). Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]`, *optional*, defaults to `{"shortest_edge": 224}`):
Size of the output image after resizing. If size is a dict with keys "width" and "height", the image will
be resized to `(size["height"], size["width"])`. If size is a dict with key "shortest_edge", the shortest
edge value `c` is rescaled to `int(c * (256/224))`. The smaller edge of the image will be matched to this
value i.e, if height > width, then image will be rescaled to `(size["shortest_egde"] * height / width,
size["shortest_egde"])`. Can be overridden by the `size` parameter in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether or not to center crop the input to `(crop_size["height"], crop_size["width"])`. Can be overridden
by the `do_center_crop` parameter in the `preprocess` method.
crop_size (`Dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
Desired image size after `center_crop`. Can be overridden by the `crop_size` parameter in the `preprocess`
method.
do_rescale (`bool`, *optional*, defaults to `True`):
Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
`preprocess` method.
image_mean (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, Iterable[float]]] = IMAGENET_DEFAULT_MEAN,
image_std: Optional[Union[float, Iterable[float]]] = IMAGENET_DEFAULT_STD,
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image.
If size is a dict with keys "width" and "height", the image will be resized to `(size["height"],
size["width"])`.
If size is a dict with key "shortest_edge", the shortest edge value `c` is rescaled to `int(c * (256/224))`.
The smaller edge of the image will be matched to this value i.e, if height > width, then image will be rescaled
to `(size["shortest_egde"] * height / width, size["shortest_egde"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image after resizing. If size is a dict with keys "width" and "height", the image
will be resized to (height, width). If size is a dict with key "shortest_edge", the shortest edge value
`c` is rescaled to int(`c` * (256/224)). The smaller edge of the image will be matched to this value
i.e, if height > width, then image will be rescaled to (size * height / width, size).
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size_dict = get_size_dict(size, default_to_square=False)
# size_dict is a dict with either keys "height" and "width" or "shortest_edge"
if "shortest_edge" in size:
shortest_edge = int((256 / 224) * size["shortest_edge"])
output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False)
size_dict = {"height": output_size[0], "width": output_size[1]}
if "height" not in size_dict or "width" not in size_dict:
raise ValueError(
f"Size dict must have keys 'height' and 'width' or 'shortest_edge'. Got {size_dict.keys()}"
)
return resize(
image, size=(size_dict["height"], size_dict["width"]), resample=resample, data_format=data_format, **kwargs
)
def center_crop(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Center crop an image.
Args:
image (`np.ndarray`):
Image to center crop.
size (`Dict[str, int]`):
Dict `{"height": int, "width": int}` specifying the size of the output image after cropping.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
mean (`float` or `List[float]`):
Image mean.
std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, Iterable[float]]] = None,
image_std: Optional[Union[float, Iterable[float]]] = None,
return_tensors: Optional[TensorType] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> BatchFeature:
"""
Preprocess an image or batch of images to be used as input to a LeViT model.
Args:
images (`ImageInput`):
Image or batch of images to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the output image after resizing. If size is a dict with keys "width" and "height", the image
will be resized to (height, width). If size is a dict with key "shortest_edge", the shortest edge value
`c` is rescaled to int(`c` * (256/224)). The smaller edge of the image will be matched to this value
i.e, if height > width, then image will be rescaled to (size * height / width, size).
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resiizing the image.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the output image after center cropping. Crops images to (crop_size["height"],
crop_size["width"]).
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image pixel values by `rescaling_factor` - typical to values between 0 and 1.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Factor to rescale the image pixel values by.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image pixel values by `image_mean` and `image_std`.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Mean to normalize the image pixel values by.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Standard deviation to normalize the image pixel values by.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_resize:
images = [self.resize(image, size, resample) for image in images]
if do_center_crop:
images = [self.center_crop(image, crop_size) for image in images]
if do_rescale:
images = [self.rescale(image, rescale_factor) for image in images]
if do_normalize:
images = [self.normalize(image, image_mean, image_std) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

View File

@ -14,189 +14,11 @@
# limitations under the License.
"""Feature extractor class for MobileViT."""
from typing import List, Optional, Tuple, Union
from ...utils import logging
from .image_processing_mobilevit import MobileViTImageProcessor
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor
from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class MobileViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a MobileViT feature extractor.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int` or `Tuple(int)`, *optional*, defaults to 288):
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then the input will be resized to match the shorter side. Only has an effect if
`do_resize` is set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
image is padded with 0's and then center cropped.
crop_size (`int`, *optional*, defaults to 256):
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
do_flip_channel_order (`bool`, *optional*, defaults to `True`):
Whether to flip the color channels from RGB to BGR.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=288,
resample=PILImageResampling.BILINEAR,
do_center_crop=True,
crop_size=256,
do_flip_channel_order=True,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_flip_channel_order = do_flip_channel_order
def __call__(
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (resizing + normalization)
if self.do_resize and self.size is not None:
images = [
self.resize(image=image, size=self.size, resample=self.resample, default_to_square=False)
for image in images
]
if self.do_center_crop and self.crop_size is not None:
images = [self.center_crop(image, self.crop_size) for image in images]
images = [self.to_numpy_array(image) for image in images]
# the pretrained checkpoints assume images are BGR, not RGB
if self.do_flip_channel_order:
images = [self.flip_channel_order(image) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Only supports
PyTorch.
Args:
outputs ([`MobileViTForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]`, *optional*):
A list of length `batch_size`, where each item is a `Tuple[int, int]` corresponding to the requested
final size (height, width) of each prediction. If left to None, predictions will not be resized.
Returns:
`List[torch.Tensor]`:
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
`torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
MobileViTFeatureExtractor = MobileViTImageProcessor

View File

@ -0,0 +1,364 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for MobileViT."""
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from transformers.utils import is_torch_available, is_torch_tensor, is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import center_crop, get_resize_output_image_size, rescale, resize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import logging
if is_vision_available():
import PIL
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension]) -> np.ndarray:
"""
Flip the color channels from RGB to BGR or vice versa.
Args:
image (`np.ndarray`):
The image, represented as a numpy array.
data_format (`ChannelDimension`, *`optional`*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
Returns:
`np.ndarray`: The image with the flipped color channels.
"""
input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.LAST:
image = image[..., ::-1]
elif input_data_format == ChannelDimension.FIRST:
image = image[:, ::-1, ...]
else:
raise ValueError(f"Invalid input channel dimension format: {input_data_format}")
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
return image
class MobileViTImageProcessor(BaseImageProcessor):
r"""
Constructs a MobileViT image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
Controls the size of the output image after resizing. Can be overridden by the `size` parameter in the
`preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Defines the resampling filter to use if resizing the image. Can be overridden by the `resample` parameter
in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
image is padded with 0's and then center cropped. Can be overridden by the `do_center_crop` parameter in
the `preprocess` method.
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`):
Desired output size `(size["height"], size["width"])` when applying center-cropping. Can be overridden by
the `crop_size` parameter in the `preprocess` method.
do_flip_channel_order (`bool`, *optional*, defaults to `True`):
Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order`
parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_flip_channel_order: bool = True,
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
crop_size = get_size_dict(crop_size)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_flip_channel_order = do_flip_channel_order
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PIL.Image.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Controls the size of the output image. The shortest edge of the image will be resized to
`size["shortest_edge"]` while maintaining the aspect ratio.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size:
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
def center_crop(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Center crop an image to size `(size["height], size["width"])`. If the input size is smaller than `size` along
any edge, the image is padded with 0's and then center cropped.
Args:
image (`np.ndarray`):
Image to center crop.
size (`Dict[str, int]`):
Size of the output image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def flip_channel_order(
self, image: np.ndarray, data_format: Optional[Union[str, ChannelDimension]] = None
) -> np.ndarray:
"""
Flip the color channels from RGB to BGR or vice versa.
Args:
image (`np.ndarray`):
The image, represented as a numpy array.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return flip_channel_order(image, data_format=data_format)
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_flip_channel_order: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image by rescale factor.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the center crop if `do_center_crop` is set to `True`.
do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`):
Whether to flip the channel order of the image.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
do_flip_channel_order = (
do_flip_channel_order if do_flip_channel_order is not None else self.do_flip_channel_order
)
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images]
if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images]
if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
# the pretrained checkpoints assume images are BGR, not RGB
if do_flip_channel_order:
images = [self.flip_channel_order(image=image) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Args:
Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Only supports
PyTorch.
outputs ([`MobileViTForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]`, *optional*):
A list of length `batch_size`, where each item is a `Tuple[int, int]` corresponding to the requested
final size (height, width) of each prediction. If left to None, predictions will not be resized.
Returns:
`List[torch.Tensor]`:
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
`torch.Tensor` correspond to a semantic class id.
"""
# TODO: add support for other frameworks
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation

View File

@ -14,179 +14,11 @@
# limitations under the License.
"""Feature extractor class for Perceiver."""
from typing import Optional, Union
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ImageFeatureExtractionMixin,
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, logging
from ...utils import logging
from .image_processing_perceiver import PerceiverImageProcessor
logger = logging.get_logger(__name__)
class PerceiverFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a Perceiver feature extractor.
This feature extractor inherits from [`ImageFeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
image is padded with 0's and then center cropped.
crop_size (`int`, *optional*, defaults to 256):
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with `image_mean` and `image_std`.
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_center_crop=True,
crop_size=256,
do_resize=True,
size=224,
resample=PILImageResampling.BICUBIC,
do_normalize=True,
image_mean=None,
image_std=None,
**kwargs
):
super().__init__(**kwargs)
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
def center_crop(self, image):
"""
Crops `image` to *self.crop_size* using a center crop. Note that if the image is too small to be cropped to the
size given, it will be padded (so the returned result has the size asked).
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to resize.
"""
if isinstance(image, Image.Image):
image = self.to_numpy_array(image)
image_height, image_width = image.shape[-2:]
padded_center_crop_size = (
(self.size / (self.crop_size)) * np.minimum(image_height, image_width).astype(np.float32)
).astype(np.int32)
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
crop_window = [offset_height, offset_width, padded_center_crop_size, padded_center_crop_size]
image = image[
:, crop_window[0] : crop_window[0] + crop_window[2], crop_window[1] : crop_window[1] + crop_window[3]
]
return image
def __call__(
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (center cropping + resizing + normalization)
if self.do_center_crop and self.crop_size is not None:
images = [self.center_crop(image) for image in images]
if self.do_resize and self.size is not None and self.resample is not None:
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
PerceiverFeatureExtractor = PerceiverImageProcessor

View File

@ -0,0 +1,330 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Perceiver."""
from typing import Dict, List, Optional, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import logging
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
class PerceiverImageProcessor(BaseImageProcessor):
r"""
Constructs a Perceiver image processor.
Args:
do_center_crop (`bool`, `optional`, defaults to `True`):
Whether or not to center crop the image. If the input size if smaller than `crop_size` along any edge, the
image will be padded with zeros and then center cropped. Can be overridden by the `do_center_crop`
parameter in the `preprocess` method.
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`):
Desired output size when applying center-cropping. Can be overridden by the `crop_size` parameter in the
`preprocess` method.
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image to `(size["height"], size["width"])`. Can be overridden by the `do_resize`
parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after resizing. Can be overridden by the `size` parameter in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Defines the resampling filter to use if resizing the image. Can be overridden by the `resample` parameter
in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter
in the `preprocess` method.
do_normalize:
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
crop_size = get_size_dict(crop_size)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size)
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def center_crop(
self,
image: np.ndarray,
crop_size: Dict[str, int],
size: Optional[int] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Center crop an image to `(size["height"] / crop_size["height"] * min_dim, size["width"] / crop_size["width"] *
min_dim)`. Where `min_dim = min(size["height"], size["width"])`.
If the input size is smaller than `crop_size` along any edge, the image will be padded with zeros and then
center cropped.
Args:
image (`np.ndarray`):
Image to center crop.
crop_size (`Dict[str, int]`):
Desired output size after applying the center crop.
size (`Dict[str, int]`, *optional*):
Size of the image after resizing. If not provided, the self.size attribute will be used.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = self.size if size is None else size
size = get_size_dict(size)
crop_size = get_size_dict(crop_size)
height, width = get_image_size(image)
min_dim = min(height, width)
cropped_height = (size["height"] / crop_size["height"]) * min_dim
cropped_width = (size["width"] / crop_size["width"]) * min_dim
return center_crop(image, size=(cropped_height, cropped_width), data_format=data_format, **kwargs)
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PIL.Image.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BILINEAR`):
Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
return resize(
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
mean (`float` or `List[float]`):
Image mean.
std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def preprocess(
self,
images: ImageInput,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image to `crop_size`.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Desired output size after applying the center crop.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size)
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_center_crop and crop_size is None:
raise ValueError("If `do_center_crop` is set to `True`, `crop_size` must be provided.")
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and image standard deviation must be specified if do_normalize is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_center_crop:
images = [self.center_crop(image, crop_size, size=size) for image in images]
if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images]
if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

View File

@ -14,161 +14,11 @@
# limitations under the License.
"""Feature extractor class for PoolFormer."""
import math
from typing import Optional, Union
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ImageFeatureExtractionMixin,
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, logging
from ...utils import logging
from .image_processing_poolformer import PoolFormerImageProcessor
logger = logging.get_logger(__name__)
class PoolFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a PoolFormer feature extractor.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize_and_center_crop (`bool`, *optional*, defaults to `True`):
Whether to resize the shortest edge of the image and center crop the input to a certain `size`.
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
Center crop the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then the input will be center cropped to (size, size). Only has an effect if
`do_resize_and_center_crop` is set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
crop_pct (`float`, *optional*, defaults to `0.9`):
The percentage of the image to crop from the center. Only has an effect if `do_resize_and_center_crop` is
set to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with `image_mean` and `image_std`.
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize_and_center_crop=True,
size=224,
resample=PILImageResampling.BICUBIC,
crop_pct=0.9,
do_normalize=True,
image_mean=None,
image_std=None,
**kwargs
):
super().__init__(**kwargs)
self.do_resize_and_center_crop = do_resize_and_center_crop
self.size = size
self.resample = resample
self.crop_pct = crop_pct
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
def __call__(
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (resizing + center cropping + normalization)
if self.do_resize_and_center_crop and self.size is not None and self.crop_pct is not None:
if isinstance(self.size, (tuple, list)):
assert len(self.size) == 2
if self.size[-1] == self.size[-2]:
scale_size = int(math.floor(self.size[0] / self.crop_pct))
else:
scale_size = tuple([int(x / self.crop_pct) for x in self.size])
else:
scale_size = int(math.floor(self.size / self.crop_pct))
# resize shortest edge of the image
images = [
self.resize(image=image, size=scale_size, resample=self.resample, default_to_square=False)
for image in images
]
# center crop
images = [self.center_crop(image, size=self.size) for image in images]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
PoolFormerFeatureExtractor = PoolFormerImageProcessor

View File

@ -0,0 +1,382 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for PoolFormer."""
import math
from typing import Dict, List, Optional, Union
import numpy as np
from transformers import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_crop,
get_resize_output_image_size,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import logging
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
class PoolFormerImageProcessor(BaseImageProcessor):
r"""
Constructs a PoolFormer image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 256}`):
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. If crop_pct is
unset:
- size is `{"height": h, "width": w}`: the image is resized to `(h, w)`.
- size is `{"shortest_edge": s}`: the shortest edge of the image is resized to s whilst maintaining the
aspect ratio.
If crop_pct is set:
- size is `{"height": h, "width": w}`: the image is resized to `(int(floor(h/crop_pct)),
int(floor(w/crop_pct)))`
- size is `{"height": c, "width": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
whilst maintaining the aspect ratio.
- size is `{"shortest_edge": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
whilst maintaining the aspect ratio.
crop_pct (`float`, *optional*, defaults to `0.9`):
Percentage of the image to crop from the center. Can be overridden by `crop_pct` in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in the `preprocess`
method.
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`):
Size of the image after applying center crop. Only has an effect if `do_center_crop` is set to `True`. Can
be overridden by the `crop_size` parameter in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
`preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
crop_pct: int = 0.9,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
rescale_factor: Union[int, float] = 1 / 255,
do_rescale: bool = True,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 256}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
crop_size = get_size_dict(crop_size)
self.do_resize = do_resize
self.size = size
self.crop_pct = crop_pct
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
crop_pct: Optional[float] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image.
If crop_pct is unset:
- size is `{"height": h, "width": w}`: the image is resized to `(h, w)`.
- size is `{"shortest_edge": s}`: the shortest edge of the image is resized to s whilst maintaining the
aspect ratio.
if crop_pct is set:
- size is `{"height": h, "width": w}`: the image is resized to `(int(floor(h/crop_pct)),
int(floor(w/crop_pct)))`
- size is `{"height": c, "width": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
whilst maintaining the aspect ratio.
- size is `{"shortest_edge": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
whilst maintaining the aspect ratio.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
crop_pct (`float`, *optional*):
Percentage of the image that will be cropped from the center. If set, the image is resized
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size and ("height" not in size or "width" not in size):
raise ValueError(f"size must contain 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
if crop_pct is not None:
if "shortest_edge" in size:
scale_size = int(math.floor(size["shortest_edge"] / crop_pct))
elif "height" in size and "width" in size:
if size["height"] == size["width"]:
scale_size = int(math.floor(size["height"] / crop_pct))
else:
scale_size = (
int(math.floor(size["height"] / crop_pct)),
int(math.floor(size["width"] / crop_pct)),
)
else:
raise ValueError("Invalid size for resize: {}".format(size))
output_size = get_resize_output_image_size(image, size=scale_size, default_to_square=False)
else:
if "shortest_edge" in size:
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
elif "height" in size and "width" in size:
output_size = (size["height"], size["width"])
else:
raise ValueError("Invalid size for resize: {}".format(size))
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
def center_crop(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Center crop an image to (size["height"], size["width"]). If the input size is smaller than `crop_size` along
any edge, the image is padded with 0's and then center cropped.
Args:
image (`np.ndarray`):
Image to center crop.
size (`Dict[str, int]`):
Size of the output image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
image_mean (`float` or `List[float]`):
Image mean.
image_std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
crop_pct: int = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after applying resize.
crop_pct (`float`, *optional*, defaults to `self.crop_pct`):
Percentage of the image to crop. Only has an effect if `do_resize` is set to `True`.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the image after applying center crop.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
crop_pct = crop_pct if crop_pct is not None else self.crop_pct
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_center_crop and crop_pct is None:
raise ValueError("Crop_pct must be specified if do_center_crop is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_resize:
images = [self.resize(image=image, size=size, crop_pct=crop_pct, resample=resample) for image in images]
if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images]
if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

View File

@ -14,248 +14,11 @@
# limitations under the License.
"""Feature extractor class for SegFormer."""
from typing import List, Optional, Tuple, Union
from ...utils import logging
from .image_processing_segformer import SegformerImageProcessor
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ImageFeatureExtractionMixin,
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class SegformerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a SegFormer feature extractor.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input based on a certain `size`.
size (`int` or `Tuple(int)`, *optional*, defaults to 512):
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
ImageNet std.
reduce_labels (`bool`, *optional*, defaults to `False`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
background label will be replaced by 255.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=512,
resample=PILImageResampling.BILINEAR,
do_normalize=True,
image_mean=None,
image_std=None,
reduce_labels=False,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.reduce_labels = reduce_labels
def __call__(
self,
images: ImageInput,
segmentation_maps: ImageInput = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s) and optional corresponding segmentation maps.
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is
the number of channels, H and W are image height and width.
segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
Optionally, the corresponding semantic segmentation maps with the pixel-wise annotations.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
- **labels** -- Optional labels to be fed to a model (when `segmentation_maps` are provided)
"""
# Input type checking for clearer error
valid_images = False
valid_segmentation_maps = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
# Check that segmentation maps has a valid type
if segmentation_maps is not None:
if isinstance(segmentation_maps, (Image.Image, np.ndarray)) or is_torch_tensor(segmentation_maps):
valid_segmentation_maps = True
elif isinstance(segmentation_maps, (list, tuple)):
if (
len(segmentation_maps) == 0
or isinstance(segmentation_maps[0], (Image.Image, np.ndarray))
or is_torch_tensor(segmentation_maps[0])
):
valid_segmentation_maps = True
if not valid_segmentation_maps:
raise ValueError(
"Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single"
" example),`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of"
" examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
if segmentation_maps is not None:
segmentation_maps = [segmentation_maps]
# reduce zero label if needed
if self.reduce_labels:
if segmentation_maps is not None:
for idx, map in enumerate(segmentation_maps):
if not isinstance(map, np.ndarray):
map = np.array(map)
# avoid using underflow conversion
map[map == 0] = 255
map = map - 1
map[map == 254] = 255
segmentation_maps[idx] = Image.fromarray(map.astype(np.uint8))
# transformations (resizing + normalization)
if self.do_resize and self.size is not None:
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
if segmentation_maps is not None:
segmentation_maps = [
self.resize(map, size=self.size, resample=Image.NEAREST) for map in segmentation_maps
]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
if segmentation_maps is not None:
labels = []
for map in segmentation_maps:
if not isinstance(map, np.ndarray):
map = np.array(map)
labels.append(map.astype(np.int64))
# cast to np.int64
data["labels"] = labels
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports
PyTorch.
Args:
outputs ([`SegformerForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
SegformerFeatureExtractor = SegformerImageProcessor

View File

@ -0,0 +1,488 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Segformer."""
import warnings
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from transformers.utils import is_torch_available, is_torch_tensor, is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import logging
if is_vision_available():
import PIL.Image
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class SegformerImageProcessor(BaseImageProcessor):
r"""
Constructs a Segformer image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_reduce_labels (`bool`, *optional*, defaults to `False`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
`preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_reduce_labels: bool = False,
**kwargs
) -> None:
if "reduce_labels" in kwargs:
warnings.warn(
"The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use "
"`do_reduce_labels` instead.",
FutureWarning,
)
do_reduce_labels = kwargs.pop("reduce_labels")
super().__init__(**kwargs)
size = size if size is not None else {"height": 512, "width": 512}
size = get_size_dict(size)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_reduce_labels = do_reduce_labels
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
return resize(
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
)
def center_crop(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
any edge, the image is padded with 0's and then center cropped.
Args:
image (`np.ndarray`):
Image to center crop.
size (`Dict[str, int]`):
Size of the output image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
image_mean (`float` or `List[float]`):
Image mean.
image_std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def reduce_label(self, label: ImageInput) -> np.ndarray:
label = to_numpy_array(label)
# Avoid using underflow conversion
label[label == 0] = 255
label = label - 1
label[label == 254] = 255
return label
def _preprocess(
self,
image: ImageInput,
do_reduce_labels: bool,
do_resize: bool,
do_rescale: bool,
do_normalize: bool,
size: Optional[Dict[str, int]] = None,
resample: Optional[PILImageResampling] = None,
rescale_factor: Optional[float] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
):
if do_reduce_labels:
image = self.reduce_label(image)
if do_resize:
image = self.resize(image=image, size=size, resample=resample)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor)
if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std)
return image
def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
# All transformations expect numpy arrays.
image = to_numpy_array(image)
image = self._preprocess(
image=image,
do_reduce_labels=False,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
)
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
return image
def _preprocess_mask(
self,
segmentation_map: ImageInput,
do_reduce_labels: bool = None,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
) -> np.ndarray:
"""Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map)
# Add channel dimension if missing - needed for certain transformations
added_channel_dim = False
if segmentation_map.ndim == 2:
added_channel_dim = True
segmentation_map = segmentation_map[None, ...]
# reduce zero label if needed
segmentation_map = self._preprocess(
image=segmentation_map,
do_reduce_labels=do_reduce_labels,
do_resize=do_resize,
resample=PIL.Image.NEAREST,
size=size,
do_rescale=False,
do_normalize=False,
)
# Remove extra channel dimension if added for processing
if added_channel_dim:
segmentation_map = segmentation_map.squeeze(0)
segmentation_map = segmentation_map.astype(np.int64)
return segmentation_map
def __call__(self, images, segmentation_maps=None, **kwargs):
"""
Preprocesses a batch of images and optionally segmentation maps.
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
passed in as positional arguments.
"""
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional[PILImageResampling] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
segmentation_maps (`ImageInput`, *optional*):
Segmentation map to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after `resize` is applied.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g.
ADE20k). The background label will be replaced by 255.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
resample = resample if resample is not None else self.resample
size = size if size is not None else self.size
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
if not is_batched(images):
images = [images]
segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
images = [
self._preprocess_image(
image=img,
do_resize=do_resize,
resample=resample,
size=size,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
)
for img in images
]
data = {"pixel_values": images}
if segmentation_maps is not None:
segmentation_maps = [
self._preprocess_mask(
segmentation_map=segmentation_map,
do_reduce_labels=do_reduce_labels,
do_resize=do_resize,
resample=PIL.Image.NEAREST,
size=size,
)
for segmentation_map in segmentation_maps
]
data["labels"] = segmentation_maps
return BatchFeature(data=data, tensor_type=return_tensors)
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports
PyTorch.
Args:
outputs ([`SegformerForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
# TODO: add support for other frameworks
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation

View File

@ -14,159 +14,11 @@
# limitations under the License.
"""Feature extractor class for VideoMAE."""
from typing import Optional, Union
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor
from ...utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, TensorType, logging
from ...utils import logging
from .image_processing_videomae import VideoMAEImageProcessor
logger = logging.get_logger(__name__)
class VideoMAEFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a VideoMAE feature extractor. This feature extractor can be used to prepare videos for the model.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the shorter edge of the input to a certain `size`.
size (`int`, *optional*, defaults to 224):
Resize the shorter edge of the input to the given size. Only has an effect if `do_resize` is set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the input to a certain `size`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=224,
resample=PILImageResampling.BILINEAR,
do_center_crop=True,
do_normalize=True,
image_mean=None,
image_std=None,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
def resize_video(self, video, size, resample="bilinear"):
return [self.resize(frame, size, resample, default_to_square=False) for frame in video]
def crop_video(self, video, size):
return [self.center_crop(frame, size) for frame in video]
def normalize_video(self, video, mean, std):
# video can be a list of PIL images, list of NumPy arrays or list of PyTorch tensors
# first: convert to list of NumPy arrays
video = [self.to_numpy_array(frame) for frame in video]
# second: stack to get (num_frames, num_channels, height, width)
video = np.stack(video, axis=0)
# third: normalize
if not isinstance(mean, np.ndarray):
mean = np.array(mean).astype(video.dtype)
if not isinstance(std, np.ndarray):
std = np.array(std).astype(video.dtype)
return (video - mean[None, :, None, None]) / std[None, :, None, None]
def __call__(
self, videos: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several video(s).
<Tip warning={true}>
NumPy arrays are converted to PIL images when resizing, so the most efficient is to pass PIL images.
</Tip>
Args:
videos (`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, `List[List[PIL.Image.Image]]`, `List[List[np.ndarrray]]`,:
`List[List[torch.Tensor]]`): The video or batch of videos to be prepared. Each video should be a list
of frames, which can be either PIL images or NumPy arrays. In case of NumPy arrays/PyTorch tensors,
each frame should be of shape (H, W, C), where H and W are frame height and width, and C is a number of
channels.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, num_frames,
height, width).
"""
# Input type checking for clearer error
valid_videos = False
is_batched = False
# Check that videos have a valid type
if isinstance(videos, (list, tuple)):
if isinstance(videos[0], (Image.Image, np.ndarray)) or is_torch_tensor(videos[0]):
valid_videos = True
elif isinstance(videos[0], (list, tuple)) and (
isinstance(videos[0][0], (Image.Image, np.ndarray)) or is_torch_tensor(videos[0][0])
):
valid_videos = True
is_batched = True
if not valid_videos:
raise ValueError(
"Videos must of type `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]` (single"
" example), `List[List[PIL.Image.Image]]`, `List[List[np.ndarray]]`, `List[List[torch.Tensor]]` (batch"
" of examples)."
)
if not is_batched:
videos = [videos]
# transformations (resizing + center cropping + normalization)
if self.do_resize and self.size is not None:
videos = [self.resize_video(video, size=self.size, resample=self.resample) for video in videos]
if self.do_center_crop and self.size is not None:
videos = [self.crop_video(video, size=self.size) for video in videos]
if self.do_normalize:
videos = [self.normalize_video(video, mean=self.image_mean, std=self.image_std) for video in videos]
# return as BatchFeature
data = {"pixel_values": videos}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
VideoMAEFeatureExtractor = VideoMAEImageProcessor

View File

@ -0,0 +1,380 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for VideoMAE."""
from typing import Dict, List, Optional, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_crop,
get_resize_output_image_size,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
is_valid_image,
to_numpy_array,
valid_images,
)
from ...utils import logging
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
def make_batched(videos) -> List[List[ImageInput]]:
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
return videos
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
return [videos]
elif is_valid_image(videos):
return [[videos]]
raise ValueError(f"Could not make batched video from {videos}")
class VideoMAEImageProcessor(BaseImageProcessor):
r"""
Constructs a VideoMAE image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
Size of the output image after resizing. The shortest edge of the image will be resized to
`size["shortest_edge"]` while maintaining the aspect ratio of the original image. Can be overriden by
`size` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the image to the specified `crop_size`. Can be overridden by the `do_center_crop`
parameter in the `preprocess` method.
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after applying the center crop. Can be overridden by the `crop_size` parameter in the
`preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter
in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
self.do_resize = do_resize
self.size = size
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image. If `size` is of the form `{"height": h, "width": w}`, the output image will
have the size `(h, w)`. If `size` is of the form `{"shortest_edge": s}`, the output image will have its
shortest edge of length `s` while keeping the aspect ratio of the original image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "shortest_edge" in size:
output_size = get_resize_output_image_size(image, size["shortest_edge"], default_to_square=False)
elif "height" in size and "width" in size:
output_size = (size["height"], size["width"])
else:
raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
def center_crop(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `size` along any
edge, the image is padded with 0's and then center cropped.
Args:
image (`np.ndarray`):
Image to center crop.
size (`Dict[str, int]`):
Size of the output image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
image_mean (`float` or `List[float]`):
Image mean.
image_std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
) -> np.ndarray:
"""Preprocesses a single image."""
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# All transformations expect numpy arrays.
image = to_numpy_array(image)
if do_resize:
image = self.resize(image=image, size=size, resample=resample)
if do_center_crop:
image = self.center_crop(image, size=crop_size)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor)
if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std)
image = to_channel_dimension_format(image, data_format)
return image
def preprocess(
self,
videos: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after applying resize.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
do_center_crop (`bool`, *optional*, defaults to `self.do_centre_crop`):
Whether to centre crop the image.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the image after applying the centre crop.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the inferred channel dimension format of the input image.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
if not valid_images(videos):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
videos = make_batched(videos)
videos = [
[
self._preprocess_image(
image=img,
do_resize=do_resize,
size=size,
resample=resample,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
)
for img in video
]
for video in videos
]
data = {"pixel_values": videos}
return BatchFeature(data=data, tensor_type=return_tensors)

View File

@ -14,282 +14,11 @@
# limitations under the License.
"""Feature extractor class for ViLT."""
from typing import List, Optional, Union
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageFeatureExtractionMixin,
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
from ...utils import logging
from .image_processing_vilt import ViltImageProcessor
logger = logging.get_logger(__name__)
class ViltFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a ViLT feature extractor.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input based on `size`.
size (`int`, *optional*, defaults to 384):
Resize the shorter side of the input to the given size. Should be an integer. The longer side will be
limited to under int((1333 / 800) * size) while preserving the aspect ratio. Only has an effect if
`do_resize` is set to `True`.
size_divisor (`int`, *optional*, defaults to 32):
The size by which to make sure both the height and width can be divided.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values", "pixel_mask"]
def __init__(
self,
do_resize=True,
size=384,
size_divisor=32,
resample=PILImageResampling.BICUBIC,
do_normalize=True,
image_mean=None,
image_std=None,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.size_divisor = size_divisor
self.resample = resample
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def _resize(self, image, shorter=800, longer=1333, size_divisor=32, resample=PILImageResampling.BICUBIC):
"""
Resizes the shorter edge of `image` to `shorter` and limits the longer edge to under `longer`, while preserving
the aspect ratio. Also makes sure that both the height and width can be divided by `size_divisor`.
Based on original implementation:
https://github.com/dandelin/ViLT/blob/3db8b5035464afee84d951bf6322e1b27f1d072d/vilt/transforms/utils.py#L5
Args:
image (`PIL.Image`):
The image to resize.
shorter (`int`, *optional*, defaults to `800`):
The size to which to resize the shorter side of the image.
longer (`int`, *optional*, defaults to `1333`):
The size by which to limit the longer side of the image, while preserving the aspect ratio.
size_divisor (`int`, *optional*, defaults to `32`):
The size by which both the height and the width must be divisible.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter.
"""
if not isinstance(image, Image.Image):
image = self.to_pil_image(image)
w, h = image.size
min_size = shorter
max_size = longer
scale = min_size / min(w, h)
if h < w:
newh, neww = min_size, scale * w
else:
newh, neww = scale * h, min_size
if max(newh, neww) > max_size:
scale = max_size / max(newh, neww)
newh = newh * scale
neww = neww * scale
newh, neww = int(newh + 0.5), int(neww + 0.5)
newh, neww = newh // size_divisor * size_divisor, neww // size_divisor * size_divisor
return self.resize(image, size=(neww, newh), resample=resample)
def _max_by_axis(self, the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def pad_and_create_pixel_mask(
self, pixel_values_list: List["torch.Tensor"], return_tensors: Optional[Union[str, TensorType]] = None
):
"""
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
Args:
pixel_values_list (`List[torch.Tensor]`):
List of images (pixel values) to be padded. Each image should be a tensor of shape (C, H, W).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model.
- **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if
*"pixel_mask"* is in `self.model_input_names`).
"""
max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list])
c, h, w = max_size
padded_images = []
pixel_mask = []
for image in pixel_values_list:
# create padded image
padded_image = np.zeros((c, h, w), dtype=np.float32)
padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
padded_images.append(padded_image)
# create pixel mask
mask = np.zeros((h, w), dtype=np.int64)
mask[: image.shape[1], : image.shape[2]] = True
pixel_mask.append(mask)
# return as BatchFeature
data = {"pixel_values": padded_images, "pixel_mask": pixel_mask}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
def __call__(
self,
images: ImageInput,
pad_and_return_pixel_mask: Optional[bool] = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
If left to the default, will return a pixel mask that is:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
- **pixel_mask** -- Pixel mask to be fed to a model (when `return_pixel_mask=True` or if *"pixel_mask"* is
in `self.model_input_names`).
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (resizing + normalization)
if self.do_resize and self.size is not None:
longer = int((1333 / 800) * self.size)
images = [
self._resize(
image=image,
shorter=self.size,
longer=longer,
size_divisor=self.size_divisor,
resample=self.resample,
)
for image in images
]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
if pad_and_return_pixel_mask:
# pad images up to largest image in batch and create pixel_mask
max_size = self._max_by_axis([list(image.shape) for image in images])
c, h, w = max_size
padded_images = []
pixel_mask = []
for image in images:
# create padded image
padded_image = np.zeros((c, h, w), dtype=np.float32)
padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
padded_images.append(padded_image)
# create pixel mask
mask = np.zeros((h, w), dtype=np.int64)
mask[: image.shape[1], : image.shape[2]] = True
pixel_mask.append(mask)
images = padded_images
# return as BatchFeature
data = {}
data["pixel_values"] = images
if pad_and_return_pixel_mask:
data["pixel_mask"] = pixel_mask
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
ViltFeatureExtractor = ViltImageProcessor

View File

@ -0,0 +1,487 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Vilt."""
import warnings
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import logging
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
def max_across_indices(values: Iterable[Any]) -> List[Any]:
"""
Return the maximum value across all indices of an iterable of values.
"""
return [max(values_i) for values_i in zip(*values)]
def pad(
image: np.ndarray,
output_size: Tuple[int, int],
input_channel_dimension: Optional[ChannelDimension] = None,
data_format: Optional[ChannelDimension] = None,
) -> np.ndarray:
"""
Pad the bottom and right of the image with zeros to the output size.
Args:
image (`np.ndarray`):
Image to pad.
output_size (`Tuple[int, int]`):
Output size of the image.
input_channel_dimension (`ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be inferred from the input image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
if input_channel_dimension is None:
input_channel_dimension = infer_channel_dimension_format(image)
output_height, output_width = output_size
input_height, input_width = get_image_size(image)
pad_bottom = output_height - input_height
pad_right = output_width - input_width
if input_channel_dimension == ChannelDimension.FIRST:
padded_image = np.pad(image, [(0, 0), (0, pad_bottom), (0, pad_right)], mode="constant", constant_values=0)
elif input_channel_dimension == ChannelDimension.LAST:
padded_image = np.pad(image, [(0, pad_bottom), (0, pad_right), (0, 0)], mode="constant", constant_values=0)
else:
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
if data_format is not None:
padded_image = to_channel_dimension_format(padded_image, data_format)
return padded_image
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
Args:
image (`np.ndarray`):
Image to make the pixel mask for.
output_size (`Tuple[int, int]`):
Output size of the mask.
"""
input_height, input_width = get_image_size(image)
mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1
return mask
def get_max_dimensions(images: List[np.ndarray]) -> List[int]:
"""
Get the maximum height and width across all images in a batch.
"""
input_channel_dimension = infer_channel_dimension_format(images[0])
if input_channel_dimension == ChannelDimension.FIRST:
_, max_height, max_width = max_across_indices([img.shape for img in images])
elif input_channel_dimension == ChannelDimension.LAST:
max_height, max_width, _ = max_across_indices([img.shape for img in images])
else:
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
return (max_height, max_width)
def get_resize_output_image_size(
input_image: np.ndarray, shorter: int = 800, longer: int = 1333, size_divisor: int = 32
) -> Tuple[int, int]:
input_height, input_width = get_image_size(input_image)
min_size, max_size = shorter, longer
scale = min_size / min(input_height, input_width)
if input_height < input_width:
new_height = min_size
new_width = scale * input_width
else:
new_height = scale * input_height
new_width = min_size
if max(new_height, new_width) > max_size:
scale = max_size / max(new_height, new_width)
new_height = scale * new_height
new_width = scale * new_width
new_height, new_width = int(new_height + 0.5), int(new_width + 0.5)
new_height = new_height // size_divisor * size_divisor
new_width = new_width // size_divisor * size_divisor
return new_height, new_width
class ViltImageProcessor(BaseImageProcessor):
r"""
Constructs a ViLT image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`):
Resize the shorter side of the input to `size["shortest_edge"]`. The longer side will be limited to under
`int((1333 / 800) * size["shortest_edge"])` while preserving the aspect ratio. Only has an effect if
`do_resize` is set to `True`. Can be overridden by the `size` parameter in the `preprocess` method.
size_divisor (`int`, *optional*, defaults to 32):
The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
overridden by the `rescale_factor` parameter in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by
the `do_pad` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
size_divisor: int = 32,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: bool = True,
**kwargs
) -> None:
if "pad_and_return_pixel_mask" in kwargs:
do_pad = kwargs.pop("pad_and_return_pixel_mask")
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 384}
size = get_size_dict(size, default_to_square=False)
self.do_resize = do_resize
self.size = size
self.size_divisor = size_divisor
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_pad = do_pad
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
size_divisor: int = 32,
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image.
Resizes the shorter side of the image to `size["shortest_edge"]` while preserving the aspect ratio. If the
longer side is larger than the max size `(int(`size["shortest_edge"]` * 1333 / 800))`, the longer side is then
resized to the max size while preserving the aspect ratio.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Controls the size of the output image. Should be of the form `{"shortest_edge": int}`.
size_divisor (`int`, defaults to 32):
The image is resized to a size that is a multiple of this value.
resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size:
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
shorter = size["shortest_edge"]
longer = int(1333 / 800 * shorter)
output_size = get_resize_output_image_size(image, shorter=shorter, longer=longer, size_divisor=size_divisor)
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
mean (`float` or `List[float]`):
Image mean.
std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def pad(
self,
images: List[np.ndarray],
return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
) -> BatchFeature:
"""
Pads a batch of images with zeros to the size of largest height and width in the batch and optionally returns
their corresponding pixel mask.
Args:
images (`List[np.ndarray]`):
Batch of images to pad.
return_pixel_mask (`bool`, *optional*, defaults to `False`):
Whether to return the pixel mask.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
pad_size = get_max_dimensions(images)
padded_images = [pad(image=image, output_size=pad_size, data_format=data_format) for image in images]
data = {"pixel_values": padded_images}
if return_pixel_mask:
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
data["pixel_mask"] = masks
return BatchFeature(data=data, tensor_type=return_tensors)
def pad_and_create_pixel_mask(
self,
pixel_values_list: List[ImageInput],
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
) -> BatchFeature:
"""
Pads a batch of images with zeros to the size of largest height and width in the batch and returns their
corresponding pixel mask.
Args:
images (`List[np.ndarray]`):
Batch of images to pad.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
warnings.warn(
"This method is deprecated and will be removed in v4.26.0. Please use pad instead.", FutureWarning
)
# pad expects a list of np.ndarray, but the previous feature extractors expected torch tensors
images = [to_numpy_array(image) for image in pixel_values_list]
return self.pad(
images=images,
return_pixel_mask=True,
return_tensors=return_tensors,
data_format=data_format,
)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
size_divisor: Optional[int] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Controls the size of the image after `resize`. The shortest edge of the image is resized to
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
size_divisor (`int`, *optional*, defaults to `self.size_divisor`):
The image is resized to a size that is a multiple of this value.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to normalize the image by if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
Whether to pad the image to the (max_height, max_width) in the batch. If `True`, a pixel mask is also
created and returned.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size_divisor = size_divisor if size_divisor is not None else self.size_divisor
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_pad = do_pad if do_pad is not None else self.do_pad
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_resize:
images = [
self.resize(image=image, size=size, size_divisor=size_divisor, resample=resample) for image in images
]
if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
if do_pad:
encoded_outputs = self.pad(images, return_pixel_mask=True, return_tensors=return_tensors)
else:
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
return encoded_outputs

View File

@ -14,139 +14,12 @@
# limitations under the License.
"""Feature extractor class for ViT."""
from typing import Optional, Union
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageFeatureExtractionMixin,
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, logging
from ...utils import logging
from .image_processing_vit import ViTImageProcessor
logger = logging.get_logger(__name__)
class ViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a ViT feature extractor.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=224,
resample=PILImageResampling.BILINEAR,
do_normalize=True,
image_mean=None,
image_std=None,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def __call__(
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (resizing + normalization)
if self.do_resize and self.size is not None:
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
# Feature extractor for ViT is being replaced by image processor
ViTFeatureExtractor = ViTImageProcessor

View File

@ -0,0 +1,275 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for ViT."""
from typing import Dict, List, Optional, Union
import numpy as np
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import logging
logger = logging.get_logger(__name__)
class ViTImageProcessor(BaseImageProcessor):
r"""
Constructs a ViT image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize:
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size)
self.do_resize = do_resize
self.do_rescale = do_rescale
self.do_normalize = do_normalize
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
resample:
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The resized image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
return resize(
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
)
def rescale(
self, image: np.ndarray, scale: float, data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs
) -> np.ndarray:
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`float`):
The scaling factor to rescale pixel values by.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The rescaled image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
mean (`float` or `List[float]`):
Image mean to use for normalization.
std (`float` or `List[float]`):
Image standard deviation to use for normalization.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The normalized image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
):
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
resizing.
resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
`PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use if `do_normalize` is set to `True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
resample = resample if resample is not None else self.resample
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
size = size if size is not None else self.size
size_dict = get_size_dict(size)
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_resize:
images = [self.resize(image=image, size=size_dict, resample=resample) for image in images]
if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

View File

@ -44,14 +44,16 @@ class BeitFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=20,
size=None,
do_center_crop=True,
crop_size=18,
crop_size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
reduce_labels=False,
do_reduce_labels=False,
):
size = size if size is not None else {"height": 20, "width": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@ -65,7 +67,7 @@ class BeitFeatureExtractionTester(unittest.TestCase):
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.reduce_labels = reduce_labels
self.do_reduce_labels = do_reduce_labels
def prepare_feat_extract_dict(self):
return {
@ -76,7 +78,7 @@ class BeitFeatureExtractionTester(unittest.TestCase):
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"reduce_labels": self.reduce_labels,
"do_reduce_labels": self.do_reduce_labels,
}
@ -141,8 +143,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -153,8 +155,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -173,8 +175,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -185,8 +187,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -205,8 +207,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -217,8 +219,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -239,16 +241,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
@ -262,16 +264,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
@ -287,16 +289,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
@ -312,16 +314,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
2,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
2,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)

View File

@ -43,14 +43,16 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=20,
size=None,
do_center_crop=True,
crop_size=18,
crop_size=None,
do_normalize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
do_convert_rgb=True,
):
size = size if size is not None else {"shortest_edge": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@ -151,8 +153,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -163,8 +165,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -183,8 +185,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -195,8 +197,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -215,8 +217,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -227,8 +229,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -276,8 +278,8 @@ class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, un
(
1,
self.expected_encoded_image_num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -288,7 +290,7 @@ class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, un
(
self.feature_extract_tester.batch_size,
self.expected_encoded_image_num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)

View File

@ -43,12 +43,13 @@ class ConvNextFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=20,
size=None,
crop_pct=0.875,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
size = size if size is not None else {"shortest_edge": 20}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@ -113,8 +114,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["shortest_edge"],
self.feature_extract_tester.size["shortest_edge"],
),
)
@ -125,8 +126,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["shortest_edge"],
self.feature_extract_tester.size["shortest_edge"],
),
)
@ -145,8 +146,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["shortest_edge"],
self.feature_extract_tester.size["shortest_edge"],
),
)
@ -157,8 +158,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["shortest_edge"],
self.feature_extract_tester.size["shortest_edge"],
),
)
@ -177,8 +178,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["shortest_edge"],
self.feature_extract_tester.size["shortest_edge"],
),
)
@ -189,7 +190,7 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["shortest_edge"],
self.feature_extract_tester.size["shortest_edge"],
),
)

View File

@ -43,13 +43,16 @@ class DeiTFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=20,
size=None,
do_center_crop=True,
crop_size=18,
crop_size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
size = size if size is not None else {"height": 20, "width": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@ -117,8 +120,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -129,8 +132,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -149,8 +152,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -161,8 +164,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -181,8 +184,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -193,7 +196,7 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)

View File

@ -43,11 +43,12 @@ class DPTFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=18,
size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@ -106,8 +107,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -118,8 +119,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -138,8 +139,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -150,8 +151,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -170,8 +171,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -182,7 +183,7 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)

View File

@ -28,11 +28,10 @@ if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
import PIL
from transformers import FlavaFeatureExtractor
from transformers.image_utils import PILImageResampling
from transformers.models.flava.feature_extraction_flava import (
from transformers.models.flava.image_processing_flava import (
FLAVA_CODEBOOK_MEAN,
FLAVA_CODEBOOK_STD,
FLAVA_IMAGE_MEAN,
@ -51,10 +50,12 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=224,
size=None,
do_center_crop=True,
crop_size=224,
crop_size=None,
resample=None,
do_rescale=True,
rescale_factor=1 / 255,
do_normalize=True,
image_mean=FLAVA_IMAGE_MEAN,
image_std=FLAVA_IMAGE_STD,
@ -65,23 +66,30 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
mask_group_min_aspect_ratio=0.3,
mask_group_max_aspect_ratio=None,
codebook_do_resize=True,
codebook_size=112,
codebook_size=None,
codebook_resample=None,
codebook_do_center_crop=True,
codebook_crop_size=112,
codebook_crop_size=None,
codebook_do_map_pixels=True,
codebook_do_normalize=True,
codebook_image_mean=FLAVA_CODEBOOK_MEAN,
codebook_image_std=FLAVA_CODEBOOK_STD,
):
size = size if size is not None else {"height": 224, "width": 224}
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112}
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.do_resize = do_resize
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.size = size
self.resample = resample if resample is not None else PILImageResampling.BICUBIC
self.resample = resample if resample is not None else PIL.Image.Resampling.BICUBIC
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
@ -97,7 +105,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
self.codebook_do_resize = codebook_do_resize
self.codebook_size = codebook_size
self.codebook_resample = codebook_resample if codebook_resample is not None else PILImageResampling.LANCZOS
self.codebook_resample = codebook_resample if codebook_resample is not None else PIL.Image.Resampling.LANCZOS
self.codebook_do_center_crop = codebook_do_center_crop
self.codebook_crop_size = codebook_crop_size
self.codebook_do_map_pixels = codebook_do_map_pixels
@ -113,6 +121,8 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
"do_resize": self.do_resize,
"size": self.size,
"resample": self.resample,
"do_rescale": self.do_rescale,
"rescale_factor": self.rescale_factor,
"do_center_crop": self.do_center_crop,
"crop_size": self.crop_size,
"input_size_patches": self.input_size_patches,
@ -133,7 +143,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
}
def get_expected_image_size(self):
return (self.size, self.size) if not isinstance(self.size, tuple) else self.size
return (self.size["height"], self.size["width"])
def get_expected_mask_size(self):
return (
@ -143,10 +153,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
)
def get_expected_codebook_image_size(self):
if not isinstance(self.codebook_size, tuple):
return (self.codebook_size, self.codebook_size)
else:
return self.codebook_size
return (self.codebook_size["height"], self.codebook_size["width"])
@require_torch
@ -172,6 +179,8 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
self.assertTrue(hasattr(feature_extractor, "resample"))
self.assertTrue(hasattr(feature_extractor, "crop_size"))
self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
self.assertTrue(hasattr(feature_extractor, "do_rescale"))
self.assertTrue(hasattr(feature_extractor, "rescale_factor"))
self.assertTrue(hasattr(feature_extractor, "masking_generator"))
self.assertTrue(hasattr(feature_extractor, "codebook_do_resize"))
self.assertTrue(hasattr(feature_extractor, "codebook_size"))
@ -192,7 +201,7 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
# create random PIL images
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
self.assertIsInstance(image, PIL.Image.Image)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt")
@ -324,7 +333,7 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
# create random PIL images
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
self.assertIsInstance(image, PIL.Image.Image)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_codebook_pixels=True, return_tensors="pt")

View File

@ -32,7 +32,7 @@ if is_vision_available():
from PIL import Image
from transformers import FlavaFeatureExtractor, FlavaProcessor
from transformers.models.flava.feature_extraction_flava import (
from transformers.models.flava.image_processing_flava import (
FLAVA_CODEBOOK_MEAN,
FLAVA_CODEBOOK_STD,
FLAVA_IMAGE_MEAN,
@ -69,7 +69,6 @@ class FlavaProcessorTest(unittest.TestCase):
"mask_group_max_aspect_ratio": None,
"codebook_do_resize": True,
"codebook_size": 112,
"codebook_resample": None,
"codebook_do_center_crop": True,
"codebook_crop_size": 112,
"codebook_do_map_pixels": True,

View File

@ -47,9 +47,10 @@ class ImageGPTFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=18,
size=None,
do_normalize=True,
):
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels

View File

@ -43,9 +43,10 @@ class LayoutLMv2FeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=18,
size=None,
apply_ocr=True,
):
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@ -97,8 +98,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -112,8 +113,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -132,8 +133,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -144,8 +145,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -164,8 +165,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -176,8 +177,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -210,12 +211,4 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
encoding = feature_extractor(image, return_tensors="pt")
self.assertEqual(
encoding.pixel_values.shape,
(
1,
3,
224,
224,
),
)
self.assertEqual(encoding.pixel_values.shape, (1, 3, 224, 224))

View File

@ -43,9 +43,10 @@ class LayoutLMv3FeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=18,
size=None,
apply_ocr=True,
):
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@ -97,8 +98,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -112,8 +113,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -132,8 +133,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -144,8 +145,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -164,8 +165,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -176,8 +177,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)

View File

@ -43,12 +43,15 @@ class LevitFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=18,
size=None,
do_center_crop=True,
crop_size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
size = size if size is not None else {"shortest_edge": 18}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@ -58,6 +61,7 @@ class LevitFeatureExtractionTester(unittest.TestCase):
self.do_resize = do_resize
self.size = size
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
@ -70,6 +74,7 @@ class LevitFeatureExtractionTester(unittest.TestCase):
"do_resize": self.do_resize,
"do_center_crop": self.do_center_crop,
"size": self.size,
"crop_size": self.crop_size,
}
@ -113,8 +118,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -125,8 +130,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -145,8 +150,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -157,8 +162,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -177,8 +182,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -189,7 +194,7 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)

View File

@ -43,11 +43,13 @@ class MobileViTFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=20,
size=None,
do_center_crop=True,
crop_size=18,
crop_size=None,
do_flip_channel_order=True,
):
size = size if size is not None else {"shortest_edge": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@ -109,8 +111,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -121,8 +123,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -141,8 +143,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -153,8 +155,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -173,8 +175,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -185,7 +187,7 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)

View File

@ -41,12 +41,15 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize_and_center_crop=True,
size=30,
size=None,
crop_pct=0.9,
crop_size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
size = size if size is not None else {"shortest_edge": 30}
crop_size = crop_size if crop_size is not None else {"height": 30, "width": 30}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@ -55,6 +58,7 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase):
self.do_resize_and_center_crop = do_resize_and_center_crop
self.size = size
self.crop_pct = crop_pct
self.crop_size = crop_size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
@ -64,6 +68,7 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase):
"size": self.size,
"do_resize_and_center_crop": self.do_resize_and_center_crop,
"crop_pct": self.crop_pct,
"crop_size": self.crop_size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
@ -111,8 +116,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -123,8 +128,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -143,8 +148,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -155,8 +160,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -175,8 +180,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -187,7 +192,7 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)

View File

@ -43,12 +43,13 @@ class SegformerFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=30,
size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
reduce_labels=False,
do_reduce_labels=False,
):
size = size if size is not None else {"height": 30, "width": 30}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@ -59,7 +60,7 @@ class SegformerFeatureExtractionTester(unittest.TestCase):
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.reduce_labels = reduce_labels
self.do_reduce_labels = do_reduce_labels
def prepare_feat_extract_dict(self):
return {
@ -68,7 +69,7 @@ class SegformerFeatureExtractionTester(unittest.TestCase):
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"reduce_labels": self.reduce_labels,
"do_reduce_labels": self.do_reduce_labels,
}
@ -112,7 +113,7 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "reduce_labels"))
self.assertTrue(hasattr(feature_extractor, "do_reduce_labels"))
def test_batch_feature(self):
pass
@ -132,8 +133,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -144,8 +145,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -164,8 +165,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -176,8 +177,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -196,8 +197,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -208,8 +209,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -230,16 +231,16 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
@ -253,16 +254,16 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
@ -278,16 +279,16 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
@ -303,16 +304,16 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
2,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
2,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)

View File

@ -44,11 +44,15 @@ class VideoMAEFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=18,
size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
crop_size=None,
):
size = size if size is not None else {"shortest_edge": 18}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@ -61,6 +65,7 @@ class VideoMAEFeatureExtractionTester(unittest.TestCase):
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.crop_size = crop_size
def prepare_feat_extract_dict(self):
return {
@ -69,6 +74,7 @@ class VideoMAEFeatureExtractionTester(unittest.TestCase):
"do_normalize": self.do_normalize,
"do_resize": self.do_resize,
"size": self.size,
"crop_size": self.crop_size,
}
@ -91,6 +97,7 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
self.assertTrue(hasattr(feature_extractor, "size"))
def test_batch_feature(self):
@ -113,8 +120,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
1,
self.feature_extract_tester.num_frames,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -126,8 +133,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_frames,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -148,8 +155,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
1,
self.feature_extract_tester.num_frames,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -161,8 +168,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_frames,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -183,8 +190,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
1,
self.feature_extract_tester.num_frames,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
@ -196,7 +203,7 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_frames,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)

View File

@ -43,12 +43,13 @@ class ViltFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=30,
size=None,
size_divisor=2,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
size = size if size is not None else {"shortest_edge": 30}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@ -78,18 +79,19 @@ class ViltFeatureExtractionTester(unittest.TestCase):
assuming do_resize is set to True with a scalar size and size_divisor.
"""
if not batched:
size = self.size["shortest_edge"]
image = image_inputs[0]
if isinstance(image, Image.Image):
w, h = image.size
else:
h, w = image.shape[1], image.shape[2]
scale = self.size / min(w, h)
scale = size / min(w, h)
if h < w:
newh, neww = self.size, scale * w
newh, neww = size, scale * w
else:
newh, neww = scale * h, self.size
newh, neww = scale * h, size
max_size = int((1333 / 800) * self.size)
max_size = int((1333 / 800) * size)
if max(newh, neww) > max_size:
scale = max_size / max(newh, neww)
newh = newh * scale
@ -233,7 +235,7 @@ class ViltFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
def test_equivalence_pad_and_create_pixel_mask(self):
# Initialize feature_extractors
feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False)
feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False, do_rescale=False)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
for image in image_inputs:

View File

@ -43,11 +43,12 @@ class ViTFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=18,
size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@ -109,8 +110,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -121,8 +122,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -141,8 +142,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -153,8 +154,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -173,8 +174,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
@ -185,7 +186,7 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)

View File

@ -0,0 +1,71 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.image_processing_utils import get_size_dict
class ImageProcessingUtilsTester(unittest.TestCase):
def test_get_size_dict(self):
# Test a dict with the wrong keys raises an error
inputs = {"wrong_key": 224}
with self.assertRaises(ValueError):
get_size_dict(inputs)
inputs = {"height": 224}
with self.assertRaises(ValueError):
get_size_dict(inputs)
inputs = {"width": 224, "shortest_edge": 224}
with self.assertRaises(ValueError):
get_size_dict(inputs)
# Test a dict with the correct keys is returned as is
inputs = {"height": 224, "width": 224}
outputs = get_size_dict(inputs)
self.assertEqual(outputs, inputs)
inputs = {"shortest_edge": 224}
outputs = get_size_dict(inputs)
self.assertEqual(outputs, {"shortest_edge": 224})
inputs = {"longest_edge": 224, "shortest_edge": 224}
outputs = get_size_dict(inputs)
self.assertEqual(outputs, {"longest_edge": 224, "shortest_edge": 224})
# Test a single int value which represents (size, size)
outputs = get_size_dict(224)
self.assertEqual(outputs, {"height": 224, "width": 224})
# Test a single int value which represents the shortest edge
outputs = get_size_dict(224, default_to_square=False)
self.assertEqual(outputs, {"shortest_edge": 224})
# Test a tuple of ints which represents (height, width)
outputs = get_size_dict((150, 200))
self.assertEqual(outputs, {"height": 150, "width": 200})
# Test a tuple of ints which represents (width, height)
outputs = get_size_dict((150, 200), height_width_order=False)
self.assertEqual(outputs, {"height": 200, "width": 150})
# Test an int representing the shortest edge and max_size which represents the longest edge
outputs = get_size_dict(224, max_size=256, default_to_square=False)
self.assertEqual(outputs, {"shortest_edge": 224, "longest_edge": 256})
# Test int with default_to_square=True and max_size fails
with self.assertRaises(ValueError):
get_size_dict(224, max_size=256, default_to_square=True)