Fix swin embeddings interpolation (#30936)

This commit is contained in:
amyeroberts 2024-05-21 15:40:19 +01:00 committed by GitHub
parent eae2b6b89e
commit d24097e022
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 10 additions and 62 deletions

View File

@ -205,9 +205,7 @@ class DonutSwinEmbeddings(nn.Module):
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings( embeddings, output_dimensions = self.patch_embeddings(pixel_values)
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings) embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
@ -228,7 +226,7 @@ class DonutSwinEmbeddings(nn.Module):
return embeddings, output_dimensions return embeddings, output_dimensions
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings # Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin
class DonutSwinPatchEmbeddings(nn.Module): class DonutSwinPatchEmbeddings(nn.Module):
""" """
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
@ -260,21 +258,10 @@ class DonutSwinPatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values) pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values return pixel_values
def forward( def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape _, _, height, width = embeddings.shape
output_dimensions = (height, width) output_dimensions = (height, width)

View File

@ -197,9 +197,7 @@ class MaskFormerSwinEmbeddings(nn.Module):
def forward(self, pixel_values, interpolate_pos_encoding): def forward(self, pixel_values, interpolate_pos_encoding):
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings( embeddings, output_dimensions = self.patch_embeddings(pixel_values)
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings) embeddings = self.norm(embeddings)
if self.position_embeddings is not None: if self.position_embeddings is not None:
@ -213,7 +211,7 @@ class MaskFormerSwinEmbeddings(nn.Module):
return embeddings, output_dimensions return embeddings, output_dimensions
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings # Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->MaskFormerSwin
class MaskFormerSwinPatchEmbeddings(nn.Module): class MaskFormerSwinPatchEmbeddings(nn.Module):
""" """
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
@ -245,21 +243,10 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values) pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values return pixel_values
def forward( def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape _, _, height, width = embeddings.shape
output_dimensions = (height, width) output_dimensions = (height, width)

View File

@ -291,9 +291,7 @@ class SwinEmbeddings(nn.Module):
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings( embeddings, output_dimensions = self.patch_embeddings(pixel_values)
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings) embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
@ -345,21 +343,10 @@ class SwinPatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values) pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values return pixel_values
def forward( def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape _, _, height, width = embeddings.shape
output_dimensions = (height, width) output_dimensions = (height, width)

View File

@ -334,9 +334,7 @@ class Swinv2Embeddings(nn.Module):
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings( embeddings, output_dimensions = self.patch_embeddings(pixel_values)
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings) embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
@ -389,21 +387,10 @@ class Swinv2PatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values) pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values return pixel_values
def forward( def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape _, _, height, width = embeddings.shape
output_dimensions = (height, width) output_dimensions = (height, width)