Add PerSAM [bis] (#23659)

* Add PerSAM args

* Make attn_sim optional

* Rename to attention_similarity

* Add docstrigns

* Improve docstrings
This commit is contained in:
NielsRogge 2023-05-23 11:43:12 +02:00 committed by GitHub
parent aa30cd4f3f
commit 527ab894e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 35 additions and 4 deletions

View File

@ -224,7 +224,7 @@ class SamAttention(nn.Module):
hidden_states = hidden_states.transpose(1, 2)
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
@ -242,6 +242,10 @@ class SamAttention(nn.Module):
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)
if attention_similarity is not None:
attn = attn + attention_similarity
attn = torch.softmax(attn, dim=-1)
# Get output
out = attn @ value
out = self._recombine_heads(out, point_batch_size)
@ -290,6 +294,7 @@ class SamTwoWayAttentionBlock(nn.Module):
keys: Tensor,
query_point_embedding: Tensor,
key_point_embedding: Tensor,
attention_similarity: Tensor,
output_attentions: bool = False,
):
# Self attention block
@ -305,7 +310,9 @@ class SamTwoWayAttentionBlock(nn.Module):
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
attn_out = self.cross_attn_token_to_image(
query=query, key=key, value=keys, attention_similarity=attention_similarity
)
queries = queries + attn_out
queries = self.layer_norm2(queries)
@ -353,6 +360,8 @@ class SamTwoWayTransformer(nn.Module):
point_embeddings: Tensor,
image_embeddings: Tensor,
image_positional_embeddings: Tensor,
attention_similarity: Tensor,
target_embedding=None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@ -377,11 +386,15 @@ class SamTwoWayTransformer(nn.Module):
# Apply transformer blocks and final layernorm
for layer in self.layers:
if target_embedding is not None:
queries += target_embedding
queries, keys, attention_outputs = layer(
queries=queries,
keys=keys,
query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings,
attention_similarity=attention_similarity,
output_attentions=output_attentions,
)
@ -460,6 +473,8 @@ class SamMaskDecoder(nn.Module):
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
output_attentions: Optional[bool] = None,
attention_similarity: torch.Tensor = None,
target_embedding: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
@ -500,6 +515,8 @@ class SamMaskDecoder(nn.Module):
point_embeddings=point_embeddings,
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
)
iou_token_out = point_embedding[:, :, 0, :]
@ -576,8 +593,12 @@ class SamMaskEmbedding(nn.Module):
self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
self.layer_norm1 = SamLayerNorm(self.mask_input_channels, config.layer_norm_eps)
self.layer_norm2 = SamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps)
self.layer_norm1 = SamLayerNorm(
self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
)
self.layer_norm2 = SamLayerNorm(
self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
)
def forward(self, masks):
hidden_states = self.conv1(masks)
@ -1146,6 +1167,12 @@ SAM_INPUTS_DOCSTRING = r"""
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
"best" mask, by specifying `multimask_output=False`.
attention_similarity (`torch.FloatTensor`, *optional*):
Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
target_embedding (`torch.FloatTensor`, *optional*):
Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
@ -1265,6 +1292,8 @@ class SamModel(SamPreTrainedModel):
input_masks: Optional[torch.LongTensor] = None,
image_embeddings: Optional[torch.FloatTensor] = None,
multimask_output: bool = True,
attention_similarity: Optional[torch.FloatTensor] = None,
target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict=None,
@ -1374,6 +1403,8 @@ class SamModel(SamPreTrainedModel):
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
)