Control first downsample stride in ResNet (#26374)

* control first downsample stride

* reduce first only works for ResNetBottleNeckLayer

* fix param name

* fix style
This commit is contained in:
jiqing-feng 2023-10-10 12:45:24 +08:00 committed by GitHub
parent a5e6df82c0
commit 592f2eabd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 7 deletions

View File

@ -59,6 +59,8 @@ class ResNetConfig(BackboneConfigMixin, PretrainedConfig):
are supported.
downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
If `True`, the first stage will downsample the inputs using a `stride` of 2.
downsample_in_bottleneck (`bool`, *optional*, defaults to `False`):
If `True`, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a `stride` of 2.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
@ -94,6 +96,7 @@ class ResNetConfig(BackboneConfigMixin, PretrainedConfig):
layer_type="bottleneck",
hidden_act="relu",
downsample_in_first_stage=False,
downsample_in_bottleneck=False,
out_features=None,
out_indices=None,
**kwargs,
@ -108,6 +111,7 @@ class ResNetConfig(BackboneConfigMixin, PretrainedConfig):
self.layer_type = layer_type
self.hidden_act = hidden_act
self.downsample_in_first_stage = downsample_in_first_stage
self.downsample_in_bottleneck = downsample_in_bottleneck
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names

View File

@ -149,11 +149,18 @@ class ResNetBottleNeckLayer(nn.Module):
A classic ResNet's bottleneck layer composed by three `3x3` convolutions.
The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`.
convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. If
`downsample_in_bottleneck` is true, downsample will be in the first layer instead of the second layer.
"""
def __init__(
self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu", reduction: int = 4
self,
in_channels: int,
out_channels: int,
stride: int = 1,
activation: str = "relu",
reduction: int = 4,
downsample_in_bottleneck: bool = False,
):
super().__init__()
should_apply_shortcut = in_channels != out_channels or stride != 1
@ -162,8 +169,10 @@ class ResNetBottleNeckLayer(nn.Module):
ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
)
self.layer = nn.Sequential(
ResNetConvLayer(in_channels, reduces_channels, kernel_size=1),
ResNetConvLayer(reduces_channels, reduces_channels, stride=stride),
ResNetConvLayer(
in_channels, reduces_channels, kernel_size=1, stride=stride if downsample_in_bottleneck else 1
),
ResNetConvLayer(reduces_channels, reduces_channels, stride=stride if not downsample_in_bottleneck else 1),
ResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None),
)
self.activation = ACT2FN[activation]
@ -194,10 +203,18 @@ class ResNetStage(nn.Module):
layer = ResNetBottleNeckLayer if config.layer_type == "bottleneck" else ResNetBasicLayer
if config.layer_type == "bottleneck":
first_layer = layer(
in_channels,
out_channels,
stride=stride,
activation=config.hidden_act,
downsample_in_bottleneck=config.downsample_in_bottleneck,
)
else:
first_layer = layer(in_channels, out_channels, stride=stride, activation=config.hidden_act)
self.layers = nn.Sequential(
# downsampling is done in the first layer with stride of 2
layer(in_channels, out_channels, stride=stride, activation=config.hidden_act),
*[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)],
first_layer, *[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)]
)
def forward(self, input: Tensor) -> Tensor: