Enable dynamic resolution input for Swin Transformer and variants (#30656)
* add interpolation of positional encoding support to swin * add style changes * use default image processor and make size a dictionary Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove logits testing Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Refactor image size validation logic when interpolation is disabled Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove asserts in modeling Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add dynamic resolution input support to swinv2 * change size to ensure interpolation encoding path is triggered * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False * add dynamic resolution input to donut swin * add dynamic resolution input to maskformer swin --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
b6eb708bf1
commit
481a957814
|
@ -166,10 +166,48 @@ class DonutSwinEmbeddings(nn.Module):
|
|||
self.norm = nn.LayerNorm(config.embed_dim)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor],
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
||||
_, num_channels, height, width = pixel_values.shape
|
||||
embeddings, output_dimensions = self.patch_embeddings(
|
||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
embeddings = self.norm(embeddings)
|
||||
batch_size, seq_len, _ = embeddings.size()
|
||||
|
||||
|
@ -180,7 +218,10 @@ class DonutSwinEmbeddings(nn.Module):
|
|||
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
||||
|
||||
if self.position_embeddings is not None:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
|
@ -219,7 +260,9 @@ class DonutSwinPatchEmbeddings(nn.Module):
|
|||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||
return pixel_values
|
||||
|
||||
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||
def forward(
|
||||
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
|
||||
) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||
_, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
|
@ -227,6 +270,11 @@ class DonutSwinPatchEmbeddings(nn.Module):
|
|||
)
|
||||
# pad the input to be divisible by self.patch_size, if needed
|
||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model"
|
||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
embeddings = self.projection(pixel_values)
|
||||
_, _, height, width = embeddings.shape
|
||||
output_dimensions = (height, width)
|
||||
|
@ -849,6 +897,8 @@ SWIN_INPUTS_DOCSTRING = r"""
|
|||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
@ -899,6 +949,7 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
|
|||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, DonutSwinModelOutput]:
|
||||
r"""
|
||||
|
@ -921,7 +972,9 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
|
|||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
||||
|
||||
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
embedding_output, input_dimensions = self.embeddings(
|
||||
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
|
|
|
@ -163,12 +163,50 @@ class MaskFormerSwinEmbeddings(nn.Module):
|
|||
self.norm = nn.LayerNorm(config.embed_dim)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
def forward(self, pixel_values, interpolate_pos_encoding):
|
||||
_, num_channels, height, width = pixel_values.shape
|
||||
embeddings, output_dimensions = self.patch_embeddings(
|
||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
embeddings = self.norm(embeddings)
|
||||
|
||||
if self.position_embeddings is not None:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
|
@ -207,7 +245,9 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
|
|||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||
return pixel_values
|
||||
|
||||
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||
def forward(
|
||||
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
|
||||
) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||
_, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
|
@ -215,6 +255,11 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
|
|||
)
|
||||
# pad the input to be divisible by self.patch_size, if needed
|
||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model"
|
||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
embeddings = self.projection(pixel_values)
|
||||
_, _, height, width = embeddings.shape
|
||||
output_dimensions = (height, width)
|
||||
|
@ -780,6 +825,7 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
|
|||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
interpolate_pos_encoding=False,
|
||||
return_dict=None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
@ -798,7 +844,9 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
|
|||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
||||
|
||||
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
||||
embedding_output, input_dimensions = self.embeddings(
|
||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
|
|
|
@ -252,10 +252,48 @@ class SwinEmbeddings(nn.Module):
|
|||
self.norm = nn.LayerNorm(config.embed_dim)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor],
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
||||
_, num_channels, height, width = pixel_values.shape
|
||||
embeddings, output_dimensions = self.patch_embeddings(
|
||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
embeddings = self.norm(embeddings)
|
||||
batch_size, seq_len, _ = embeddings.size()
|
||||
|
||||
|
@ -266,7 +304,10 @@ class SwinEmbeddings(nn.Module):
|
|||
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
||||
|
||||
if self.position_embeddings is not None:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
|
@ -304,7 +345,9 @@ class SwinPatchEmbeddings(nn.Module):
|
|||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||
return pixel_values
|
||||
|
||||
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||
def forward(
|
||||
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
|
||||
) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||
_, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
|
@ -312,6 +355,11 @@ class SwinPatchEmbeddings(nn.Module):
|
|||
)
|
||||
# pad the input to be divisible by self.patch_size, if needed
|
||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model"
|
||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
embeddings = self.projection(pixel_values)
|
||||
_, _, height, width = embeddings.shape
|
||||
output_dimensions = (height, width)
|
||||
|
@ -924,6 +972,8 @@ SWIN_INPUTS_DOCSTRING = r"""
|
|||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
@ -981,6 +1031,7 @@ class SwinModel(SwinPreTrainedModel):
|
|||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SwinModelOutput]:
|
||||
r"""
|
||||
|
@ -1003,7 +1054,9 @@ class SwinModel(SwinPreTrainedModel):
|
|||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
||||
|
||||
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
embedding_output, input_dimensions = self.embeddings(
|
||||
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
|
@ -1074,6 +1127,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
|||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SwinMaskedImageModelingOutput]:
|
||||
r"""
|
||||
|
@ -1113,6 +1167,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
|||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
@ -1156,6 +1211,14 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
|||
"""
|
||||
Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
||||
the [CLS] token) e.g. for ImageNet.
|
||||
|
||||
<Tip>
|
||||
|
||||
Note that it's possible to fine-tune Swin on higher resolution images than the ones it has been trained on, by
|
||||
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
|
||||
position embeddings to the higher resolution.
|
||||
|
||||
</Tip>
|
||||
""",
|
||||
SWIN_START_DOCSTRING,
|
||||
)
|
||||
|
@ -1188,6 +1251,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
|
|||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SwinImageClassifierOutput]:
|
||||
r"""
|
||||
|
@ -1203,6 +1267,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
|
|||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
|
|
@ -295,10 +295,48 @@ class Swinv2Embeddings(nn.Module):
|
|||
self.norm = nn.LayerNorm(config.embed_dim)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor],
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
||||
_, num_channels, height, width = pixel_values.shape
|
||||
embeddings, output_dimensions = self.patch_embeddings(
|
||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
embeddings = self.norm(embeddings)
|
||||
batch_size, seq_len, _ = embeddings.size()
|
||||
|
||||
|
@ -309,7 +347,10 @@ class Swinv2Embeddings(nn.Module):
|
|||
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
||||
|
||||
if self.position_embeddings is not None:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
|
@ -348,7 +389,9 @@ class Swinv2PatchEmbeddings(nn.Module):
|
|||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||
return pixel_values
|
||||
|
||||
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||
def forward(
|
||||
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
|
||||
) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||
_, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
|
@ -356,6 +399,11 @@ class Swinv2PatchEmbeddings(nn.Module):
|
|||
)
|
||||
# pad the input to be divisible by self.patch_size, if needed
|
||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model"
|
||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
embeddings = self.projection(pixel_values)
|
||||
_, _, height, width = embeddings.shape
|
||||
output_dimensions = (height, width)
|
||||
|
@ -979,6 +1027,8 @@ SWINV2_INPUTS_DOCSTRING = r"""
|
|||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
@ -1031,6 +1081,7 @@ class Swinv2Model(Swinv2PreTrainedModel):
|
|||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Swinv2ModelOutput]:
|
||||
r"""
|
||||
|
@ -1053,7 +1104,9 @@ class Swinv2Model(Swinv2PreTrainedModel):
|
|||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
||||
|
||||
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
embedding_output, input_dimensions = self.embeddings(
|
||||
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
|
@ -1126,6 +1179,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
|
|||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Swinv2MaskedImageModelingOutput]:
|
||||
r"""
|
||||
|
@ -1165,6 +1219,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
|
|||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
@ -1208,6 +1263,14 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
|
|||
"""
|
||||
Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
|
||||
of the [CLS] token) e.g. for ImageNet.
|
||||
|
||||
<Tip>
|
||||
|
||||
Note that it's possible to fine-tune SwinV2 on higher resolution images than the ones it has been trained on, by
|
||||
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
|
||||
position embeddings to the higher resolution.
|
||||
|
||||
</Tip>
|
||||
""",
|
||||
SWINV2_START_DOCSTRING,
|
||||
)
|
||||
|
@ -1241,6 +1304,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel):
|
|||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Swinv2ImageClassifierOutput]:
|
||||
r"""
|
||||
|
@ -1256,6 +1320,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel):
|
|||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
|
|
@ -493,6 +493,26 @@ class SwinModelIntegrationTest(unittest.TestCase):
|
|||
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
# Swin models have an `interpolate_pos_encoding` argument in their forward method,
|
||||
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||
# the model on higher resolutions.
|
||||
model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").to(torch_device)
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt")
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 256, 768))
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
class SwinBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||
|
|
|
@ -485,6 +485,26 @@ class Swinv2ModelIntegrationTest(unittest.TestCase):
|
|||
expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
# Swinv2 models have an `interpolate_pos_encoding` argument in their forward method,
|
||||
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||
# the model on higher resolutions.
|
||||
model = Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256").to(torch_device)
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt")
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 256, 768))
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
class Swinv2BackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||
|
|
Loading…
Reference in New Issue