Fix swin embeddings interpolation (#30936)
This commit is contained in:
parent
eae2b6b89e
commit
d24097e022
|
@ -205,9 +205,7 @@ class DonutSwinEmbeddings(nn.Module):
|
||||||
interpolate_pos_encoding: bool = False,
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
_, num_channels, height, width = pixel_values.shape
|
_, num_channels, height, width = pixel_values.shape
|
||||||
embeddings, output_dimensions = self.patch_embeddings(
|
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
||||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
|
||||||
)
|
|
||||||
embeddings = self.norm(embeddings)
|
embeddings = self.norm(embeddings)
|
||||||
batch_size, seq_len, _ = embeddings.size()
|
batch_size, seq_len, _ = embeddings.size()
|
||||||
|
|
||||||
|
@ -228,7 +226,7 @@ class DonutSwinEmbeddings(nn.Module):
|
||||||
return embeddings, output_dimensions
|
return embeddings, output_dimensions
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
|
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin
|
||||||
class DonutSwinPatchEmbeddings(nn.Module):
|
class DonutSwinPatchEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||||
|
@ -260,21 +258,10 @@ class DonutSwinPatchEmbeddings(nn.Module):
|
||||||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
def forward(
|
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||||
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
|
|
||||||
) -> Tuple[torch.Tensor, Tuple[int]]:
|
|
||||||
_, num_channels, height, width = pixel_values.shape
|
_, num_channels, height, width = pixel_values.shape
|
||||||
if num_channels != self.num_channels:
|
|
||||||
raise ValueError(
|
|
||||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
|
||||||
)
|
|
||||||
# pad the input to be divisible by self.patch_size, if needed
|
# pad the input to be divisible by self.patch_size, if needed
|
||||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
|
||||||
raise ValueError(
|
|
||||||
f"Input image size ({height}*{width}) doesn't match model"
|
|
||||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
|
||||||
)
|
|
||||||
embeddings = self.projection(pixel_values)
|
embeddings = self.projection(pixel_values)
|
||||||
_, _, height, width = embeddings.shape
|
_, _, height, width = embeddings.shape
|
||||||
output_dimensions = (height, width)
|
output_dimensions = (height, width)
|
||||||
|
|
|
@ -197,9 +197,7 @@ class MaskFormerSwinEmbeddings(nn.Module):
|
||||||
|
|
||||||
def forward(self, pixel_values, interpolate_pos_encoding):
|
def forward(self, pixel_values, interpolate_pos_encoding):
|
||||||
_, num_channels, height, width = pixel_values.shape
|
_, num_channels, height, width = pixel_values.shape
|
||||||
embeddings, output_dimensions = self.patch_embeddings(
|
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
||||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
|
||||||
)
|
|
||||||
embeddings = self.norm(embeddings)
|
embeddings = self.norm(embeddings)
|
||||||
|
|
||||||
if self.position_embeddings is not None:
|
if self.position_embeddings is not None:
|
||||||
|
@ -213,7 +211,7 @@ class MaskFormerSwinEmbeddings(nn.Module):
|
||||||
return embeddings, output_dimensions
|
return embeddings, output_dimensions
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
|
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->MaskFormerSwin
|
||||||
class MaskFormerSwinPatchEmbeddings(nn.Module):
|
class MaskFormerSwinPatchEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||||
|
@ -245,21 +243,10 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
|
||||||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
def forward(
|
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||||
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
|
|
||||||
) -> Tuple[torch.Tensor, Tuple[int]]:
|
|
||||||
_, num_channels, height, width = pixel_values.shape
|
_, num_channels, height, width = pixel_values.shape
|
||||||
if num_channels != self.num_channels:
|
|
||||||
raise ValueError(
|
|
||||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
|
||||||
)
|
|
||||||
# pad the input to be divisible by self.patch_size, if needed
|
# pad the input to be divisible by self.patch_size, if needed
|
||||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
|
||||||
raise ValueError(
|
|
||||||
f"Input image size ({height}*{width}) doesn't match model"
|
|
||||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
|
||||||
)
|
|
||||||
embeddings = self.projection(pixel_values)
|
embeddings = self.projection(pixel_values)
|
||||||
_, _, height, width = embeddings.shape
|
_, _, height, width = embeddings.shape
|
||||||
output_dimensions = (height, width)
|
output_dimensions = (height, width)
|
||||||
|
|
|
@ -291,9 +291,7 @@ class SwinEmbeddings(nn.Module):
|
||||||
interpolate_pos_encoding: bool = False,
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
_, num_channels, height, width = pixel_values.shape
|
_, num_channels, height, width = pixel_values.shape
|
||||||
embeddings, output_dimensions = self.patch_embeddings(
|
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
||||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
|
||||||
)
|
|
||||||
embeddings = self.norm(embeddings)
|
embeddings = self.norm(embeddings)
|
||||||
batch_size, seq_len, _ = embeddings.size()
|
batch_size, seq_len, _ = embeddings.size()
|
||||||
|
|
||||||
|
@ -345,21 +343,10 @@ class SwinPatchEmbeddings(nn.Module):
|
||||||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
def forward(
|
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||||
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
|
|
||||||
) -> Tuple[torch.Tensor, Tuple[int]]:
|
|
||||||
_, num_channels, height, width = pixel_values.shape
|
_, num_channels, height, width = pixel_values.shape
|
||||||
if num_channels != self.num_channels:
|
|
||||||
raise ValueError(
|
|
||||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
|
||||||
)
|
|
||||||
# pad the input to be divisible by self.patch_size, if needed
|
# pad the input to be divisible by self.patch_size, if needed
|
||||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
|
||||||
raise ValueError(
|
|
||||||
f"Input image size ({height}*{width}) doesn't match model"
|
|
||||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
|
||||||
)
|
|
||||||
embeddings = self.projection(pixel_values)
|
embeddings = self.projection(pixel_values)
|
||||||
_, _, height, width = embeddings.shape
|
_, _, height, width = embeddings.shape
|
||||||
output_dimensions = (height, width)
|
output_dimensions = (height, width)
|
||||||
|
|
|
@ -334,9 +334,7 @@ class Swinv2Embeddings(nn.Module):
|
||||||
interpolate_pos_encoding: bool = False,
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
_, num_channels, height, width = pixel_values.shape
|
_, num_channels, height, width = pixel_values.shape
|
||||||
embeddings, output_dimensions = self.patch_embeddings(
|
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
||||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
|
||||||
)
|
|
||||||
embeddings = self.norm(embeddings)
|
embeddings = self.norm(embeddings)
|
||||||
batch_size, seq_len, _ = embeddings.size()
|
batch_size, seq_len, _ = embeddings.size()
|
||||||
|
|
||||||
|
@ -389,21 +387,10 @@ class Swinv2PatchEmbeddings(nn.Module):
|
||||||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
def forward(
|
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||||
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
|
|
||||||
) -> Tuple[torch.Tensor, Tuple[int]]:
|
|
||||||
_, num_channels, height, width = pixel_values.shape
|
_, num_channels, height, width = pixel_values.shape
|
||||||
if num_channels != self.num_channels:
|
|
||||||
raise ValueError(
|
|
||||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
|
||||||
)
|
|
||||||
# pad the input to be divisible by self.patch_size, if needed
|
# pad the input to be divisible by self.patch_size, if needed
|
||||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
|
||||||
raise ValueError(
|
|
||||||
f"Input image size ({height}*{width}) doesn't match model"
|
|
||||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
|
||||||
)
|
|
||||||
embeddings = self.projection(pixel_values)
|
embeddings = self.projection(pixel_values)
|
||||||
_, _, height, width = embeddings.shape
|
_, _, height, width = embeddings.shape
|
||||||
output_dimensions = (height, width)
|
output_dimensions = (height, width)
|
||||||
|
|
Loading…
Reference in New Issue