From c9fb250a255cec542f9ec7bd072e3db2ebcbcd70 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Fri, 22 Dec 2023 12:12:56 +0100 Subject: [PATCH] Add Swinv2 backbone (#27742) * First draft * More improvements * More improvements * Make all tests pass * Remove script * Update image processor * Address comments * Use new gradient checkpointing method * Convert checkpoints, add integration test * Do not keep aspect ratio for now * Set keep_aspect_ratio=False for beit, add integration test * Remove print statement --- src/transformers/__init__.py | 2 + src/transformers/models/auto/modeling_auto.py | 1 + .../models/dpt/convert_dpt_beit_to_hf.py | 3 +- .../models/dpt/convert_dpt_swinv2_to_hf.py | 322 ++++++++++++++++++ .../models/dpt/image_processing_dpt.py | 14 +- .../models/swin2sr/modeling_swin2sr.py | 41 +-- src/transformers/models/swinv2/__init__.py | 2 + .../models/swinv2/configuration_swinv2.py | 22 +- .../models/swinv2/modeling_swinv2.py | 223 ++++++++---- src/transformers/utils/dummy_pt_objects.py | 7 + tests/models/dpt/test_image_processing_dpt.py | 10 + .../dpt/test_modeling_dpt_auto_backbone.py | 46 ++- tests/models/swinv2/test_modeling_swinv2.py | 103 +++++- utils/check_repo.py | 1 + 14 files changed, 667 insertions(+), 130 deletions(-) create mode 100644 src/transformers/models/dpt/convert_dpt_swinv2_to_hf.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 98139511d2..e7d168154a 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3213,6 +3213,7 @@ else: _import_structure["models.swinv2"].extend( [ "SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST", + "Swinv2Backbone", "Swinv2ForImageClassification", "Swinv2ForMaskedImageModeling", "Swinv2Model", @@ -7541,6 +7542,7 @@ if TYPE_CHECKING: ) from .models.swinv2 import ( SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST, + Swinv2Backbone, Swinv2ForImageClassification, Swinv2ForMaskedImageModeling, Swinv2Model, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e562bd28bd..9978b13530 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1116,6 +1116,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( ("nat", "NatBackbone"), ("resnet", "ResNetBackbone"), ("swin", "SwinBackbone"), + ("swinv2", "Swinv2Backbone"), ("timm_backbone", "TimmBackbone"), ("vitdet", "VitDetBackbone"), ] diff --git a/src/transformers/models/dpt/convert_dpt_beit_to_hf.py b/src/transformers/models/dpt/convert_dpt_beit_to_hf.py index 1e7d438e02..eb335a0ea2 100644 --- a/src/transformers/models/dpt/convert_dpt_beit_to_hf.py +++ b/src/transformers/models/dpt/convert_dpt_beit_to_hf.py @@ -207,8 +207,9 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub): model.eval() # Check outputs on an image + # We set `keep_aspect_ratio=False` as our current BEiT does not support arbitrary window sizes processor = DPTImageProcessor( - size={"height": image_size, "width": image_size}, keep_aspect_ratio=True, ensure_multiple_of=32 + size={"height": image_size, "width": image_size}, keep_aspect_ratio=False, ensure_multiple_of=32 ) image = prepare_img() diff --git a/src/transformers/models/dpt/convert_dpt_swinv2_to_hf.py b/src/transformers/models/dpt/convert_dpt_swinv2_to_hf.py new file mode 100644 index 0000000000..fd6522ab6b --- /dev/null +++ b/src/transformers/models/dpt/convert_dpt_swinv2_to_hf.py @@ -0,0 +1,322 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert DPT 3.1 checkpoints from the MiDaS repository. URL: https://github.com/isl-org/MiDaS""" + + +import argparse +from pathlib import Path + +import requests +import torch +from PIL import Image + +from transformers import DPTConfig, DPTForDepthEstimation, DPTImageProcessor, Swinv2Config +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_dpt_config(model_name): + if "tiny" in model_name: + embed_dim = 96 + depths = (2, 2, 6, 2) + num_heads = (3, 6, 12, 24) + window_size = 16 + # note: for Swinv2-tiny authors used the window_size = 16 variant + # as seen here: https://github.com/isl-org/MiDaS/blob/bdc4ed64c095e026dc0a2f17cabb14d58263decb/midas/backbones/swin2.py#L26 + pretrained_window_sizes = (0, 0, 0, 0) + elif "base" in model_name: + embed_dim = 128 + depths = (2, 2, 18, 2) + num_heads = (4, 8, 16, 32) + window_size = 24 + pretrained_window_sizes = (12, 12, 12, 6) + elif "large" in model_name: + embed_dim = 192 + depths = (2, 2, 18, 2) + num_heads = (6, 12, 24, 48) + window_size = 24 + pretrained_window_sizes = (12, 12, 12, 6) + + if "384" in model_name: + image_size = 384 + elif "256" in model_name: + image_size = 256 + else: + raise ValueError("Model not supported, to do") + + backbone_config = Swinv2Config( + image_size=image_size, + embed_dim=embed_dim, + depths=depths, + window_size=window_size, + pretrained_window_sizes=pretrained_window_sizes, + num_heads=num_heads, + out_features=["stage1", "stage2", "stage3", "stage4"], + ) + + if model_name == "dpt-swinv2-tiny-256": + neck_hidden_sizes = [96, 192, 384, 768] + elif model_name == "dpt-swinv2-base-384": + neck_hidden_sizes = [128, 256, 512, 1024] + elif model_name == "dpt-swinv2-large-384": + neck_hidden_sizes = [192, 384, 768, 1536] + + config = DPTConfig(backbone_config=backbone_config, neck_hidden_sizes=neck_hidden_sizes) + + return config, image_size + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config): + rename_keys = [] + + # fmt: off + # stem + rename_keys.append(("pretrained.model.patch_embed.proj.weight", "backbone.embeddings.patch_embeddings.projection.weight")) + rename_keys.append(("pretrained.model.patch_embed.proj.bias", "backbone.embeddings.patch_embeddings.projection.bias")) + rename_keys.append(("pretrained.model.patch_embed.norm.weight", "backbone.embeddings.norm.weight")) + rename_keys.append(("pretrained.model.patch_embed.norm.bias", "backbone.embeddings.norm.bias")) + + # transformer encoder + for i in range(len(config.backbone_config.depths)): + for j in range(config.backbone_config.depths[i]): + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.logit_scale", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.logit_scale")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.cpb_mlp.0.weight", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.continuous_position_bias_mlp.0.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.cpb_mlp.0.bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.continuous_position_bias_mlp.0.bias")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.cpb_mlp.2.weight", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.continuous_position_bias_mlp.2.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.q_bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.query.bias")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.v_bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.value.bias")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.proj.weight", f"backbone.encoder.layers.{i}.blocks.{j}.attention.output.dense.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.proj.bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.output.dense.bias")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm1.weight", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_before.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm1.bias", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_before.bias")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc1.weight", f"backbone.encoder.layers.{i}.blocks.{j}.intermediate.dense.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc1.bias", f"backbone.encoder.layers.{i}.blocks.{j}.intermediate.dense.bias")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc2.weight", f"backbone.encoder.layers.{i}.blocks.{j}.output.dense.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc2.bias", f"backbone.encoder.layers.{i}.blocks.{j}.output.dense.bias")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm2.weight", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_after.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm2.bias", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_after.bias")) + + # downsample parameters + if i in [0,1,2]: + rename_keys.append((f"pretrained.model.layers.{i}.downsample.reduction.weight", f"backbone.encoder.layers.{i}.downsample.reduction.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.downsample.norm.weight", f"backbone.encoder.layers.{i}.downsample.norm.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.downsample.norm.bias", f"backbone.encoder.layers.{i}.downsample.norm.bias")) + + # note: non-Transformer backbones like Swinv2, LeViT et al don't require activation postprocessing (readout projections + resize blocks) + + # refinenet (tricky here) + mapping = {1:3, 2:2, 3:1, 4:0} + + for i in range(1, 5): + j = mapping[i] + rename_keys.append((f"scratch.refinenet{i}.out_conv.weight", f"neck.fusion_stage.layers.{j}.projection.weight")) + rename_keys.append((f"scratch.refinenet{i}.out_conv.bias", f"neck.fusion_stage.layers.{j}.projection.bias")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.weight")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.bias")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.weight")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.bias")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.weight")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.bias")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.weight")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.bias")) + + # scratch convolutions + for i in range(4): + rename_keys.append((f"scratch.layer{i+1}_rn.weight", f"neck.convs.{i}.weight")) + + # head + for i in range(0, 5, 2): + rename_keys.append((f"scratch.output_conv.{i}.weight", f"head.head.{i}.weight")) + rename_keys.append((f"scratch.output_conv.{i}.bias", f"head.head.{i}.bias")) + + return rename_keys + + +def remove_ignore_keys_(state_dict): + ignore_keys = ["pretrained.model.head.weight", "pretrained.model.head.bias"] + for k in ignore_keys: + state_dict.pop(k, None) + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config, model): + for i in range(len(config.backbone_config.depths)): + for j in range(config.backbone_config.depths[i]): + dim = model.backbone.encoder.layers[i].blocks[j].attention.self.all_head_size + # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"pretrained.model.layers.{i}.blocks.{j}.attn.qkv.weight") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.query.weight"] = in_proj_weight[:dim, :] + state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.key.weight"] = in_proj_weight[ + dim : dim * 2, : + ] + state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.value.weight"] = in_proj_weight[ + -dim:, : + ] + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, verify_logits, push_to_hub): + """ + Copy/paste/tweak model's weights to our DPT structure. + """ + + name_to_url = { + "dpt-swinv2-tiny-256": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt", + "dpt-swinv2-base-384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt", + "dpt-swinv2-large-384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt", + } + + # define DPT configuration based on URL + checkpoint_url = name_to_url[model_name] + config, image_size = get_dpt_config(model_name) + # load original state_dict from URL + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu") + + # load HuggingFace model + model = DPTForDepthEstimation(config) + + # remove certain keys + remove_ignore_keys_(state_dict) + # rename keys + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + # read in qkv matrices + read_in_q_k_v(state_dict, config, model) + + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + model.eval() + + # Check outputs on an image + processor = DPTImageProcessor(size={"height": image_size, "width": image_size}) + + image = prepare_img() + processor(image, return_tensors="pt") + + if verify_logits: + from torchvision import transforms + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + transforms = transforms.Compose( + [ + transforms.Resize((image_size, image_size)), + transforms.ToTensor(), + ] + ) + pixel_values = transforms(image).unsqueeze(0) + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values) + + predicted_depth = outputs.predicted_depth + + print("Shape of predicted depth:", predicted_depth.shape) + print("First values of predicted depth:", predicted_depth[0, :3, :3]) + + # assert logits + if model_name == "dpt-swinv2-base-384": + # OK, checked + expected_shape = torch.Size([1, 384, 384]) + expected_slice = torch.tensor( + [ + [1998.5575, 1997.3887, 2009.2981], + [1952.8607, 1979.6488, 2001.0854], + [1953.7697, 1961.7711, 1968.8904], + ], + ) + elif model_name == "dpt-swinv2-tiny-256": + # OK, checked + expected_shape = torch.Size([1, 256, 256]) + expected_slice = torch.tensor( + [[978.9163, 976.5215, 978.5349], [974.1859, 971.7249, 975.8046], [971.3419, 970.3118, 971.6830]], + ) + elif model_name == "dpt-swinv2-large-384": + # OK, checked + expected_shape = torch.Size([1, 384, 384]) + expected_slice = torch.tensor( + [ + [1203.7206, 1200.1495, 1197.8234], + [1196.2484, 1183.5033, 1186.4640], + [1178.8131, 1182.3260, 1174.3975], + ], + ) + + assert predicted_depth.shape == torch.Size(expected_shape) + assert torch.allclose(predicted_depth[0, :3, :3], expected_slice) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model and processor to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print("Pushing model and processor to hub...") + model.push_to_hub(repo_id=f"Intel/{model_name}") + processor.push_to_hub(repo_id=f"Intel/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="dpt-swinv2-base-384", + type=str, + choices=["dpt-swinv2-tiny-256", "dpt-swinv2-base-384", "dpt-swinv2-large-384"], + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--verify_logits", + action="store_true", + help="Whether to verify logits after conversion.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether to push the model to the hub after conversion.", + ) + + args = parser.parse_args() + convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub) diff --git a/src/transformers/models/dpt/image_processing_dpt.py b/src/transformers/models/dpt/image_processing_dpt.py index 49403140b9..ec1b8fead2 100644 --- a/src/transformers/models/dpt/image_processing_dpt.py +++ b/src/transformers/models/dpt/image_processing_dpt.py @@ -100,7 +100,7 @@ class DPTImageProcessor(BaseImageProcessor): Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`. size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 384}`): Size of the image after resizing. Can be overidden by `size` in `preprocess`. - resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`. keep_aspect_ratio (`bool`, *optional*, defaults to `False`): If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can @@ -136,7 +136,7 @@ class DPTImageProcessor(BaseImageProcessor): self, do_resize: bool = True, size: Dict[str, int] = None, - resample: PILImageResampling = PILImageResampling.BILINEAR, + resample: PILImageResampling = PILImageResampling.BICUBIC, keep_aspect_ratio: bool = False, ensure_multiple_of: int = 1, do_rescale: bool = True, @@ -202,6 +202,7 @@ class DPTImageProcessor(BaseImageProcessor): size = get_size_dict(size) if "height" not in size or "width" not in size: raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}") + output_size = get_resize_output_image_size( image, output_size=(size["height"], size["width"]), @@ -381,7 +382,14 @@ class DPTImageProcessor(BaseImageProcessor): if do_resize: images = [ - self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + self.resize( + image=image, + size=size, + resample=resample, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=ensure_multiple_of, + input_data_format=input_data_format, + ) for image in images ] diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index 1884a4a2c4..b3ef7a2a2f 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -489,11 +489,12 @@ class Swin2SROutput(nn.Module): class Swin2SRLayer(nn.Module): def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretrained_window_size=0): super().__init__() - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.shift_size = shift_size - self.window_size = config.window_size self.input_resolution = input_resolution - self.set_shift_and_window_size(input_resolution) + window_size, shift_size = self._compute_window_shift( + (config.window_size, config.window_size), (shift_size, shift_size) + ) + self.window_size = window_size[0] + self.shift_size = shift_size[0] self.attention = Swin2SRAttention( config=config, dim=dim, @@ -509,29 +510,10 @@ class Swin2SRLayer(nn.Module): self.output = Swin2SROutput(config, dim) self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) - def set_shift_and_window_size(self, input_resolution): - target_window_size = ( - self.window_size - if isinstance(self.window_size, collections.abc.Iterable) - else (self.window_size, self.window_size) - ) - target_shift_size = ( - self.shift_size - if isinstance(self.shift_size, collections.abc.Iterable) - else (self.shift_size, self.shift_size) - ) - window_dim = input_resolution[0].item() if torch.is_tensor(input_resolution[0]) else input_resolution[0] - self.window_size = window_dim if window_dim <= target_window_size[0] else target_window_size[0] - self.shift_size = ( - 0 - if input_resolution - <= ( - self.window_size - if isinstance(self.window_size, collections.abc.Iterable) - else (self.window_size, self.window_size) - ) - else target_shift_size[0] - ) + def _compute_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: + window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] + shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] + return window_size, shift_size def get_attn_mask(self, height, width, dtype): if self.shift_size > 0: @@ -574,12 +556,7 @@ class Swin2SRLayer(nn.Module): input_dimensions: Tuple[int, int], head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, - always_partition: Optional[bool] = False, ) -> Tuple[torch.Tensor, torch.Tensor]: - if not always_partition: - self.set_shift_and_window_size(input_dimensions) - else: - pass height, width = input_dimensions batch_size, _, channels = hidden_states.size() shortcut = hidden_states diff --git a/src/transformers/models/swinv2/__init__.py b/src/transformers/models/swinv2/__init__.py index 5b3bb21cad..b104662e08 100644 --- a/src/transformers/models/swinv2/__init__.py +++ b/src/transformers/models/swinv2/__init__.py @@ -33,6 +33,7 @@ else: "Swinv2ForMaskedImageModeling", "Swinv2Model", "Swinv2PreTrainedModel", + "Swinv2Backbone", ] @@ -47,6 +48,7 @@ if TYPE_CHECKING: else: from .modeling_swinv2 import ( SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST, + Swinv2Backbone, Swinv2ForImageClassification, Swinv2ForMaskedImageModeling, Swinv2Model, diff --git a/src/transformers/models/swinv2/configuration_swinv2.py b/src/transformers/models/swinv2/configuration_swinv2.py index 1dac62583c..3c839e3f94 100644 --- a/src/transformers/models/swinv2/configuration_swinv2.py +++ b/src/transformers/models/swinv2/configuration_swinv2.py @@ -16,6 +16,7 @@ from ...configuration_utils import PretrainedConfig from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices logger = logging.get_logger(__name__) @@ -27,7 +28,7 @@ SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP = { } -class Swinv2Config(PretrainedConfig): +class Swinv2Config(BackboneConfigMixin, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Swinv2Model`]. It is used to instantiate a Swin Transformer v2 model according to the specified arguments, defining the model architecture. Instantiating a @@ -53,6 +54,8 @@ class Swinv2Config(PretrainedConfig): Number of attention heads in each layer of the Transformer encoder. window_size (`int`, *optional*, defaults to 7): Size of windows. + pretrained_window_sizes (`list(int)`, *optional*, defaults to `[0, 0, 0, 0]`): + Size of windows during pretraining. mlp_ratio (`float`, *optional*, defaults to 4.0): Ratio of MLP hidden dimensionality to embedding dimensionality. qkv_bias (`bool`, *optional*, defaults to `True`): @@ -74,6 +77,14 @@ class Swinv2Config(PretrainedConfig): The epsilon used by the layer normalization layers. encoder_stride (`int`, *optional*, defaults to 32): Factor to increase the spatial resolution by in the decoder head for masked image modeling. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Example: @@ -106,6 +117,7 @@ class Swinv2Config(PretrainedConfig): depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, + pretrained_window_sizes=[0, 0, 0, 0], mlp_ratio=4.0, qkv_bias=True, hidden_dropout_prob=0.0, @@ -116,6 +128,8 @@ class Swinv2Config(PretrainedConfig): initializer_range=0.02, layer_norm_eps=1e-5, encoder_stride=32, + out_features=None, + out_indices=None, **kwargs, ): super().__init__(**kwargs) @@ -128,6 +142,7 @@ class Swinv2Config(PretrainedConfig): self.num_layers = len(depths) self.num_heads = num_heads self.window_size = window_size + self.pretrained_window_sizes = pretrained_window_sizes self.mlp_ratio = mlp_ratio self.qkv_bias = qkv_bias self.hidden_dropout_prob = hidden_dropout_prob @@ -138,7 +153,10 @@ class Swinv2Config(PretrainedConfig): self.layer_norm_eps = layer_norm_eps self.initializer_range = initializer_range self.encoder_stride = encoder_stride + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) # we set the hidden_size attribute in order to make Swinv2 work with VisionEncoderDecoderModel # this indicates the channel dimension after the last stage of the model self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) - self.pretrained_window_sizes = (0, 0, 0, 0) diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index ebe9426689..ed5130c02e 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -23,10 +23,11 @@ from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint -from torch import nn +from torch import Tensor, nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import ( @@ -37,6 +38,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.backbone_utils import BackboneMixin from .configuration_swinv2 import Swinv2Config @@ -641,11 +643,12 @@ class Swinv2Output(nn.Module): class Swinv2Layer(nn.Module): def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretrained_window_size=0): super().__init__() - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.shift_size = shift_size - self.window_size = config.window_size self.input_resolution = input_resolution - self.set_shift_and_window_size(input_resolution) + window_size, shift_size = self._compute_window_shift( + (config.window_size, config.window_size), (shift_size, shift_size) + ) + self.window_size = window_size[0] + self.shift_size = shift_size[0] self.attention = Swinv2Attention( config=config, dim=dim, @@ -661,29 +664,10 @@ class Swinv2Layer(nn.Module): self.output = Swinv2Output(config, dim) self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) - def set_shift_and_window_size(self, input_resolution): - target_window_size = ( - self.window_size - if isinstance(self.window_size, collections.abc.Iterable) - else (self.window_size, self.window_size) - ) - target_shift_size = ( - self.shift_size - if isinstance(self.shift_size, collections.abc.Iterable) - else (self.shift_size, self.shift_size) - ) - window_dim = input_resolution[0].item() if torch.is_tensor(input_resolution[0]) else input_resolution[0] - self.window_size = window_dim if window_dim <= target_window_size[0] else target_window_size[0] - self.shift_size = ( - 0 - if input_resolution - <= ( - self.window_size - if isinstance(self.window_size, collections.abc.Iterable) - else (self.window_size, self.window_size) - ) - else target_shift_size[0] - ) + def _compute_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: + window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] + shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] + return window_size, shift_size def get_attn_mask(self, height, width, dtype): if self.shift_size > 0: @@ -726,12 +710,7 @@ class Swinv2Layer(nn.Module): input_dimensions: Tuple[int, int], head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, - always_partition: Optional[bool] = False, ) -> Tuple[torch.Tensor, torch.Tensor]: - if not always_partition: - self.set_shift_and_window_size(input_dimensions) - else: - pass height, width = input_dimensions batch_size, _, channels = hidden_states.size() shortcut = hidden_states @@ -791,24 +770,18 @@ class Swinv2Stage(nn.Module): super().__init__() self.config = config self.dim = dim - window_size = ( - config.window_size - if isinstance(config.window_size, collections.abc.Iterable) - else (config.window_size, config.window_size) - ) - self.blocks = nn.ModuleList( - [ - Swinv2Layer( - config=config, - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - shift_size=[0, 0] if (i % 2 == 0) else [window_size[0] // 2, window_size[1] // 2], - pretrained_window_size=pretrained_window_size, - ) - for i in range(depth) - ] - ) + blocks = [] + for i in range(depth): + block = Swinv2Layer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + pretrained_window_size=pretrained_window_size, + ) + blocks.append(block) + self.blocks = nn.ModuleList(blocks) # patch merging layer if downsample is not None: @@ -818,21 +791,22 @@ class Swinv2Stage(nn.Module): self.pointing = False - # Copied from transformers.models.swin.modeling_swin.SwinStage.forward def forward( self, hidden_states: torch.Tensor, input_dimensions: Tuple[int, int], head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, - always_partition: Optional[bool] = False, ) -> Tuple[torch.Tensor]: height, width = input_dimensions for i, layer_module in enumerate(self.blocks): layer_head_mask = head_mask[i] if head_mask is not None else None layer_outputs = layer_module( - hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, ) hidden_states = layer_outputs[0] @@ -860,25 +834,24 @@ class Swinv2Encoder(nn.Module): if self.config.pretrained_window_sizes is not None: pretrained_window_sizes = config.pretrained_window_sizes dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] - self.layers = nn.ModuleList( - [ - Swinv2Stage( - config=config, - dim=int(config.embed_dim * 2**i_layer), - input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), - depth=config.depths[i_layer], - num_heads=config.num_heads[i_layer], - drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], - downsample=Swinv2PatchMerging if (i_layer < self.num_layers - 1) else None, - pretrained_window_size=pretrained_window_sizes[i_layer], - ) - for i_layer in range(self.num_layers) - ] - ) + + layers = [] + for i_layer in range(self.num_layers): + stage = Swinv2Stage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=Swinv2PatchMerging if (i_layer < self.num_layers - 1) else None, + pretrained_window_size=pretrained_window_sizes[i_layer], + ) + layers.append(stage) + self.layers = nn.ModuleList(layers) self.gradient_checkpointing = False - # Copied from transformers.models.swin.modeling_swin.SwinEncoder.forward with SwinEncoderOutput->Swinv2EncoderOutput def forward( self, hidden_states: torch.Tensor, @@ -887,7 +860,6 @@ class Swinv2Encoder(nn.Module): output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, output_hidden_states_before_downsampling: Optional[bool] = False, - always_partition: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[Tuple, Swinv2EncoderOutput]: all_hidden_states = () if output_hidden_states else None @@ -907,11 +879,14 @@ class Swinv2Encoder(nn.Module): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask ) else: layer_outputs = layer_module( - hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, ) hidden_states = layer_outputs[0] @@ -942,7 +917,11 @@ class Swinv2Encoder(nn.Module): all_self_attentions += layer_outputs[3:] if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_reshaped_hidden_states] + if v is not None + ) return Swinv2EncoderOutput( last_hidden_state=hidden_states, @@ -1323,3 +1302,99 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel): attentions=outputs.attentions, reshaped_hidden_states=outputs.reshaped_hidden_states, ) + + +@add_start_docstrings( + """ + Swinv2 backbone, to be used with frameworks like DETR and MaskFormer. + """, + SWINV2_START_DOCSTRING, +) +class Swinv2Backbone(Swinv2PreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] + self.embeddings = Swinv2Embeddings(config) + self.encoder = Swinv2Encoder(config, self.embeddings.patch_grid) + + # initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256") + >>> model = AutoBackbone.from_pretrained( + ... "microsoft/swinv2-tiny-patch4-window8-256", out_features=["stage1", "stage2", "stage3", "stage4"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 2048, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output, input_dimensions = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=None, + output_attentions=output_attentions, + output_hidden_states=True, + output_hidden_states_before_downsampling=True, + return_dict=return_dict, + ) + + hidden_states = outputs.reshaped_hidden_states if return_dict else outputs[-1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + feature_maps += (hidden_state,) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs[1],) + if output_attentions: + output += (outputs[2],) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index b9b3e9b580..3832d48f2e 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -7759,6 +7759,13 @@ class Swin2SRPreTrainedModel(metaclass=DummyObject): SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST = None +class Swinv2Backbone(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Swinv2ForImageClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/dpt/test_image_processing_dpt.py b/tests/models/dpt/test_image_processing_dpt.py index a70165048b..2cc72274c4 100644 --- a/tests/models/dpt/test_image_processing_dpt.py +++ b/tests/models/dpt/test_image_processing_dpt.py @@ -126,3 +126,13 @@ class DPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): ).pixel_values self.assertTrue(pixel_values.shape[2] % 4 == 0) self.assertTrue(pixel_values.shape[3] % 4 == 0) + + def test_keep_aspect_ratio(self): + size = {"height": 512, "width": 512} + image_processor = DPTImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=32) + + image = np.zeros((489, 640, 3)) + + pixel_values = image_processor(image, return_tensors="pt").pixel_values + + self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672]) diff --git a/tests/models/dpt/test_modeling_dpt_auto_backbone.py b/tests/models/dpt/test_modeling_dpt_auto_backbone.py index 76ab220583..aa240f0599 100644 --- a/tests/models/dpt/test_modeling_dpt_auto_backbone.py +++ b/tests/models/dpt/test_modeling_dpt_auto_backbone.py @@ -258,7 +258,7 @@ def prepare_img(): @require_vision @slow class DPTModelIntegrationTest(unittest.TestCase): - def test_inference_depth_estimation(self): + def test_inference_depth_estimation_dinov2(self): image_processor = DPTImageProcessor.from_pretrained("facebook/dpt-dinov2-small-kitti") model = DPTForDepthEstimation.from_pretrained("facebook/dpt-dinov2-small-kitti").to(torch_device) @@ -279,3 +279,47 @@ class DPTModelIntegrationTest(unittest.TestCase): ).to(torch_device) self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4)) + + def test_inference_depth_estimation_beit(self): + image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-beit-base-384") + model = DPTForDepthEstimation.from_pretrained("Intel/dpt-beit-base-384").to(torch_device) + + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + predicted_depth = outputs.predicted_depth + + # verify the predicted depth + expected_shape = torch.Size((1, 384, 384)) + self.assertEqual(predicted_depth.shape, expected_shape) + + expected_slice = torch.tensor( + [[2669.7061, 2663.7144, 2674.9399], [2633.9326, 2650.9092, 2665.4270], [2621.8271, 2632.0129, 2637.2290]] + ).to(torch_device) + + self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4)) + + def test_inference_depth_estimation_swinv2(self): + image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256") + model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256").to(torch_device) + + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + predicted_depth = outputs.predicted_depth + + # verify the predicted depth + expected_shape = torch.Size((1, 256, 256)) + self.assertEqual(predicted_depth.shape, expected_shape) + + expected_slice = torch.tensor( + [[1032.7719, 1025.1886, 1030.2661], [1023.7619, 1021.0075, 1024.9121], [1022.5667, 1018.8522, 1021.4145]] + ).to(torch_device) + + self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4)) diff --git a/tests/models/swinv2/test_modeling_swinv2.py b/tests/models/swinv2/test_modeling_swinv2.py index 9b9d08b39f..ebe05a9a71 100644 --- a/tests/models/swinv2/test_modeling_swinv2.py +++ b/tests/models/swinv2/test_modeling_swinv2.py @@ -14,12 +14,14 @@ # limitations under the License. """ Testing suite for the PyTorch Swinv2 model. """ import collections +import inspect import unittest from transformers import Swinv2Config from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available +from ...test_backbone_common import BackboneTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -29,7 +31,7 @@ if is_torch_available(): import torch from torch import nn - from transformers import Swinv2ForImageClassification, Swinv2ForMaskedImageModeling, Swinv2Model + from transformers import Swinv2Backbone, Swinv2ForImageClassification, Swinv2ForMaskedImageModeling, Swinv2Model from transformers.models.swinv2.modeling_swinv2 import SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): @@ -65,6 +67,8 @@ class Swinv2ModelTester: use_labels=True, type_sequence_label_size=10, encoder_stride=8, + out_features=["stage1", "stage2"], + out_indices=[1, 2], ): self.parent = parent self.batch_size = batch_size @@ -90,6 +94,8 @@ class Swinv2ModelTester: self.use_labels = use_labels self.type_sequence_label_size = type_sequence_label_size self.encoder_stride = encoder_stride + self.out_features = out_features + self.out_indices = out_indices def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -122,6 +128,8 @@ class Swinv2ModelTester: layer_norm_eps=self.layer_norm_eps, initializer_range=self.initializer_range, encoder_stride=self.encoder_stride, + out_features=self.out_features, + out_indices=self.out_indices, ) def create_and_check_model(self, config, pixel_values, labels): @@ -135,6 +143,33 @@ class Swinv2ModelTester: self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) + def create_and_check_backbone(self, config, pixel_values, labels): + model = Swinv2Backbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify hidden states + self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], 16, 16]) + + # verify channels + self.parent.assertEqual(len(model.channels), len(config.out_features)) + + # verify backbone works with out_features=None + config.out_features = None + model = Swinv2Backbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), 1) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], 4, 4]) + + # verify channels + self.parent.assertEqual(len(model.channels), 1) + def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels): model = Swinv2ForMaskedImageModeling(config=config) model.to(torch_device) @@ -172,7 +207,14 @@ class Swinv2ModelTester: @require_torch class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (Swinv2Model, Swinv2ForImageClassification, Swinv2ForMaskedImageModeling) if is_torch_available() else () + ( + Swinv2Model, + Swinv2ForImageClassification, + Swinv2ForMaskedImageModeling, + Swinv2Backbone, + ) + if is_torch_available() + else () ) pipeline_model_mapping = ( {"feature-extraction": Swinv2Model, "image-classification": Swinv2ForImageClassification} @@ -201,6 +243,10 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_backbone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_backbone(*config_and_inputs) + # TODO: check if this works again for PyTorch 2.x.y @unittest.skip(reason="Got `CUDA error: misaligned address` with PyTorch 2.0.0.") def test_multi_gpu_data_parallel_forward(self): @@ -219,6 +265,18 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): x = model.get_output_embeddings() self.assertTrue(x is None or isinstance(x, nn.Linear)) + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True @@ -263,11 +321,8 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - if hasattr(self.model_tester, "num_hidden_states_types"): - added_hidden_states = self.model_tester.num_hidden_states_types - else: - # also another +1 for reshaped_hidden_states - added_hidden_states = 2 + # also another +1 for reshaped_hidden_states + added_hidden_states = 1 if model_class.__name__ == "Swinv2Backbone" else 2 self.assertEqual(out_len + added_hidden_states, len(outputs)) self_attentions = outputs.attentions @@ -308,17 +363,18 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): [num_patches, self.model_tester.embed_dim], ) - reshaped_hidden_states = outputs.reshaped_hidden_states - self.assertEqual(len(reshaped_hidden_states), expected_num_layers) + if not model_class.__name__ == "Swinv2Backbone": + reshaped_hidden_states = outputs.reshaped_hidden_states + self.assertEqual(len(reshaped_hidden_states), expected_num_layers) - batch_size, num_channels, height, width = reshaped_hidden_states[0].shape - reshaped_hidden_states = ( - reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1) - ) - self.assertListEqual( - list(reshaped_hidden_states.shape[-2:]), - [num_patches, self.model_tester.embed_dim], - ) + batch_size, num_channels, height, width = reshaped_hidden_states[0].shape + reshaped_hidden_states = ( + reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1) + ) + self.assertListEqual( + list(reshaped_hidden_states.shape[-2:]), + [num_patches, self.model_tester.embed_dim], + ) def test_hidden_states_output(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -380,6 +436,10 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model = Swinv2Model.from_pretrained(model_name) self.assertIsNotNone(model) + @unittest.skip(reason="Swinv2 does not support feedforward chunking yet") + def test_feed_forward_chunking(self): + pass + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -425,3 +485,12 @@ class Swinv2ModelIntegrationTest(unittest.TestCase): self.assertEqual(outputs.logits.shape, expected_shape) expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + + +@require_torch +class Swinv2BackboneTest(unittest.TestCase, BackboneTesterMixin): + all_model_classes = (Swinv2Backbone,) if is_torch_available() else () + config_class = Swinv2Config + + def setUp(self): + self.model_tester = Swinv2ModelTester(self) diff --git a/utils/check_repo.py b/utils/check_repo.py index 372b1c903b..23f004cf78 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -988,6 +988,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [ "NatBackbone", "ResNetBackbone", "SwinBackbone", + "Swinv2Backbone", "TimmBackbone", "TimmBackboneConfig", "VitDetBackbone",