From d24097e0229485287ff4959258c552168bd898c6 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Tue, 21 May 2024 15:40:19 +0100 Subject: [PATCH] Fix swin embeddings interpolation (#30936) --- .../models/donut/modeling_donut_swin.py | 19 +++---------------- .../maskformer/modeling_maskformer_swin.py | 19 +++---------------- src/transformers/models/swin/modeling_swin.py | 17 ++--------------- .../models/swinv2/modeling_swinv2.py | 17 ++--------------- 4 files changed, 10 insertions(+), 62 deletions(-) diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index c1b27cd180..5b8a0e27b0 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -205,9 +205,7 @@ class DonutSwinEmbeddings(nn.Module): interpolate_pos_encoding: bool = False, ) -> Tuple[torch.Tensor]: _, num_channels, height, width = pixel_values.shape - embeddings, output_dimensions = self.patch_embeddings( - pixel_values, interpolate_pos_encoding=interpolate_pos_encoding - ) + embeddings, output_dimensions = self.patch_embeddings(pixel_values) embeddings = self.norm(embeddings) batch_size, seq_len, _ = embeddings.size() @@ -228,7 +226,7 @@ class DonutSwinEmbeddings(nn.Module): return embeddings, output_dimensions -# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings +# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin class DonutSwinPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial @@ -260,21 +258,10 @@ class DonutSwinPatchEmbeddings(nn.Module): pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values - def forward( - self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False - ) -> Tuple[torch.Tensor, Tuple[int]]: + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: _, num_channels, height, width = pixel_values.shape - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) - if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) embeddings = self.projection(pixel_values) _, _, height, width = embeddings.shape output_dimensions = (height, width) diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index fc9c642adc..ef607ec811 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -197,9 +197,7 @@ class MaskFormerSwinEmbeddings(nn.Module): def forward(self, pixel_values, interpolate_pos_encoding): _, num_channels, height, width = pixel_values.shape - embeddings, output_dimensions = self.patch_embeddings( - pixel_values, interpolate_pos_encoding=interpolate_pos_encoding - ) + embeddings, output_dimensions = self.patch_embeddings(pixel_values) embeddings = self.norm(embeddings) if self.position_embeddings is not None: @@ -213,7 +211,7 @@ class MaskFormerSwinEmbeddings(nn.Module): return embeddings, output_dimensions -# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings +# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->MaskFormerSwin class MaskFormerSwinPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial @@ -245,21 +243,10 @@ class MaskFormerSwinPatchEmbeddings(nn.Module): pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values - def forward( - self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False - ) -> Tuple[torch.Tensor, Tuple[int]]: + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: _, num_channels, height, width = pixel_values.shape - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) - if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) embeddings = self.projection(pixel_values) _, _, height, width = embeddings.shape output_dimensions = (height, width) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 2a6363c8e6..9f64f8009a 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -291,9 +291,7 @@ class SwinEmbeddings(nn.Module): interpolate_pos_encoding: bool = False, ) -> Tuple[torch.Tensor]: _, num_channels, height, width = pixel_values.shape - embeddings, output_dimensions = self.patch_embeddings( - pixel_values, interpolate_pos_encoding=interpolate_pos_encoding - ) + embeddings, output_dimensions = self.patch_embeddings(pixel_values) embeddings = self.norm(embeddings) batch_size, seq_len, _ = embeddings.size() @@ -345,21 +343,10 @@ class SwinPatchEmbeddings(nn.Module): pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values - def forward( - self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False - ) -> Tuple[torch.Tensor, Tuple[int]]: + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: _, num_channels, height, width = pixel_values.shape - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) - if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) embeddings = self.projection(pixel_values) _, _, height, width = embeddings.shape output_dimensions = (height, width) diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index ac8ec197e5..48a4c65722 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -334,9 +334,7 @@ class Swinv2Embeddings(nn.Module): interpolate_pos_encoding: bool = False, ) -> Tuple[torch.Tensor]: _, num_channels, height, width = pixel_values.shape - embeddings, output_dimensions = self.patch_embeddings( - pixel_values, interpolate_pos_encoding=interpolate_pos_encoding - ) + embeddings, output_dimensions = self.patch_embeddings(pixel_values) embeddings = self.norm(embeddings) batch_size, seq_len, _ = embeddings.size() @@ -389,21 +387,10 @@ class Swinv2PatchEmbeddings(nn.Module): pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values - def forward( - self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False - ) -> Tuple[torch.Tensor, Tuple[int]]: + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: _, num_channels, height, width = pixel_values.shape - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) - if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) embeddings = self.projection(pixel_values) _, _, height, width = embeddings.shape output_dimensions = (height, width)