Add Swinv2 backbone (#27742)

* First draft

* More improvements

* More improvements

* Make all tests pass

* Remove script

* Update image processor

* Address comments

* Use new gradient checkpointing method

* Convert checkpoints, add integration test

* Do not keep aspect ratio for now

* Set keep_aspect_ratio=False for beit, add integration test

* Remove print statement
This commit is contained in:
NielsRogge 2023-12-22 12:12:56 +01:00 committed by GitHub
parent 1ef86c4f56
commit c9fb250a25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 667 additions and 130 deletions

View File

@ -3213,6 +3213,7 @@ else:
_import_structure["models.swinv2"].extend(
[
"SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Swinv2Backbone",
"Swinv2ForImageClassification",
"Swinv2ForMaskedImageModeling",
"Swinv2Model",
@ -7541,6 +7542,7 @@ if TYPE_CHECKING:
)
from .models.swinv2 import (
SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST,
Swinv2Backbone,
Swinv2ForImageClassification,
Swinv2ForMaskedImageModeling,
Swinv2Model,

View File

@ -1116,6 +1116,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
("nat", "NatBackbone"),
("resnet", "ResNetBackbone"),
("swin", "SwinBackbone"),
("swinv2", "Swinv2Backbone"),
("timm_backbone", "TimmBackbone"),
("vitdet", "VitDetBackbone"),
]

View File

@ -207,8 +207,9 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub):
model.eval()
# Check outputs on an image
# We set `keep_aspect_ratio=False` as our current BEiT does not support arbitrary window sizes
processor = DPTImageProcessor(
size={"height": image_size, "width": image_size}, keep_aspect_ratio=True, ensure_multiple_of=32
size={"height": image_size, "width": image_size}, keep_aspect_ratio=False, ensure_multiple_of=32
)
image = prepare_img()

View File

@ -0,0 +1,322 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert DPT 3.1 checkpoints from the MiDaS repository. URL: https://github.com/isl-org/MiDaS"""
import argparse
from pathlib import Path
import requests
import torch
from PIL import Image
from transformers import DPTConfig, DPTForDepthEstimation, DPTImageProcessor, Swinv2Config
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_dpt_config(model_name):
if "tiny" in model_name:
embed_dim = 96
depths = (2, 2, 6, 2)
num_heads = (3, 6, 12, 24)
window_size = 16
# note: for Swinv2-tiny authors used the window_size = 16 variant
# as seen here: https://github.com/isl-org/MiDaS/blob/bdc4ed64c095e026dc0a2f17cabb14d58263decb/midas/backbones/swin2.py#L26
pretrained_window_sizes = (0, 0, 0, 0)
elif "base" in model_name:
embed_dim = 128
depths = (2, 2, 18, 2)
num_heads = (4, 8, 16, 32)
window_size = 24
pretrained_window_sizes = (12, 12, 12, 6)
elif "large" in model_name:
embed_dim = 192
depths = (2, 2, 18, 2)
num_heads = (6, 12, 24, 48)
window_size = 24
pretrained_window_sizes = (12, 12, 12, 6)
if "384" in model_name:
image_size = 384
elif "256" in model_name:
image_size = 256
else:
raise ValueError("Model not supported, to do")
backbone_config = Swinv2Config(
image_size=image_size,
embed_dim=embed_dim,
depths=depths,
window_size=window_size,
pretrained_window_sizes=pretrained_window_sizes,
num_heads=num_heads,
out_features=["stage1", "stage2", "stage3", "stage4"],
)
if model_name == "dpt-swinv2-tiny-256":
neck_hidden_sizes = [96, 192, 384, 768]
elif model_name == "dpt-swinv2-base-384":
neck_hidden_sizes = [128, 256, 512, 1024]
elif model_name == "dpt-swinv2-large-384":
neck_hidden_sizes = [192, 384, 768, 1536]
config = DPTConfig(backbone_config=backbone_config, neck_hidden_sizes=neck_hidden_sizes)
return config, image_size
# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config):
rename_keys = []
# fmt: off
# stem
rename_keys.append(("pretrained.model.patch_embed.proj.weight", "backbone.embeddings.patch_embeddings.projection.weight"))
rename_keys.append(("pretrained.model.patch_embed.proj.bias", "backbone.embeddings.patch_embeddings.projection.bias"))
rename_keys.append(("pretrained.model.patch_embed.norm.weight", "backbone.embeddings.norm.weight"))
rename_keys.append(("pretrained.model.patch_embed.norm.bias", "backbone.embeddings.norm.bias"))
# transformer encoder
for i in range(len(config.backbone_config.depths)):
for j in range(config.backbone_config.depths[i]):
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.logit_scale", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.logit_scale"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.cpb_mlp.0.weight", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.continuous_position_bias_mlp.0.weight"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.cpb_mlp.0.bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.continuous_position_bias_mlp.0.bias"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.cpb_mlp.2.weight", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.continuous_position_bias_mlp.2.weight"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.q_bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.query.bias"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.v_bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.value.bias"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.proj.weight", f"backbone.encoder.layers.{i}.blocks.{j}.attention.output.dense.weight"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.proj.bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.output.dense.bias"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm1.weight", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_before.weight"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm1.bias", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_before.bias"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc1.weight", f"backbone.encoder.layers.{i}.blocks.{j}.intermediate.dense.weight"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc1.bias", f"backbone.encoder.layers.{i}.blocks.{j}.intermediate.dense.bias"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc2.weight", f"backbone.encoder.layers.{i}.blocks.{j}.output.dense.weight"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc2.bias", f"backbone.encoder.layers.{i}.blocks.{j}.output.dense.bias"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm2.weight", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_after.weight"))
rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm2.bias", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_after.bias"))
# downsample parameters
if i in [0,1,2]:
rename_keys.append((f"pretrained.model.layers.{i}.downsample.reduction.weight", f"backbone.encoder.layers.{i}.downsample.reduction.weight"))
rename_keys.append((f"pretrained.model.layers.{i}.downsample.norm.weight", f"backbone.encoder.layers.{i}.downsample.norm.weight"))
rename_keys.append((f"pretrained.model.layers.{i}.downsample.norm.bias", f"backbone.encoder.layers.{i}.downsample.norm.bias"))
# note: non-Transformer backbones like Swinv2, LeViT et al don't require activation postprocessing (readout projections + resize blocks)
# refinenet (tricky here)
mapping = {1:3, 2:2, 3:1, 4:0}
for i in range(1, 5):
j = mapping[i]
rename_keys.append((f"scratch.refinenet{i}.out_conv.weight", f"neck.fusion_stage.layers.{j}.projection.weight"))
rename_keys.append((f"scratch.refinenet{i}.out_conv.bias", f"neck.fusion_stage.layers.{j}.projection.bias"))
rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.weight"))
rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.bias"))
rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.weight"))
rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.bias"))
rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.weight"))
rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.bias"))
rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.weight"))
rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.bias"))
# scratch convolutions
for i in range(4):
rename_keys.append((f"scratch.layer{i+1}_rn.weight", f"neck.convs.{i}.weight"))
# head
for i in range(0, 5, 2):
rename_keys.append((f"scratch.output_conv.{i}.weight", f"head.head.{i}.weight"))
rename_keys.append((f"scratch.output_conv.{i}.bias", f"head.head.{i}.bias"))
return rename_keys
def remove_ignore_keys_(state_dict):
ignore_keys = ["pretrained.model.head.weight", "pretrained.model.head.bias"]
for k in ignore_keys:
state_dict.pop(k, None)
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config, model):
for i in range(len(config.backbone_config.depths)):
for j in range(config.backbone_config.depths[i]):
dim = model.backbone.encoder.layers[i].blocks[j].attention.self.all_head_size
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"pretrained.model.layers.{i}.blocks.{j}.attn.qkv.weight")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.query.weight"] = in_proj_weight[:dim, :]
state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.key.weight"] = in_proj_weight[
dim : dim * 2, :
]
state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.value.weight"] = in_proj_weight[
-dim:, :
]
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, verify_logits, push_to_hub):
"""
Copy/paste/tweak model's weights to our DPT structure.
"""
name_to_url = {
"dpt-swinv2-tiny-256": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt",
"dpt-swinv2-base-384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt",
"dpt-swinv2-large-384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt",
}
# define DPT configuration based on URL
checkpoint_url = name_to_url[model_name]
config, image_size = get_dpt_config(model_name)
# load original state_dict from URL
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
# load HuggingFace model
model = DPTForDepthEstimation(config)
# remove certain keys
remove_ignore_keys_(state_dict)
# rename keys
rename_keys = create_rename_keys(config)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
# read in qkv matrices
read_in_q_k_v(state_dict, config, model)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print("Missing keys:", missing_keys)
print("Unexpected keys:", unexpected_keys)
model.eval()
# Check outputs on an image
processor = DPTImageProcessor(size={"height": image_size, "width": image_size})
image = prepare_img()
processor(image, return_tensors="pt")
if verify_logits:
from torchvision import transforms
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
transforms = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
]
)
pixel_values = transforms(image).unsqueeze(0)
# forward pass
with torch.no_grad():
outputs = model(pixel_values)
predicted_depth = outputs.predicted_depth
print("Shape of predicted depth:", predicted_depth.shape)
print("First values of predicted depth:", predicted_depth[0, :3, :3])
# assert logits
if model_name == "dpt-swinv2-base-384":
# OK, checked
expected_shape = torch.Size([1, 384, 384])
expected_slice = torch.tensor(
[
[1998.5575, 1997.3887, 2009.2981],
[1952.8607, 1979.6488, 2001.0854],
[1953.7697, 1961.7711, 1968.8904],
],
)
elif model_name == "dpt-swinv2-tiny-256":
# OK, checked
expected_shape = torch.Size([1, 256, 256])
expected_slice = torch.tensor(
[[978.9163, 976.5215, 978.5349], [974.1859, 971.7249, 975.8046], [971.3419, 970.3118, 971.6830]],
)
elif model_name == "dpt-swinv2-large-384":
# OK, checked
expected_shape = torch.Size([1, 384, 384])
expected_slice = torch.tensor(
[
[1203.7206, 1200.1495, 1197.8234],
[1196.2484, 1183.5033, 1186.4640],
[1178.8131, 1182.3260, 1174.3975],
],
)
assert predicted_depth.shape == torch.Size(expected_shape)
assert torch.allclose(predicted_depth[0, :3, :3], expected_slice)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model and processor to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print("Pushing model and processor to hub...")
model.push_to_hub(repo_id=f"Intel/{model_name}")
processor.push_to_hub(repo_id=f"Intel/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="dpt-swinv2-base-384",
type=str,
choices=["dpt-swinv2-tiny-256", "dpt-swinv2-base-384", "dpt-swinv2-large-384"],
help="Name of the model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--verify_logits",
action="store_true",
help="Whether to verify logits after conversion.",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether to push the model to the hub after conversion.",
)
args = parser.parse_args()
convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub)

View File

@ -100,7 +100,7 @@ class DPTImageProcessor(BaseImageProcessor):
Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`.
size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 384}`):
Size of the image after resizing. Can be overidden by `size` in `preprocess`.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
@ -136,7 +136,7 @@ class DPTImageProcessor(BaseImageProcessor):
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
resample: PILImageResampling = PILImageResampling.BICUBIC,
keep_aspect_ratio: bool = False,
ensure_multiple_of: int = 1,
do_rescale: bool = True,
@ -202,6 +202,7 @@ class DPTImageProcessor(BaseImageProcessor):
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
output_size = get_resize_output_image_size(
image,
output_size=(size["height"], size["width"]),
@ -381,7 +382,14 @@ class DPTImageProcessor(BaseImageProcessor):
if do_resize:
images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
self.resize(
image=image,
size=size,
resample=resample,
keep_aspect_ratio=keep_aspect_ratio,
ensure_multiple_of=ensure_multiple_of,
input_data_format=input_data_format,
)
for image in images
]

View File

@ -489,11 +489,12 @@ class Swin2SROutput(nn.Module):
class Swin2SRLayer(nn.Module):
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretrained_window_size=0):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
self.set_shift_and_window_size(input_resolution)
window_size, shift_size = self._compute_window_shift(
(config.window_size, config.window_size), (shift_size, shift_size)
)
self.window_size = window_size[0]
self.shift_size = shift_size[0]
self.attention = Swin2SRAttention(
config=config,
dim=dim,
@ -509,29 +510,10 @@ class Swin2SRLayer(nn.Module):
self.output = Swin2SROutput(config, dim)
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
def set_shift_and_window_size(self, input_resolution):
target_window_size = (
self.window_size
if isinstance(self.window_size, collections.abc.Iterable)
else (self.window_size, self.window_size)
)
target_shift_size = (
self.shift_size
if isinstance(self.shift_size, collections.abc.Iterable)
else (self.shift_size, self.shift_size)
)
window_dim = input_resolution[0].item() if torch.is_tensor(input_resolution[0]) else input_resolution[0]
self.window_size = window_dim if window_dim <= target_window_size[0] else target_window_size[0]
self.shift_size = (
0
if input_resolution
<= (
self.window_size
if isinstance(self.window_size, collections.abc.Iterable)
else (self.window_size, self.window_size)
)
else target_shift_size[0]
)
def _compute_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
return window_size, shift_size
def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0:
@ -574,12 +556,7 @@ class Swin2SRLayer(nn.Module):
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
always_partition: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if not always_partition:
self.set_shift_and_window_size(input_dimensions)
else:
pass
height, width = input_dimensions
batch_size, _, channels = hidden_states.size()
shortcut = hidden_states

View File

@ -33,6 +33,7 @@ else:
"Swinv2ForMaskedImageModeling",
"Swinv2Model",
"Swinv2PreTrainedModel",
"Swinv2Backbone",
]
@ -47,6 +48,7 @@ if TYPE_CHECKING:
else:
from .modeling_swinv2 import (
SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST,
Swinv2Backbone,
Swinv2ForImageClassification,
Swinv2ForMaskedImageModeling,
Swinv2Model,

View File

@ -16,6 +16,7 @@
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__)
@ -27,7 +28,7 @@ SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
}
class Swinv2Config(PretrainedConfig):
class Swinv2Config(BackboneConfigMixin, PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Swinv2Model`]. It is used to instantiate a Swin
Transformer v2 model according to the specified arguments, defining the model architecture. Instantiating a
@ -53,6 +54,8 @@ class Swinv2Config(PretrainedConfig):
Number of attention heads in each layer of the Transformer encoder.
window_size (`int`, *optional*, defaults to 7):
Size of windows.
pretrained_window_sizes (`list(int)`, *optional*, defaults to `[0, 0, 0, 0]`):
Size of windows during pretraining.
mlp_ratio (`float`, *optional*, defaults to 4.0):
Ratio of MLP hidden dimensionality to embedding dimensionality.
qkv_bias (`bool`, *optional*, defaults to `True`):
@ -74,6 +77,14 @@ class Swinv2Config(PretrainedConfig):
The epsilon used by the layer normalization layers.
encoder_stride (`int`, *optional*, defaults to 32):
Factor to increase the spatial resolution by in the decoder head for masked image modeling.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage.
Example:
@ -106,6 +117,7 @@ class Swinv2Config(PretrainedConfig):
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
pretrained_window_sizes=[0, 0, 0, 0],
mlp_ratio=4.0,
qkv_bias=True,
hidden_dropout_prob=0.0,
@ -116,6 +128,8 @@ class Swinv2Config(PretrainedConfig):
initializer_range=0.02,
layer_norm_eps=1e-5,
encoder_stride=32,
out_features=None,
out_indices=None,
**kwargs,
):
super().__init__(**kwargs)
@ -128,6 +142,7 @@ class Swinv2Config(PretrainedConfig):
self.num_layers = len(depths)
self.num_heads = num_heads
self.window_size = window_size
self.pretrained_window_sizes = pretrained_window_sizes
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
@ -138,7 +153,10 @@ class Swinv2Config(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.encoder_stride = encoder_stride
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)
# we set the hidden_size attribute in order to make Swinv2 work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.pretrained_window_sizes = (0, 0, 0, 0)

View File

@ -23,10 +23,11 @@ from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import (
@ -37,6 +38,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin
from .configuration_swinv2 import Swinv2Config
@ -641,11 +643,12 @@ class Swinv2Output(nn.Module):
class Swinv2Layer(nn.Module):
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretrained_window_size=0):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
self.set_shift_and_window_size(input_resolution)
window_size, shift_size = self._compute_window_shift(
(config.window_size, config.window_size), (shift_size, shift_size)
)
self.window_size = window_size[0]
self.shift_size = shift_size[0]
self.attention = Swinv2Attention(
config=config,
dim=dim,
@ -661,29 +664,10 @@ class Swinv2Layer(nn.Module):
self.output = Swinv2Output(config, dim)
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
def set_shift_and_window_size(self, input_resolution):
target_window_size = (
self.window_size
if isinstance(self.window_size, collections.abc.Iterable)
else (self.window_size, self.window_size)
)
target_shift_size = (
self.shift_size
if isinstance(self.shift_size, collections.abc.Iterable)
else (self.shift_size, self.shift_size)
)
window_dim = input_resolution[0].item() if torch.is_tensor(input_resolution[0]) else input_resolution[0]
self.window_size = window_dim if window_dim <= target_window_size[0] else target_window_size[0]
self.shift_size = (
0
if input_resolution
<= (
self.window_size
if isinstance(self.window_size, collections.abc.Iterable)
else (self.window_size, self.window_size)
)
else target_shift_size[0]
)
def _compute_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
return window_size, shift_size
def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0:
@ -726,12 +710,7 @@ class Swinv2Layer(nn.Module):
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
always_partition: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if not always_partition:
self.set_shift_and_window_size(input_dimensions)
else:
pass
height, width = input_dimensions
batch_size, _, channels = hidden_states.size()
shortcut = hidden_states
@ -791,24 +770,18 @@ class Swinv2Stage(nn.Module):
super().__init__()
self.config = config
self.dim = dim
window_size = (
config.window_size
if isinstance(config.window_size, collections.abc.Iterable)
else (config.window_size, config.window_size)
)
self.blocks = nn.ModuleList(
[
Swinv2Layer(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
shift_size=[0, 0] if (i % 2 == 0) else [window_size[0] // 2, window_size[1] // 2],
pretrained_window_size=pretrained_window_size,
)
for i in range(depth)
]
)
blocks = []
for i in range(depth):
block = Swinv2Layer(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
pretrained_window_size=pretrained_window_size,
)
blocks.append(block)
self.blocks = nn.ModuleList(blocks)
# patch merging layer
if downsample is not None:
@ -818,21 +791,22 @@ class Swinv2Stage(nn.Module):
self.pointing = False
# Copied from transformers.models.swin.modeling_swin.SwinStage.forward
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
always_partition: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
)
hidden_states = layer_outputs[0]
@ -860,25 +834,24 @@ class Swinv2Encoder(nn.Module):
if self.config.pretrained_window_sizes is not None:
pretrained_window_sizes = config.pretrained_window_sizes
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
self.layers = nn.ModuleList(
[
Swinv2Stage(
config=config,
dim=int(config.embed_dim * 2**i_layer),
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
depth=config.depths[i_layer],
num_heads=config.num_heads[i_layer],
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
downsample=Swinv2PatchMerging if (i_layer < self.num_layers - 1) else None,
pretrained_window_size=pretrained_window_sizes[i_layer],
)
for i_layer in range(self.num_layers)
]
)
layers = []
for i_layer in range(self.num_layers):
stage = Swinv2Stage(
config=config,
dim=int(config.embed_dim * 2**i_layer),
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
depth=config.depths[i_layer],
num_heads=config.num_heads[i_layer],
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
downsample=Swinv2PatchMerging if (i_layer < self.num_layers - 1) else None,
pretrained_window_size=pretrained_window_sizes[i_layer],
)
layers.append(stage)
self.layers = nn.ModuleList(layers)
self.gradient_checkpointing = False
# Copied from transformers.models.swin.modeling_swin.SwinEncoder.forward with SwinEncoderOutput->Swinv2EncoderOutput
def forward(
self,
hidden_states: torch.Tensor,
@ -887,7 +860,6 @@ class Swinv2Encoder(nn.Module):
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
always_partition: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, Swinv2EncoderOutput]:
all_hidden_states = () if output_hidden_states else None
@ -907,11 +879,14 @@ class Swinv2Encoder(nn.Module):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask
)
else:
layer_outputs = layer_module(
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
)
hidden_states = layer_outputs[0]
@ -942,7 +917,11 @@ class Swinv2Encoder(nn.Module):
all_self_attentions += layer_outputs[3:]
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions, all_reshaped_hidden_states]
if v is not None
)
return Swinv2EncoderOutput(
last_hidden_state=hidden_states,
@ -1323,3 +1302,99 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel):
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
)
@add_start_docstrings(
"""
Swinv2 backbone, to be used with frameworks like DETR and MaskFormer.
""",
SWINV2_START_DOCSTRING,
)
class Swinv2Backbone(Swinv2PreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.embeddings = Swinv2Embeddings(config)
self.encoder = Swinv2Encoder(config, self.embeddings.patch_grid)
# initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
@add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BackboneOutput:
"""
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
>>> model = AutoBackbone.from_pretrained(
... "microsoft/swinv2-tiny-patch4-window8-256", out_features=["stage1", "stage2", "stage3", "stage4"]
... )
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> feature_maps = outputs.feature_maps
>>> list(feature_maps[-1].shape)
[1, 2048, 7, 7]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
embedding_output, input_dimensions = self.embeddings(pixel_values)
outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=None,
output_attentions=output_attentions,
output_hidden_states=True,
output_hidden_states_before_downsampling=True,
return_dict=return_dict,
)
hidden_states = outputs.reshaped_hidden_states if return_dict else outputs[-1]
feature_maps = ()
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
feature_maps += (hidden_state,)
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs[1],)
if output_attentions:
output += (outputs[2],)
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)

View File

@ -7759,6 +7759,13 @@ class Swin2SRPreTrainedModel(metaclass=DummyObject):
SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST = None
class Swinv2Backbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Swinv2ForImageClassification(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -126,3 +126,13 @@ class DPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
).pixel_values
self.assertTrue(pixel_values.shape[2] % 4 == 0)
self.assertTrue(pixel_values.shape[3] % 4 == 0)
def test_keep_aspect_ratio(self):
size = {"height": 512, "width": 512}
image_processor = DPTImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=32)
image = np.zeros((489, 640, 3))
pixel_values = image_processor(image, return_tensors="pt").pixel_values
self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672])

View File

@ -258,7 +258,7 @@ def prepare_img():
@require_vision
@slow
class DPTModelIntegrationTest(unittest.TestCase):
def test_inference_depth_estimation(self):
def test_inference_depth_estimation_dinov2(self):
image_processor = DPTImageProcessor.from_pretrained("facebook/dpt-dinov2-small-kitti")
model = DPTForDepthEstimation.from_pretrained("facebook/dpt-dinov2-small-kitti").to(torch_device)
@ -279,3 +279,47 @@ class DPTModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4))
def test_inference_depth_estimation_beit(self):
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-beit-base-384")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-beit-base-384").to(torch_device)
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
# verify the predicted depth
expected_shape = torch.Size((1, 384, 384))
self.assertEqual(predicted_depth.shape, expected_shape)
expected_slice = torch.tensor(
[[2669.7061, 2663.7144, 2674.9399], [2633.9326, 2650.9092, 2665.4270], [2621.8271, 2632.0129, 2637.2290]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4))
def test_inference_depth_estimation_swinv2(self):
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256").to(torch_device)
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
# verify the predicted depth
expected_shape = torch.Size((1, 256, 256))
self.assertEqual(predicted_depth.shape, expected_shape)
expected_slice = torch.tensor(
[[1032.7719, 1025.1886, 1030.2661], [1023.7619, 1021.0075, 1024.9121], [1022.5667, 1018.8522, 1021.4145]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4))

View File

@ -14,12 +14,14 @@
# limitations under the License.
""" Testing suite for the PyTorch Swinv2 model. """
import collections
import inspect
import unittest
from transformers import Swinv2Config
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@ -29,7 +31,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import Swinv2ForImageClassification, Swinv2ForMaskedImageModeling, Swinv2Model
from transformers import Swinv2Backbone, Swinv2ForImageClassification, Swinv2ForMaskedImageModeling, Swinv2Model
from transformers.models.swinv2.modeling_swinv2 import SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
@ -65,6 +67,8 @@ class Swinv2ModelTester:
use_labels=True,
type_sequence_label_size=10,
encoder_stride=8,
out_features=["stage1", "stage2"],
out_indices=[1, 2],
):
self.parent = parent
self.batch_size = batch_size
@ -90,6 +94,8 @@ class Swinv2ModelTester:
self.use_labels = use_labels
self.type_sequence_label_size = type_sequence_label_size
self.encoder_stride = encoder_stride
self.out_features = out_features
self.out_indices = out_indices
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -122,6 +128,8 @@ class Swinv2ModelTester:
layer_norm_eps=self.layer_norm_eps,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
out_features=self.out_features,
out_indices=self.out_indices,
)
def create_and_check_model(self, config, pixel_values, labels):
@ -135,6 +143,33 @@ class Swinv2ModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_backbone(self, config, pixel_values, labels):
model = Swinv2Backbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], 16, 16])
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
# verify backbone works with out_features=None
config.out_features = None
model = Swinv2Backbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], 4, 4])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
model = Swinv2ForMaskedImageModeling(config=config)
model.to(torch_device)
@ -172,7 +207,14 @@ class Swinv2ModelTester:
@require_torch
class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(Swinv2Model, Swinv2ForImageClassification, Swinv2ForMaskedImageModeling) if is_torch_available() else ()
(
Swinv2Model,
Swinv2ForImageClassification,
Swinv2ForMaskedImageModeling,
Swinv2Backbone,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{"feature-extraction": Swinv2Model, "image-classification": Swinv2ForImageClassification}
@ -201,6 +243,10 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
# TODO: check if this works again for PyTorch 2.x.y
@unittest.skip(reason="Got `CUDA error: misaligned address` with PyTorch 2.0.0.")
def test_multi_gpu_data_parallel_forward(self):
@ -219,6 +265,18 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
@ -263,11 +321,8 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
else:
# also another +1 for reshaped_hidden_states
added_hidden_states = 2
# also another +1 for reshaped_hidden_states
added_hidden_states = 1 if model_class.__name__ == "Swinv2Backbone" else 2
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.attentions
@ -308,17 +363,18 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
[num_patches, self.model_tester.embed_dim],
)
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
if not model_class.__name__ == "Swinv2Backbone":
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
)
self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
)
self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -380,6 +436,10 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = Swinv2Model.from_pretrained(model_name)
self.assertIsNotNone(model)
@unittest.skip(reason="Swinv2 does not support feedforward chunking yet")
def test_feed_forward_chunking(self):
pass
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -425,3 +485,12 @@ class Swinv2ModelIntegrationTest(unittest.TestCase):
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@require_torch
class Swinv2BackboneTest(unittest.TestCase, BackboneTesterMixin):
all_model_classes = (Swinv2Backbone,) if is_torch_available() else ()
config_class = Swinv2Config
def setUp(self):
self.model_tester = Swinv2ModelTester(self)

View File

@ -988,6 +988,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
"NatBackbone",
"ResNetBackbone",
"SwinBackbone",
"Swinv2Backbone",
"TimmBackbone",
"TimmBackboneConfig",
"VitDetBackbone",