Add DINOv2 depth estimation (#26092)

* First draft

* Fix style

* More improvements

* Fix tests

* Fix tests

* Convert checkpoint

* Improve DPTImageProcessor

* Remove scripts, improve conversion script

* Remove print statements

* Fix test

* Improve docstring

* More improvements

* Fix style

* Fix image processor

* Add tests

* Address comments

* Address comments

* Make bias backwards compatible

* Address comment

* Address comment

* Address comment

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Address comments

* Add flag

* Add tests

* Make tests smaller

* Use regular BackboneOutput

* Fix all tests

* Update test

* Convert more checkpoints

* Convert giant checkpoints, add integration test

* Rename size_divisibility to size_divisor

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
NielsRogge 2023-11-13 17:20:42 +01:00 committed by GitHub
parent 3b59621310
commit 2422c38de6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 928 additions and 75 deletions

View File

@ -835,11 +835,12 @@ class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
if self.config.apply_layernorm:
hidden_state = self.layernorm(hidden_state)
if self.config.reshape_hidden_states:
hidden_state = hidden_state[:, 1:]
# this was actually a bug in the original implementation that we copied here,
# cause normally the order is height, width
batch_size, _, height, width = pixel_values.shape
patch_size = self.config.patch_size
hidden_state = hidden_state[:, 1:, :].reshape(
batch_size, width // patch_size, height // patch_size, -1
)
hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)

View File

@ -18,6 +18,7 @@ import copy
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto.configuration_auto import CONFIG_MAPPING
from ..bit import BitConfig
@ -91,6 +92,10 @@ class DPTConfig(PretrainedConfig):
The index of the features to use in the heads.
use_batch_norm_in_fusion_residual (`bool`, *optional*, defaults to `False`):
Whether to use batch normalization in the pre-activate residual units of the fusion blocks.
use_bias_in_fusion_residual (`bool`, *optional*, defaults to `True`):
Whether to use bias in the pre-activate residual units of the fusion blocks.
add_projection (`bool`, *optional*, defaults to `False`):
Whether to add a projection layer before the depth estimation head.
use_auxiliary_head (`bool`, *optional*, defaults to `True`):
Whether to use an auxiliary head during training.
auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
@ -104,7 +109,8 @@ class DPTConfig(PretrainedConfig):
neck_ignore_stages (`List[int]`, *optional*, defaults to `[0, 1]`):
Used only for the `hybrid` embedding type. The stages of the readout layers to ignore.
backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
Used only for the `hybrid` embedding type. The configuration of the backbone in a dictionary.
The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to
leverage the [`AutoBackbone`] API.
Example:
@ -145,6 +151,8 @@ class DPTConfig(PretrainedConfig):
fusion_hidden_size=256,
head_in_index=-1,
use_batch_norm_in_fusion_residual=False,
use_bias_in_fusion_residual=None,
add_projection=False,
use_auxiliary_head=True,
auxiliary_loss_weight=0.4,
semantic_loss_ignore_index=255,
@ -159,6 +167,7 @@ class DPTConfig(PretrainedConfig):
self.hidden_size = hidden_size
self.is_hybrid = is_hybrid
use_autobackbone = False
if self.is_hybrid:
if backbone_config is None:
logger.info("Initializing the config with a `BiT` backbone.")
@ -185,32 +194,49 @@ class DPTConfig(PretrainedConfig):
if readout_type != "project":
raise ValueError("Readout type must be 'project' when using `DPT-hybrid` mode.")
else:
self.backbone_config = None
elif backbone_config is not None:
use_autobackbone = True
if isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
self.backbone_config = backbone_config
self.backbone_featmap_shape = None
self.neck_ignore_stages = []
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.qkv_bias = qkv_bias
self.backbone_out_indices = backbone_out_indices
else:
self.backbone_config = backbone_config
self.backbone_featmap_shape = None
self.neck_ignore_stages = []
self.num_hidden_layers = None if use_autobackbone else num_hidden_layers
self.num_attention_heads = None if use_autobackbone else num_attention_heads
self.intermediate_size = None if use_autobackbone else intermediate_size
self.hidden_dropout_prob = None if use_autobackbone else hidden_dropout_prob
self.attention_probs_dropout_prob = None if use_autobackbone else attention_probs_dropout_prob
self.layer_norm_eps = None if use_autobackbone else layer_norm_eps
self.image_size = None if use_autobackbone else image_size
self.patch_size = None if use_autobackbone else patch_size
self.num_channels = None if use_autobackbone else num_channels
self.qkv_bias = None if use_autobackbone else qkv_bias
self.backbone_out_indices = None if use_autobackbone else backbone_out_indices
if readout_type not in ["ignore", "add", "project"]:
raise ValueError("Readout_type must be one of ['ignore', 'add', 'project']")
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.readout_type = readout_type
self.reassemble_factors = reassemble_factors
self.neck_hidden_sizes = neck_hidden_sizes
self.fusion_hidden_size = fusion_hidden_size
self.head_in_index = head_in_index
self.use_batch_norm_in_fusion_residual = use_batch_norm_in_fusion_residual
self.use_bias_in_fusion_residual = use_bias_in_fusion_residual
self.add_projection = add_projection
# auxiliary head attributes (semantic segmentation)
self.use_auxiliary_head = use_auxiliary_head
self.auxiliary_loss_weight = auxiliary_loss_weight

View File

@ -0,0 +1,384 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert DINOv2 + DPT checkpoints from the original repository. URL:
https://github.com/facebookresearch/dinov2/tree/main"""
import argparse
import itertools
import math
from pathlib import Path
import requests
import torch
from PIL import Image
from torchvision import transforms
from transformers import Dinov2Config, DPTConfig, DPTForDepthEstimation, DPTImageProcessor
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_dpt_config(model_name):
if "small" in model_name:
# equivalent to stage 3, stage 6, stage 9, stage 12
backbone_config = Dinov2Config.from_pretrained(
"facebook/dinov2-small", out_indices=[3, 6, 9, 12], apply_layernorm=False, reshape_hidden_states=False
)
neck_hidden_sizes = [48, 96, 192, 384]
elif "base" in model_name:
backbone_config = Dinov2Config.from_pretrained(
"facebook/dinov2-base", out_indices=[3, 6, 9, 12], apply_layernorm=False, reshape_hidden_states=False
)
neck_hidden_sizes = [96, 192, 384, 768]
elif "large" in model_name:
backbone_config = Dinov2Config.from_pretrained(
"facebook/dinov2-large", out_indices=[5, 12, 18, 24], apply_layernorm=False, reshape_hidden_states=False
)
neck_hidden_sizes = [128, 256, 512, 1024]
elif "giant" in model_name:
backbone_config = Dinov2Config.from_pretrained(
"facebook/dinov2-giant", out_indices=[10, 20, 30, 40], apply_layernorm=False, reshape_hidden_states=False
)
neck_hidden_sizes = [192, 384, 768, 1536]
else:
raise NotImplementedError("To do")
config = DPTConfig(
backbone_config=backbone_config,
neck_hidden_sizes=neck_hidden_sizes,
use_bias_in_fusion_residual=False,
add_projection=True,
)
return config
# here we list all DPT keys to be renamed (original name on the left, our name on the right)
def create_rename_keys_dpt(config):
rename_keys = []
# fmt: off
# activation postprocessing (projections, readout projections + resize blocks)
for i in range(4):
rename_keys.append((f"decode_head.reassemble_blocks.projects.{i}.conv.weight", f"neck.reassemble_stage.layers.{i}.projection.weight"))
rename_keys.append((f"decode_head.reassemble_blocks.projects.{i}.conv.bias", f"neck.reassemble_stage.layers.{i}.projection.bias"))
rename_keys.append((f"decode_head.reassemble_blocks.readout_projects.{i}.0.weight", f"neck.reassemble_stage.readout_projects.{i}.0.weight"))
rename_keys.append((f"decode_head.reassemble_blocks.readout_projects.{i}.0.bias", f"neck.reassemble_stage.readout_projects.{i}.0.bias"))
if i != 2:
rename_keys.append((f"decode_head.reassemble_blocks.resize_layers.{i}.weight", f"neck.reassemble_stage.layers.{i}.resize.weight"))
rename_keys.append((f"decode_head.reassemble_blocks.resize_layers.{i}.bias", f"neck.reassemble_stage.layers.{i}.resize.bias"))
# fusion layers
for i in range(4):
rename_keys.append((f"decode_head.fusion_blocks.{i}.project.conv.weight", f"neck.fusion_stage.layers.{i}.projection.weight"))
rename_keys.append((f"decode_head.fusion_blocks.{i}.project.conv.bias", f"neck.fusion_stage.layers.{i}.projection.bias"))
if i != 0:
rename_keys.append((f"decode_head.fusion_blocks.{i}.res_conv_unit1.conv1.conv.weight", f"neck.fusion_stage.layers.{i}.residual_layer1.convolution1.weight"))
rename_keys.append((f"decode_head.fusion_blocks.{i}.res_conv_unit1.conv2.conv.weight", f"neck.fusion_stage.layers.{i}.residual_layer1.convolution2.weight"))
rename_keys.append((f"decode_head.fusion_blocks.{i}.res_conv_unit2.conv1.conv.weight", f"neck.fusion_stage.layers.{i}.residual_layer2.convolution1.weight"))
rename_keys.append((f"decode_head.fusion_blocks.{i}.res_conv_unit2.conv2.conv.weight", f"neck.fusion_stage.layers.{i}.residual_layer2.convolution2.weight"))
# neck convolutions
for i in range(4):
rename_keys.append((f"decode_head.convs.{i}.conv.weight", f"neck.convs.{i}.weight"))
# head
rename_keys.append(("decode_head.project.conv.weight", "head.projection.weight"))
rename_keys.append(("decode_head.project.conv.bias", "head.projection.bias"))
for i in range(0, 5, 2):
rename_keys.append((f"decode_head.conv_depth.head.{i}.weight", f"head.head.{i}.weight"))
rename_keys.append((f"decode_head.conv_depth.head.{i}.bias", f"head.head.{i}.bias"))
# fmt: on
return rename_keys
# here we list all backbone keys to be renamed (original name on the left, our name on the right)
def create_rename_keys_backbone(config):
rename_keys = []
# fmt: off
# patch embedding layer
rename_keys.append(("cls_token", "backbone.embeddings.cls_token"))
rename_keys.append(("mask_token", "backbone.embeddings.mask_token"))
rename_keys.append(("pos_embed", "backbone.embeddings.position_embeddings"))
rename_keys.append(("patch_embed.proj.weight", "backbone.embeddings.patch_embeddings.projection.weight"))
rename_keys.append(("patch_embed.proj.bias", "backbone.embeddings.patch_embeddings.projection.bias"))
# Transfomer encoder
for i in range(config.backbone_config.num_hidden_layers):
# layernorms
rename_keys.append((f"blocks.{i}.norm1.weight", f"backbone.encoder.layer.{i}.norm1.weight"))
rename_keys.append((f"blocks.{i}.norm1.bias", f"backbone.encoder.layer.{i}.norm1.bias"))
rename_keys.append((f"blocks.{i}.norm2.weight", f"backbone.encoder.layer.{i}.norm2.weight"))
rename_keys.append((f"blocks.{i}.norm2.bias", f"backbone.encoder.layer.{i}.norm2.bias"))
# MLP
if config.backbone_config.use_swiglu_ffn:
rename_keys.append((f"blocks.{i}.mlp.w12.weight", f"backbone.encoder.layer.{i}.mlp.w12.weight"))
rename_keys.append((f"blocks.{i}.mlp.w12.bias", f"backbone.encoder.layer.{i}.mlp.w12.bias"))
rename_keys.append((f"blocks.{i}.mlp.w3.weight", f"backbone.encoder.layer.{i}.mlp.w3.weight"))
rename_keys.append((f"blocks.{i}.mlp.w3.bias", f"backbone.encoder.layer.{i}.mlp.w3.bias"))
else:
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"backbone.encoder.layer.{i}.mlp.fc1.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"backbone.encoder.layer.{i}.mlp.fc1.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"backbone.encoder.layer.{i}.mlp.fc2.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"backbone.encoder.layer.{i}.mlp.fc2.bias"))
# layerscale
rename_keys.append((f"blocks.{i}.ls1.gamma", f"backbone.encoder.layer.{i}.layer_scale1.lambda1"))
rename_keys.append((f"blocks.{i}.ls2.gamma", f"backbone.encoder.layer.{i}.layer_scale2.lambda1"))
# attention projection layer
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"backbone.encoder.layer.{i}.attention.output.dense.weight"))
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"backbone.encoder.layer.{i}.attention.output.dense.bias"))
# fmt: on
rename_keys.append(("norm.weight", "backbone.layernorm.weight"))
rename_keys.append(("norm.bias", "backbone.layernorm.bias"))
return rename_keys
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config):
for i in range(config.backbone_config.num_hidden_layers):
# read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
hidden_size = config.backbone_config.hidden_size
# next, add query, keys and values (in that order) to the state dict
state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[:hidden_size, :]
state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[:hidden_size]
state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
hidden_size : hidden_size * 2, :
]
state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
hidden_size : hidden_size * 2
]
state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-hidden_size:, :]
state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-hidden_size:]
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
# We will verify our results on an image of cute cats
def prepare_img():
url = "https://dl.fbaipublicfiles.com/dinov2/images/example.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
name_to_url = {
"dpt-dinov2-small-nyu": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_nyu_dpt_head.pth",
"dpt-dinov2-small-kitti": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_kitti_dpt_head.pth",
"dpt-dinov2-base-nyu": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_nyu_dpt_head.pth",
"dpt-dinov2-base-kitti": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_kitti_dpt_head.pth",
"dpt-dinov2-large-nyu": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_nyu_dpt_head.pth",
"dpt-dinov2-large-kitti": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_kitti_dpt_head.pth",
"dpt-dinov2-giant-nyu": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_nyu_dpt_head.pth",
"dpt-dinov2-giant-kitti": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_kitti_dpt_head.pth",
}
def get_original_pixel_values(image):
class CenterPadding(object):
def __init__(self, multiple):
super().__init__()
self.multiple = multiple
def _get_pad(self, size):
new_size = math.ceil(size / self.multiple) * self.multiple
pad_size = new_size - size
pad_size_left = pad_size // 2
pad_size_right = pad_size - pad_size_left
return pad_size_left, pad_size_right
def __call__(self, img):
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in img.shape[-2:][::-1]))
output = torch.nn.functional.pad(img, pads)
return output
def __repr__(self):
return self.__class__.__name__ + "()"
def make_depth_transform() -> transforms.Compose:
return transforms.Compose(
[
transforms.ToTensor(),
lambda x: 255.0 * x[:3], # Discard alpha component and scale by 255
transforms.Normalize(
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
),
CenterPadding(multiple=14),
]
)
transform = make_depth_transform()
original_pixel_values = transform(image).unsqueeze(0)
return original_pixel_values
@torch.no_grad()
def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, verify_logits):
"""
Copy/paste/tweak model's weights to our DPT structure.
"""
# define DPT configuration based on URL
checkpoint_url = name_to_url[model_name]
config = get_dpt_config(model_name)
# load original DPT state_dict from URL
print("URL:", checkpoint_url)
dpt_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["state_dict"]
# rename keys
rename_keys = create_rename_keys_dpt(config)
for src, dest in rename_keys:
rename_key(dpt_state_dict, src, dest)
# load original backbone state_dict from URL
if "small" in model_name:
original_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
elif "base" in model_name:
original_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
elif "large" in model_name:
original_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14")
elif "giant" in model_name:
original_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14")
else:
raise NotImplementedError("To do")
original_model.eval()
backbone_state_dict = original_model.state_dict()
# rename keys
rename_keys = create_rename_keys_backbone(config)
for src, dest in rename_keys:
rename_key(backbone_state_dict, src, dest)
# read in qkv matrices
read_in_q_k_v(backbone_state_dict, config)
for key, val in backbone_state_dict.copy().items():
val = backbone_state_dict.pop(key)
if "w12" in key:
key = key.replace("w12", "weights_in")
if "w3" in key:
key = key.replace("w3", "weights_out")
backbone_state_dict[key] = val
# merge state_dicts
state_dict = {**backbone_state_dict, **dpt_state_dict}
# load HuggingFace model
model = DPTForDepthEstimation(config)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print("Missing keys:", missing_keys)
print("Unexpected keys:", unexpected_keys)
assert missing_keys == [
"neck.fusion_stage.layers.0.residual_layer1.convolution1.weight",
"neck.fusion_stage.layers.0.residual_layer1.convolution2.weight",
]
model.eval()
# Verify image processor
processor = DPTImageProcessor(
do_resize=False,
do_rescale=False,
do_pad=True,
size_divisor=14,
do_normalize=True,
image_mean=(123.675, 116.28, 103.53),
image_std=(58.395, 57.12, 57.375),
)
image = prepare_img()
pixel_values = processor(image, return_tensors="pt").pixel_values.float()
original_pixel_values = get_original_pixel_values(image)
assert torch.allclose(pixel_values, original_pixel_values)
# Verify forward pass
with torch.no_grad():
outputs = model(pixel_values)
predicted_depth = outputs.predicted_depth
print("Shape of predicted depth:", predicted_depth.shape)
print("First values of predicted depth:", predicted_depth[0, :3, :3])
# assert logits
if verify_logits:
if model_name == "dpt-dinov2-small-nyu":
expected_shape = torch.Size([1, 576, 736])
expected_slice = torch.tensor(
[[3.3576, 3.4741, 3.4345], [3.4324, 3.5012, 3.2775], [3.2560, 3.3563, 3.2354]]
)
assert predicted_depth.shape == torch.Size(expected_shape)
assert torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-5)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model and processor to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print("Pushing model and processor to hub...")
model.push_to_hub(repo_id=f"facebook/{model_name}")
processor.push_to_hub(repo_id=f"facebook/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="dpt-dinov2-small-nyu",
type=str,
choices=name_to_url.keys(),
help="Name of the model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether to push the model to the hub after conversion.",
)
parser.add_argument(
"--verify_logits",
action="store_true",
required=False,
help="Path to the output PyTorch model directory.",
)
args = parser.parse_args()
convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.verify_logits)

View File

@ -229,12 +229,14 @@ def convert_dpt_checkpoint(checkpoint_url, pytorch_dump_folder_path, push_to_hub
if "ade" in checkpoint_url
else torch.allclose(outputs[0, :3, :3], expected_slice)
)
print("Looks ok!")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
image_processor.save_pretrained(pytorch_dump_folder_path)
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
image_processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print("Pushing model to hub...")
@ -265,7 +267,7 @@ if __name__ == "__main__":
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
required=False,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
@ -276,6 +278,7 @@ if __name__ == "__main__":
"--model_name",
default="dpt-large",
type=str,
required=False,
help="Name of the model, in case you're pushing to the hub.",
)

View File

@ -20,7 +20,7 @@ from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import resize, to_channel_dimension_format
from ...image_transforms import pad, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
@ -122,6 +122,12 @@ class DPTImageProcessor(BaseImageProcessor):
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `False`):
Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in
combination with DPT.
size_divisor (`int`, *optional*):
If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the
DINOv2 paper, which uses the model in combination with DPT.
"""
model_input_names = ["pixel_values"]
@ -138,6 +144,8 @@ class DPTImageProcessor(BaseImageProcessor):
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: bool = False,
size_divisor: int = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@ -153,6 +161,8 @@ class DPTImageProcessor(BaseImageProcessor):
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_pad = do_pad
self.size_divisor = size_divisor
def resize(
self,
@ -208,6 +218,51 @@ class DPTImageProcessor(BaseImageProcessor):
**kwargs,
)
def pad_image(
self,
image: np.array,
size_divisor: int,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Center pad an image to be a multiple of `multiple`.
Args:
image (`np.ndarray`):
Image to pad.
size_divisor (`int`):
The width and height of the image will be padded to a multiple of this number.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
def _get_pad(size, size_divisor):
new_size = math.ceil(size / size_divisor) * size_divisor
pad_size = new_size - size
pad_size_left = pad_size // 2
pad_size_right = pad_size - pad_size_left
return pad_size_left, pad_size_right
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
height, width = get_image_size(image, input_data_format)
pad_size_left, pad_size_right = _get_pad(height, size_divisor)
pad_size_top, pad_size_bottom = _get_pad(width, size_divisor)
return pad(image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format)
def preprocess(
self,
images: ImageInput,
@ -221,6 +276,8 @@ class DPTImageProcessor(BaseImageProcessor):
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: bool = None,
size_divisor: int = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
@ -286,6 +343,8 @@ class DPTImageProcessor(BaseImageProcessor):
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_pad = do_pad if do_pad is not None else self.do_pad
size_divisor = size_divisor if size_divisor is not None else self.size_divisor
images = make_list_of_images(images)
@ -304,6 +363,9 @@ class DPTImageProcessor(BaseImageProcessor):
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
if do_pad and size_divisor is None:
raise ValueError("Size divisibility must be specified if do_pad is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
@ -335,6 +397,12 @@ class DPTImageProcessor(BaseImageProcessor):
for image in images
]
if do_pad:
images = [
self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format)
for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]

View File

@ -599,12 +599,13 @@ class DPTReassembleStage(nn.Module):
# When using DPT-Hybrid the readout type is set to "project". The sanity check is done on the config file
self.readout_projects = nn.ModuleList()
hidden_size = _get_backbone_hidden_size(config)
for i in range(len(config.neck_hidden_sizes)):
if i <= 1:
self.readout_projects.append(nn.Sequential(nn.Identity()))
elif i > 1:
self.readout_projects.append(
nn.Sequential(nn.Linear(2 * config.hidden_size, config.hidden_size), ACT2FN[config.hidden_act])
nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
)
def _init_reassemble_dpt(self, config):
@ -613,12 +614,13 @@ class DPTReassembleStage(nn.Module):
if config.readout_type == "project":
self.readout_projects = nn.ModuleList()
hidden_size = _get_backbone_hidden_size(config)
for _ in range(len(config.neck_hidden_sizes)):
self.readout_projects.append(
nn.Sequential(nn.Linear(2 * config.hidden_size, config.hidden_size), ACT2FN[config.hidden_act])
nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
)
def forward(self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
"""
Args:
hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
@ -628,21 +630,24 @@ class DPTReassembleStage(nn.Module):
for i, hidden_state in enumerate(hidden_states):
if i not in self.neck_ignore_stages:
# reshape to (B, C, H, W)
hidden_state, cls_token = hidden_state[:, 1:], hidden_state[:, 0]
# reshape to (batch_size, num_channels, height, width)
cls_token, hidden_state = hidden_state[:, 0], hidden_state[:, 1:]
batch_size, sequence_length, num_channels = hidden_state.shape
size = int(math.sqrt(sequence_length))
hidden_state = hidden_state.reshape(batch_size, size, size, num_channels)
if patch_height is not None and patch_width is not None:
hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
else:
size = int(math.sqrt(sequence_length))
hidden_state = hidden_state.reshape(batch_size, size, size, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_shape = hidden_state.shape
if self.config.readout_type == "project":
# reshape to (B, H*W, C)
# reshape to (batch_size, height*width, num_channels)
hidden_state = hidden_state.flatten(2).permute((0, 2, 1))
readout = cls_token.unsqueeze(1).expand_as(hidden_state)
# concatenate the readout token to the hidden states and project
hidden_state = self.readout_projects[i](torch.cat((hidden_state, readout), -1))
# reshape back to (B, C, H, W)
# reshape back to (batch_size, num_channels, height, width)
hidden_state = hidden_state.permute(0, 2, 1).reshape(feature_shape)
elif self.config.readout_type == "add":
hidden_state = hidden_state.flatten(2) + cls_token.unsqueeze(-1)
@ -653,11 +658,19 @@ class DPTReassembleStage(nn.Module):
return out
def _get_backbone_hidden_size(config):
if config.backbone_config is not None and config.is_hybrid is False:
return config.backbone_config.hidden_size
else:
return config.hidden_size
class DPTReassembleLayer(nn.Module):
def __init__(self, config, channels, factor):
super().__init__()
# projection
self.projection = nn.Conv2d(in_channels=config.hidden_size, out_channels=channels, kernel_size=1)
hidden_size = _get_backbone_hidden_size(config)
self.projection = nn.Conv2d(in_channels=hidden_size, out_channels=channels, kernel_size=1)
# up/down sampling depending on factor
if factor > 1:
@ -710,24 +723,30 @@ class DPTPreActResidualLayer(nn.Module):
super().__init__()
self.use_batch_norm = config.use_batch_norm_in_fusion_residual
self.activation1 = ACT2FN["relu"]
use_bias_in_fusion_residual = (
config.use_bias_in_fusion_residual
if config.use_bias_in_fusion_residual is not None
else not self.use_batch_norm
)
self.activation1 = nn.ReLU()
self.convolution1 = nn.Conv2d(
config.fusion_hidden_size,
config.fusion_hidden_size,
kernel_size=3,
stride=1,
padding=1,
bias=not self.use_batch_norm,
bias=use_bias_in_fusion_residual,
)
self.activation2 = ACT2FN["relu"]
self.activation2 = nn.ReLU()
self.convolution2 = nn.Conv2d(
config.fusion_hidden_size,
config.fusion_hidden_size,
kernel_size=3,
stride=1,
padding=1,
bias=not self.use_batch_norm,
bias=use_bias_in_fusion_residual,
)
if self.use_batch_norm:
@ -973,8 +992,12 @@ class DPTNeck(nn.Module):
super().__init__()
self.config = config
# postprocessing
self.reassemble_stage = DPTReassembleStage(config)
# postprocessing: only required in case of a non-hierarchical backbone (e.g. ViT, BEiT)
if config.backbone_config is not None and config.backbone_config.model_type in ["swinv2"]:
self.reassemble_stage = None
else:
self.reassemble_stage = DPTReassembleStage(config)
self.convs = nn.ModuleList()
for channel in config.neck_hidden_sizes:
self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
@ -982,17 +1005,23 @@ class DPTNeck(nn.Module):
# fusion
self.fusion_stage = DPTFeatureFusionStage(config)
def forward(self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
if not isinstance(hidden_states, list):
raise ValueError("hidden_states should be a list of tensors")
def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
"""
Args:
hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
List of hidden states from the backbone.
"""
if not isinstance(hidden_states, (tuple, list)):
raise ValueError("hidden_states should be a tuple or list of tensors")
if len(hidden_states) != len(self.config.neck_hidden_sizes):
raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
# postprocess hidden states
features = self.reassemble_stage(hidden_states)
if self.reassemble_stage is not None:
hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
features = [self.convs[i](feature) for i, feature in enumerate(features)]
features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
# fusion blocks
output = self.fusion_stage(features)
@ -1012,20 +1041,28 @@ class DPTDepthEstimationHead(nn.Module):
self.config = config
self.projection = None
if config.add_projection:
self.projection = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
features = config.fusion_hidden_size
self.head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
ACT2FN["relu"],
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
ACT2FN["relu"],
nn.ReLU(),
)
def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
# use last features
hidden_states = hidden_states[self.config.head_in_index]
if self.projection is not None:
hidden_states = self.projection(hidden_states)
hidden_states = nn.ReLU()(hidden_states)
predicted_depth = self.head(hidden_states)
predicted_depth = predicted_depth.squeeze(dim=1)
@ -1043,7 +1080,11 @@ class DPTForDepthEstimation(DPTPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.dpt = DPTModel(config, add_pooling_layer=False)
self.backbone = None
if config.backbone_config is not None and config.is_hybrid is False:
self.backbone = AutoBackbone.from_config(config.backbone_config)
else:
self.dpt = DPTModel(config, add_pooling_layer=False)
# Neck
self.neck = DPTNeck(config)
@ -1109,32 +1150,46 @@ class DPTForDepthEstimation(DPTPreTrainedModel):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
outputs = self.dpt(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=True, # we need the intermediate hidden states
return_dict=return_dict,
)
hidden_states = outputs.hidden_states if return_dict else outputs[1]
# only keep certain features based on config.backbone_out_indices
# note that the hidden_states also include the initial embeddings
if not self.config.is_hybrid:
hidden_states = [
feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
]
else:
backbone_hidden_states = outputs.intermediate_activations if return_dict else list(outputs[-1])
backbone_hidden_states.extend(
feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices[2:]
if self.backbone is not None:
outputs = self.backbone.forward_with_filtered_kwargs(
pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
)
hidden_states = outputs.feature_maps
else:
outputs = self.dpt(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=True, # we need the intermediate hidden states
return_dict=return_dict,
)
hidden_states = outputs.hidden_states if return_dict else outputs[1]
# only keep certain features based on config.backbone_out_indices
# note that the hidden_states also include the initial embeddings
if not self.config.is_hybrid:
hidden_states = [
feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
]
else:
backbone_hidden_states = outputs.intermediate_activations if return_dict else list(outputs[-1])
backbone_hidden_states.extend(
feature
for idx, feature in enumerate(hidden_states[1:])
if idx in self.config.backbone_out_indices[2:]
)
hidden_states = backbone_hidden_states
hidden_states = backbone_hidden_states
hidden_states = self.neck(hidden_states)
patch_height, patch_width = None, None
if self.config.backbone_config is not None and self.config.is_hybrid is False:
_, _, height, width = pixel_values.shape
patch_size = self.config.backbone_config.patch_size
patch_height = height // patch_size
patch_width = width // patch_size
hidden_states = self.neck(hidden_states, patch_height, patch_width)
predicted_depth = self.head(hidden_states)
@ -1167,7 +1222,7 @@ class DPTSemanticSegmentationHead(nn.Module):
self.head = nn.Sequential(
nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(features),
ACT2FN["relu"],
nn.ReLU(),
nn.Dropout(config.semantic_classifier_dropout),
nn.Conv2d(features, config.num_labels, kernel_size=1),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
@ -1190,7 +1245,7 @@ class DPTAuxiliaryHead(nn.Module):
self.head = nn.Sequential(
nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(features),
ACT2FN["relu"],
nn.ReLU(),
nn.Dropout(0.1, False),
nn.Conv2d(features, config.num_labels, kernel_size=1),
)
@ -1287,7 +1342,7 @@ class DPTForSemanticSegmentation(DPTPreTrainedModel):
hidden_states = backbone_hidden_states
hidden_states = self.neck(hidden_states)
hidden_states = self.neck(hidden_states=hidden_states)
logits = self.head(hidden_states)

View File

@ -16,6 +16,8 @@
import unittest
import numpy as np
from transformers.file_utils import is_vision_available
from transformers.testing_utils import require_torch, require_vision
@ -97,6 +99,10 @@ class DPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "size_divisor"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
@ -104,3 +110,19 @@ class DPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
def test_padding(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
image = np.random.randn(3, 249, 491)
# test individual method
image = image_processing.pad_image(image, size_divisor=4)
self.assertTrue(image.shape[1] % 4 == 0)
self.assertTrue(image.shape[2] % 4 == 0)
# test by calling
pixel_values = image_processing.preprocess(
image, do_rescale=False, do_resize=False, do_pad=True, size_divisor=4, return_tensors="pt"
).pixel_values
self.assertTrue(pixel_values.shape[2] % 4 == 0)
self.assertTrue(pixel_values.shape[3] % 4 == 0)

View File

@ -0,0 +1,294 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch DPT model. """
import inspect
import unittest
from transformers import Dinov2Config, DPTConfig
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import MODEL_MAPPING, DPTForDepthEstimation
from transformers.models.dpt.modeling_dpt import DPT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
from transformers import DPTImageProcessor
class DPTModelTester:
def __init__(
self,
parent,
batch_size=2,
num_channels=3,
image_size=32,
patch_size=16,
use_labels=True,
num_labels=3,
is_training=True,
hidden_size=4,
num_hidden_layers=2,
num_attention_heads=2,
intermediate_size=8,
out_features=["stage1", "stage2"],
apply_layernorm=False,
reshape_hidden_states=False,
neck_hidden_sizes=[2, 2],
fusion_hidden_size=6,
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.patch_size = patch_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.out_features = out_features
self.apply_layernorm = apply_layernorm
self.reshape_hidden_states = reshape_hidden_states
self.use_labels = use_labels
self.num_labels = num_labels
self.is_training = is_training
self.neck_hidden_sizes = neck_hidden_sizes
self.fusion_hidden_size = fusion_hidden_size
# DPT's sequence length
self.seq_length = (self.image_size // self.patch_size) ** 2 + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return DPTConfig(
backbone_config=self.get_backbone_config(),
neck_hidden_sizes=self.neck_hidden_sizes,
fusion_hidden_size=self.fusion_hidden_size,
)
def get_backbone_config(self):
return Dinov2Config(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
is_training=self.is_training,
out_features=self.out_features,
reshape_hidden_states=self.reshape_hidden_states,
)
def create_and_check_for_depth_estimation(self, config, pixel_values, labels):
config.num_labels = self.num_labels
model = DPTForDepthEstimation(config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(result.predicted_depth.shape, (self.batch_size, self.image_size, self.image_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as DPT does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (DPTForDepthEstimation,) if is_torch_available() else ()
pipeline_model_mapping = {"depth-estimation": DPTForDepthEstimation} if is_torch_available() else {}
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = DPTModelTester(self)
self.config_tester = ConfigTester(self, config_class=DPTConfig, has_text_modality=False, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="DPT with AutoBackbone does not have a base model and hence no input_embeddings")
def test_inputs_embeds(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_for_depth_estimation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_depth_estimation(*config_and_inputs)
def test_training(self):
for model_class in self.all_model_classes:
if model_class.__name__ == "DPTForDepthEstimation":
continue
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
if model_class in get_values(MODEL_MAPPING):
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_training_gradient_checkpointing(self):
for model_class in self.all_model_classes:
if model_class.__name__ == "DPTForDepthEstimation":
continue
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
model = model_class(config)
model.to(torch_device)
model.gradient_checkpointing_enable()
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Skip the check for the backbone
backbone_params = []
for name, module in model.named_modules():
if module.__class__.__name__ == "DPTViTHybridEmbeddings":
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
break
for name, param in model.named_parameters():
if param.requires_grad:
if name in backbone_params:
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@unittest.skip(reason="DPT with AutoBackbone does not have a base model and hence no input_embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="DPT with AutoBackbone does not have a base model")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="DPT with AutoBackbone does not have a base model")
def test_save_load_fast_init_to_base(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = DPTForDepthEstimation.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_torch
@require_vision
@slow
class DPTModelIntegrationTest(unittest.TestCase):
def test_inference_depth_estimation(self):
image_processor = DPTImageProcessor.from_pretrained("facebook/dpt-dinov2-small-kitti")
model = DPTForDepthEstimation.from_pretrained("facebook/dpt-dinov2-small-kitti").to(torch_device)
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
# verify the predicted depth
expected_shape = torch.Size((1, 576, 736))
self.assertEqual(predicted_depth.shape, expected_shape)
expected_slice = torch.tensor(
[[6.0433, 7.1636, 7.4268], [6.9047, 7.2471, 7.2355], [7.9261, 8.0631, 8.0244]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4))