Add focalnet backbone (#23104)

Adds FocalNet backbone to return features from all stages
This commit is contained in:
Alara Dirik 2023-05-03 19:32:42 +03:00 committed by GitHub
parent ca7eb27ed5
commit 441658dd6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 210 additions and 11 deletions

View File

@ -1623,6 +1623,7 @@ else:
_import_structure["models.focalnet"].extend(
[
"FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"FocalNetBackbone",
"FocalNetForImageClassification",
"FocalNetForMaskedImageModeling",
"FocalNetModel",
@ -5178,6 +5179,7 @@ if TYPE_CHECKING:
)
from .models.focalnet import (
FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST,
FocalNetBackbone,
FocalNetForImageClassification,
FocalNetForMaskedImageModeling,
FocalNetModel,

View File

@ -980,6 +980,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
("convnext", "ConvNextBackbone"),
("convnextv2", "ConvNextV2Backbone"),
("dinat", "DinatBackbone"),
("focalnet", "FocalNetBackbone"),
("maskformer-swin", "MaskFormerSwinBackbone"),
("nat", "NatBackbone"),
("resnet", "ResNetBackbone"),

View File

@ -30,6 +30,7 @@ else:
"FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"FocalNetForImageClassification",
"FocalNetForMaskedImageModeling",
"FocalNetBackbone",
"FocalNetModel",
"FocalNetPreTrainedModel",
]
@ -45,6 +46,7 @@ if TYPE_CHECKING:
else:
from .modeling_focalnet import (
FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST,
FocalNetBackbone,
FocalNetForImageClassification,
FocalNetForMaskedImageModeling,
FocalNetModel,

View File

@ -47,6 +47,8 @@ class FocalNetConfig(PretrainedConfig):
use_conv_embed (`bool`, *optional*, defaults to `False`):
Whether to use convolutional embedding. The authors noted that using convolutional embedding usually
improve the performance, but it's not used by default.
hidden_sizes (`List[int]`, *optional*, defaults to `[192, 384, 768, 768]`):
Dimensionality (hidden size) at each stage.
depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`):
Depth (number of layers) of each stage in the encoder.
focal_levels (`list(int)`, *optional*, defaults to `[2, 2, 2, 2]`):
@ -78,6 +80,14 @@ class FocalNetConfig(PretrainedConfig):
The epsilon used by the layer normalization layers.
encoder_stride (`int`, `optional`, defaults to 32):
Factor to increase the spatial resolution by in the decoder head for masked image modeling.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage.
Example:
@ -102,6 +112,7 @@ class FocalNetConfig(PretrainedConfig):
num_channels=3,
embed_dim=96,
use_conv_embed=False,
hidden_sizes=[192, 384, 768, 768],
depths=[2, 2, 6, 2],
focal_levels=[2, 2, 2, 2],
focal_windows=[3, 3, 3, 3],
@ -117,6 +128,8 @@ class FocalNetConfig(PretrainedConfig):
initializer_range=0.02,
layer_norm_eps=1e-5,
encoder_stride=32,
out_features=None,
out_indices=None,
**kwargs,
):
super().__init__(**kwargs)
@ -126,6 +139,7 @@ class FocalNetConfig(PretrainedConfig):
self.num_channels = num_channels
self.embed_dim = embed_dim
self.use_conv_embed = use_conv_embed
self.hidden_sizes = hidden_sizes
self.depths = depths
self.focal_levels = focal_levels
self.focal_windows = focal_windows
@ -141,3 +155,36 @@ class FocalNetConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.encoder_stride = encoder_stride
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features
self.out_indices = out_indices

View File

@ -56,7 +56,6 @@ def get_focalnet_config(model_name):
embed_dim = 128
elif "large" in model_name:
embed_dim = 192
focal_windows = [5, 5, 5, 5]
elif "xlarge" in model_name:
embed_dim = 256
elif "huge" in model_name:
@ -130,7 +129,10 @@ def convert_focalnet_checkpoint(model_name, pytorch_dump_folder_path, push_to_hu
"focalnet-small-lrf": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_lrf.pth",
"focalnet-base": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_srf.pth",
"focalnet-base-lrf": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_lrf.pth",
"focalnet-large": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth",
"focalnet-large-lrf-fl3": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth",
"focalnet-large-lrf-fl4": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384_fl4.pth",
"focalnet-xlarge-lrf-fl3": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384.pth",
"focalnet-xlarge-lrf-fl4": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384_fl4.pth",
}
# fmt: on

View File

@ -26,7 +26,8 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
@ -209,7 +210,6 @@ class FocalNetEmbeddings(nn.Module):
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
embeddings = self.dropout(embeddings)
return embeddings, output_dimensions
@ -971,3 +971,81 @@ class FocalNetForImageClassification(FocalNetPreTrainedModel):
hidden_states=outputs.hidden_states,
reshaped_hidden_states=outputs.reshaped_hidden_states,
)
@add_start_docstrings(
"""
FocalNet backbone, to be used with frameworks like X-Decoder.
""",
FOCALNET_START_DOCSTRING,
)
class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
self.stage_names = config.stage_names
self.focalnet = FocalNetModel(config)
self.num_features = [config.embed_dim] + config.hidden_sizes
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
if config.out_indices is not None:
self.out_indices = config.out_indices
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
# initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(FOCALNET_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: torch.Tensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BackboneOutput:
"""
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-tiny-lrf")
>>> model = AutoBackbone.from_pretrained("microsoft/focalnet-tiny-lrf")
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs = self.focalnet(pixel_values, output_hidden_states=True, return_dict=True)
hidden_states = outputs.reshaped_hidden_states
feature_maps = ()
for idx, stage in enumerate(self.stage_names):
if stage in self.out_features:
feature_maps += (hidden_states[idx],)
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs.hidden_states,)
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=None,
)

View File

@ -3002,6 +3002,13 @@ class FNetPreTrainedModel(metaclass=DummyObject):
FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
class FocalNetBackbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class FocalNetForImageClassification(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -22,6 +22,7 @@ from transformers import FocalNetConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
@ -30,7 +31,12 @@ if is_torch_available():
import torch
from torch import nn
from transformers import FocalNetForImageClassification, FocalNetForMaskedImageModeling, FocalNetModel
from transformers import (
FocalNetBackbone,
FocalNetForImageClassification,
FocalNetForMaskedImageModeling,
FocalNetModel,
)
from transformers.models.focalnet.modeling_focalnet import FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
@ -48,6 +54,7 @@ class FocalNetModelTester:
patch_size=2,
num_channels=3,
embed_dim=16,
hidden_sizes=[32, 64, 128],
depths=[1, 2, 1],
num_heads=[2, 2, 4],
window_size=2,
@ -67,6 +74,7 @@ class FocalNetModelTester:
type_sequence_label_size=10,
encoder_stride=8,
out_features=["stage1", "stage2"],
out_indices=[1, 2],
):
self.parent = parent
self.batch_size = batch_size
@ -74,6 +82,7 @@ class FocalNetModelTester:
self.patch_size = patch_size
self.num_channels = num_channels
self.embed_dim = embed_dim
self.hidden_sizes = hidden_sizes
self.depths = depths
self.num_heads = num_heads
self.window_size = window_size
@ -93,6 +102,7 @@ class FocalNetModelTester:
self.type_sequence_label_size = type_sequence_label_size
self.encoder_stride = encoder_stride
self.out_features = out_features
self.out_indices = out_indices
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -111,6 +121,7 @@ class FocalNetModelTester:
patch_size=self.patch_size,
num_channels=self.num_channels,
embed_dim=self.embed_dim,
hidden_sizes=self.hidden_sizes,
depths=self.depths,
num_heads=self.num_heads,
window_size=self.window_size,
@ -126,6 +137,7 @@ class FocalNetModelTester:
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
out_features=self.out_features,
out_indices=self.out_indices,
)
def create_and_check_model(self, config, pixel_values, labels):
@ -139,6 +151,35 @@ class FocalNetModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_backbone(self, config, pixel_values, labels):
model = FocalNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.image_size, 8, 8])
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, config.hidden_sizes[:-1])
# verify backbone works with out_features=None
config.out_features = None
model = FocalNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.image_size * 2, 4, 4])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
model = FocalNetForMaskedImageModeling(config=config)
model.to(torch_device)
@ -191,6 +232,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
FocalNetModel,
FocalNetForImageClassification,
FocalNetForMaskedImageModeling,
FocalNetBackbone,
)
if is_torch_available()
else ()
@ -204,7 +246,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
def setUp(self):
self.model_tester = FocalNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=FocalNetConfig, embed_dim=37)
self.config_tester = ConfigTester(self, config_class=FocalNetConfig, embed_dim=37, has_text_modality=False)
def test_config(self):
self.create_and_test_config_common_properties()
@ -222,6 +264,10 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
def test_for_masked_image_modeling(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
@ -234,14 +280,14 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
@unittest.skip(reason="FocalNet Transformer does not use feedforward chunking")
@unittest.skip(reason="FocalNet does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
for model_class in self.all_model_classes[:-1]:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
@ -250,7 +296,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
for model_class in self.all_model_classes[:-1]:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
@ -309,7 +355,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
else (self.model_tester.image_size, self.model_tester.image_size)
)
for model_class in self.all_model_classes:
for model_class in self.all_model_classes[:-1]:
inputs_dict["output_hidden_states"] = True
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
@ -337,7 +383,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
for model_class in self.all_model_classes:
for model_class in self.all_model_classes[:-1]:
inputs_dict["output_hidden_states"] = True
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
@ -393,3 +439,14 @@ class FocalNetModelIntegrationTest(unittest.TestCase):
expected_slice = torch.tensor([0.2166, -0.4368, 0.2191]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
self.assertTrue(outputs.logits.argmax(dim=-1).item(), 281)
@require_torch
class FocalNetBackboneTest(BackboneTesterMixin, unittest.TestCase):
all_model_classes = (FocalNetBackbone,) if is_torch_available() else ()
config_class = FocalNetConfig
has_attentions = False
def setUp(self):
self.model_tester = FocalNetModelTester(self)

View File

@ -135,6 +135,8 @@ class BackboneTesterMixin:
# Verify num_features has been initialized in the backbone init
self.assertIsNotNone(backbone.num_features)
self.assertTrue(len(backbone.channels) == len(backbone.out_indices))
print(backbone.stage_names)
print(backbone.num_features)
self.assertTrue(len(backbone.stage_names) == len(backbone.num_features))
self.assertTrue(len(backbone.channels) <= len(backbone.num_features))
self.assertTrue(len(backbone.out_feature_channels) == len(backbone.stage_names))

View File

@ -836,6 +836,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
"ConvNextBackbone",
"ConvNextV2Backbone",
"DinatBackbone",
"FocalNetBackbone",
"MaskFormerSwinBackbone",
"MaskFormerSwinConfig",
"MaskFormerSwinModel",