diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 445fbb53e2..46be8c9d2c 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1623,6 +1623,7 @@ else: _import_structure["models.focalnet"].extend( [ "FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "FocalNetBackbone", "FocalNetForImageClassification", "FocalNetForMaskedImageModeling", "FocalNetModel", @@ -5178,6 +5179,7 @@ if TYPE_CHECKING: ) from .models.focalnet import ( FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST, + FocalNetBackbone, FocalNetForImageClassification, FocalNetForMaskedImageModeling, FocalNetModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index ccdda1af33..14847c7ad2 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -980,6 +980,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( ("convnext", "ConvNextBackbone"), ("convnextv2", "ConvNextV2Backbone"), ("dinat", "DinatBackbone"), + ("focalnet", "FocalNetBackbone"), ("maskformer-swin", "MaskFormerSwinBackbone"), ("nat", "NatBackbone"), ("resnet", "ResNetBackbone"), diff --git a/src/transformers/models/focalnet/__init__.py b/src/transformers/models/focalnet/__init__.py index e082ae26d2..b043a006f9 100644 --- a/src/transformers/models/focalnet/__init__.py +++ b/src/transformers/models/focalnet/__init__.py @@ -30,6 +30,7 @@ else: "FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST", "FocalNetForImageClassification", "FocalNetForMaskedImageModeling", + "FocalNetBackbone", "FocalNetModel", "FocalNetPreTrainedModel", ] @@ -45,6 +46,7 @@ if TYPE_CHECKING: else: from .modeling_focalnet import ( FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST, + FocalNetBackbone, FocalNetForImageClassification, FocalNetForMaskedImageModeling, FocalNetModel, diff --git a/src/transformers/models/focalnet/configuration_focalnet.py b/src/transformers/models/focalnet/configuration_focalnet.py index 5bfecb5737..c6814e1dda 100644 --- a/src/transformers/models/focalnet/configuration_focalnet.py +++ b/src/transformers/models/focalnet/configuration_focalnet.py @@ -47,6 +47,8 @@ class FocalNetConfig(PretrainedConfig): use_conv_embed (`bool`, *optional*, defaults to `False`): Whether to use convolutional embedding. The authors noted that using convolutional embedding usually improve the performance, but it's not used by default. + hidden_sizes (`List[int]`, *optional*, defaults to `[192, 384, 768, 768]`): + Dimensionality (hidden size) at each stage. depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`): Depth (number of layers) of each stage in the encoder. focal_levels (`list(int)`, *optional*, defaults to `[2, 2, 2, 2]`): @@ -78,6 +80,14 @@ class FocalNetConfig(PretrainedConfig): The epsilon used by the layer normalization layers. encoder_stride (`int`, `optional`, defaults to 32): Factor to increase the spatial resolution by in the decoder head for masked image modeling. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Example: @@ -102,6 +112,7 @@ class FocalNetConfig(PretrainedConfig): num_channels=3, embed_dim=96, use_conv_embed=False, + hidden_sizes=[192, 384, 768, 768], depths=[2, 2, 6, 2], focal_levels=[2, 2, 2, 2], focal_windows=[3, 3, 3, 3], @@ -117,6 +128,8 @@ class FocalNetConfig(PretrainedConfig): initializer_range=0.02, layer_norm_eps=1e-5, encoder_stride=32, + out_features=None, + out_indices=None, **kwargs, ): super().__init__(**kwargs) @@ -126,6 +139,7 @@ class FocalNetConfig(PretrainedConfig): self.num_channels = num_channels self.embed_dim = embed_dim self.use_conv_embed = use_conv_embed + self.hidden_sizes = hidden_sizes self.depths = depths self.focal_levels = focal_levels self.focal_windows = focal_windows @@ -141,3 +155,36 @@ class FocalNetConfig(PretrainedConfig): self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.encoder_stride = encoder_stride + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)] + + if out_features is not None and out_indices is not None: + if len(out_features) != len(out_indices): + raise ValueError("out_features and out_indices should have the same length if both are set") + elif out_features != [self.stage_names[idx] for idx in out_indices]: + raise ValueError("out_features and out_indices should correspond to the same stages if both are set") + + if out_features is None and out_indices is not None: + out_features = [self.stage_names[idx] for idx in out_indices] + elif out_features is not None and out_indices is None: + out_indices = [self.stage_names.index(feature) for feature in out_features] + elif out_features is None and out_indices is None: + out_features = [self.stage_names[-1]] + out_indices = [len(self.stage_names) - 1] + + if out_features is not None: + if not isinstance(out_features, list): + raise ValueError("out_features should be a list") + for feature in out_features: + if feature not in self.stage_names: + raise ValueError( + f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" + ) + if out_indices is not None: + if not isinstance(out_indices, (list, tuple)): + raise ValueError("out_indices should be a list or tuple") + for idx in out_indices: + if idx >= len(self.stage_names): + raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}") + + self.out_features = out_features + self.out_indices = out_indices diff --git a/src/transformers/models/focalnet/convert_focalnet_to_hf_format.py b/src/transformers/models/focalnet/convert_focalnet_to_hf_format.py index a23383e3ab..4aed159280 100644 --- a/src/transformers/models/focalnet/convert_focalnet_to_hf_format.py +++ b/src/transformers/models/focalnet/convert_focalnet_to_hf_format.py @@ -56,7 +56,6 @@ def get_focalnet_config(model_name): embed_dim = 128 elif "large" in model_name: embed_dim = 192 - focal_windows = [5, 5, 5, 5] elif "xlarge" in model_name: embed_dim = 256 elif "huge" in model_name: @@ -130,7 +129,10 @@ def convert_focalnet_checkpoint(model_name, pytorch_dump_folder_path, push_to_hu "focalnet-small-lrf": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_lrf.pth", "focalnet-base": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_srf.pth", "focalnet-base-lrf": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_lrf.pth", - "focalnet-large": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth", + "focalnet-large-lrf-fl3": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth", + "focalnet-large-lrf-fl4": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384_fl4.pth", + "focalnet-xlarge-lrf-fl3": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384.pth", + "focalnet-xlarge-lrf-fl4": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384_fl4.pth", } # fmt: on diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index ff1b75e14b..cfd6468976 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -26,7 +26,8 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_utils import PreTrainedModel +from ...modeling_outputs import BackboneOutput +from ...modeling_utils import BackboneMixin, PreTrainedModel from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -209,7 +210,6 @@ class FocalNetEmbeddings(nn.Module): embeddings = embeddings * (1.0 - mask) + mask_tokens * mask embeddings = self.dropout(embeddings) - return embeddings, output_dimensions @@ -971,3 +971,81 @@ class FocalNetForImageClassification(FocalNetPreTrainedModel): hidden_states=outputs.hidden_states, reshaped_hidden_states=outputs.reshaped_hidden_states, ) + + +@add_start_docstrings( + """ + FocalNet backbone, to be used with frameworks like X-Decoder. + """, + FOCALNET_START_DOCSTRING, +) +class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + + self.stage_names = config.stage_names + self.focalnet = FocalNetModel(config) + + self.num_features = [config.embed_dim] + config.hidden_sizes + self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] + if config.out_indices is not None: + self.out_indices = config.out_indices + else: + self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features) + + # initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(FOCALNET_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-tiny-lrf") + >>> model = AutoBackbone.from_pretrained("microsoft/focalnet-tiny-lrf") + + >>> inputs = processor(image, return_tensors="pt") + >>> outputs = model(**inputs) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs = self.focalnet(pixel_values, output_hidden_states=True, return_dict=True) + + hidden_states = outputs.reshaped_hidden_states + + feature_maps = () + for idx, stage in enumerate(self.stage_names): + if stage in self.out_features: + feature_maps += (hidden_states[idx],) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=None, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 7fe538eccc..1e9845ba9b 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3002,6 +3002,13 @@ class FNetPreTrainedModel(metaclass=DummyObject): FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST = None +class FocalNetBackbone(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class FocalNetForImageClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/focalnet/test_modeling_focalnet.py b/tests/models/focalnet/test_modeling_focalnet.py index 4ddf8e63b9..75127e5fd3 100644 --- a/tests/models/focalnet/test_modeling_focalnet.py +++ b/tests/models/focalnet/test_modeling_focalnet.py @@ -22,6 +22,7 @@ from transformers import FocalNetConfig from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available +from ...test_backbone_common import BackboneTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor @@ -30,7 +31,12 @@ if is_torch_available(): import torch from torch import nn - from transformers import FocalNetForImageClassification, FocalNetForMaskedImageModeling, FocalNetModel + from transformers import ( + FocalNetBackbone, + FocalNetForImageClassification, + FocalNetForMaskedImageModeling, + FocalNetModel, + ) from transformers.models.focalnet.modeling_focalnet import FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): @@ -48,6 +54,7 @@ class FocalNetModelTester: patch_size=2, num_channels=3, embed_dim=16, + hidden_sizes=[32, 64, 128], depths=[1, 2, 1], num_heads=[2, 2, 4], window_size=2, @@ -67,6 +74,7 @@ class FocalNetModelTester: type_sequence_label_size=10, encoder_stride=8, out_features=["stage1", "stage2"], + out_indices=[1, 2], ): self.parent = parent self.batch_size = batch_size @@ -74,6 +82,7 @@ class FocalNetModelTester: self.patch_size = patch_size self.num_channels = num_channels self.embed_dim = embed_dim + self.hidden_sizes = hidden_sizes self.depths = depths self.num_heads = num_heads self.window_size = window_size @@ -93,6 +102,7 @@ class FocalNetModelTester: self.type_sequence_label_size = type_sequence_label_size self.encoder_stride = encoder_stride self.out_features = out_features + self.out_indices = out_indices def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -111,6 +121,7 @@ class FocalNetModelTester: patch_size=self.patch_size, num_channels=self.num_channels, embed_dim=self.embed_dim, + hidden_sizes=self.hidden_sizes, depths=self.depths, num_heads=self.num_heads, window_size=self.window_size, @@ -126,6 +137,7 @@ class FocalNetModelTester: initializer_range=self.initializer_range, encoder_stride=self.encoder_stride, out_features=self.out_features, + out_indices=self.out_indices, ) def create_and_check_model(self, config, pixel_values, labels): @@ -139,6 +151,35 @@ class FocalNetModelTester: self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) + def create_and_check_backbone(self, config, pixel_values, labels): + model = FocalNetBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.image_size, 8, 8]) + + # verify channels + self.parent.assertEqual(len(model.channels), len(config.out_features)) + self.parent.assertListEqual(model.channels, config.hidden_sizes[:-1]) + + # verify backbone works with out_features=None + config.out_features = None + model = FocalNetBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), 1) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.image_size * 2, 4, 4]) + + # verify channels + self.parent.assertEqual(len(model.channels), 1) + self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]]) + def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels): model = FocalNetForMaskedImageModeling(config=config) model.to(torch_device) @@ -191,6 +232,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase): FocalNetModel, FocalNetForImageClassification, FocalNetForMaskedImageModeling, + FocalNetBackbone, ) if is_torch_available() else () @@ -204,7 +246,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = FocalNetModelTester(self) - self.config_tester = ConfigTester(self, config_class=FocalNetConfig, embed_dim=37) + self.config_tester = ConfigTester(self, config_class=FocalNetConfig, embed_dim=37, has_text_modality=False) def test_config(self): self.create_and_test_config_common_properties() @@ -222,6 +264,10 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_backbone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_backbone(*config_and_inputs) + def test_for_masked_image_modeling(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs) @@ -234,14 +280,14 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase): def test_inputs_embeds(self): pass - @unittest.skip(reason="FocalNet Transformer does not use feedforward chunking") + @unittest.skip(reason="FocalNet does not use feedforward chunking") def test_feed_forward_chunking(self): pass def test_model_common_attributes(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: + for model_class in self.all_model_classes[:-1]: model = model_class(config) self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) x = model.get_output_embeddings() @@ -250,7 +296,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase): def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: + for model_class in self.all_model_classes[:-1]: model = model_class(config) signature = inspect.signature(model.forward) # signature.parameters is an OrderedDict => so arg_names order is deterministic @@ -309,7 +355,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase): else (self.model_tester.image_size, self.model_tester.image_size) ) - for model_class in self.all_model_classes: + for model_class in self.all_model_classes[:-1]: inputs_dict["output_hidden_states"] = True self.check_hidden_states_output(inputs_dict, config, model_class, image_size) @@ -337,7 +383,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase): padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0]) padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1]) - for model_class in self.all_model_classes: + for model_class in self.all_model_classes[:-1]: inputs_dict["output_hidden_states"] = True self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width)) @@ -393,3 +439,14 @@ class FocalNetModelIntegrationTest(unittest.TestCase): expected_slice = torch.tensor([0.2166, -0.4368, 0.2191]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(outputs.logits.argmax(dim=-1).item(), 281) + + +@require_torch +class FocalNetBackboneTest(BackboneTesterMixin, unittest.TestCase): + all_model_classes = (FocalNetBackbone,) if is_torch_available() else () + config_class = FocalNetConfig + + has_attentions = False + + def setUp(self): + self.model_tester = FocalNetModelTester(self) diff --git a/tests/test_backbone_common.py b/tests/test_backbone_common.py index 80e68a2f44..6bcf47004b 100644 --- a/tests/test_backbone_common.py +++ b/tests/test_backbone_common.py @@ -135,6 +135,8 @@ class BackboneTesterMixin: # Verify num_features has been initialized in the backbone init self.assertIsNotNone(backbone.num_features) self.assertTrue(len(backbone.channels) == len(backbone.out_indices)) + print(backbone.stage_names) + print(backbone.num_features) self.assertTrue(len(backbone.stage_names) == len(backbone.num_features)) self.assertTrue(len(backbone.channels) <= len(backbone.num_features)) self.assertTrue(len(backbone.out_feature_channels) == len(backbone.stage_names)) diff --git a/utils/check_repo.py b/utils/check_repo.py index 5bdec16b9e..7280381faf 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -836,6 +836,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [ "ConvNextBackbone", "ConvNextV2Backbone", "DinatBackbone", + "FocalNetBackbone", "MaskFormerSwinBackbone", "MaskFormerSwinConfig", "MaskFormerSwinModel",