Add PerSAM [bis] (#23659)
* Add PerSAM args * Make attn_sim optional * Rename to attention_similarity * Add docstrigns * Improve docstrings
This commit is contained in:
parent
aa30cd4f3f
commit
527ab894e5
|
@ -224,7 +224,7 @@ class SamAttention(nn.Module):
|
|||
hidden_states = hidden_states.transpose(1, 2)
|
||||
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
||||
|
||||
def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
|
||||
def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor:
|
||||
# Input projections
|
||||
query = self.q_proj(query)
|
||||
key = self.k_proj(key)
|
||||
|
@ -242,6 +242,10 @@ class SamAttention(nn.Module):
|
|||
attn = attn / math.sqrt(c_per_head)
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
|
||||
if attention_similarity is not None:
|
||||
attn = attn + attention_similarity
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
|
||||
# Get output
|
||||
out = attn @ value
|
||||
out = self._recombine_heads(out, point_batch_size)
|
||||
|
@ -290,6 +294,7 @@ class SamTwoWayAttentionBlock(nn.Module):
|
|||
keys: Tensor,
|
||||
query_point_embedding: Tensor,
|
||||
key_point_embedding: Tensor,
|
||||
attention_similarity: Tensor,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# Self attention block
|
||||
|
@ -305,7 +310,9 @@ class SamTwoWayAttentionBlock(nn.Module):
|
|||
query = queries + query_point_embedding
|
||||
key = keys + key_point_embedding
|
||||
|
||||
attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
|
||||
attn_out = self.cross_attn_token_to_image(
|
||||
query=query, key=key, value=keys, attention_similarity=attention_similarity
|
||||
)
|
||||
queries = queries + attn_out
|
||||
|
||||
queries = self.layer_norm2(queries)
|
||||
|
@ -353,6 +360,8 @@ class SamTwoWayTransformer(nn.Module):
|
|||
point_embeddings: Tensor,
|
||||
image_embeddings: Tensor,
|
||||
image_positional_embeddings: Tensor,
|
||||
attention_similarity: Tensor,
|
||||
target_embedding=None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
|
@ -377,11 +386,15 @@ class SamTwoWayTransformer(nn.Module):
|
|||
|
||||
# Apply transformer blocks and final layernorm
|
||||
for layer in self.layers:
|
||||
if target_embedding is not None:
|
||||
queries += target_embedding
|
||||
|
||||
queries, keys, attention_outputs = layer(
|
||||
queries=queries,
|
||||
keys=keys,
|
||||
query_point_embedding=point_embeddings,
|
||||
key_point_embedding=image_positional_embeddings,
|
||||
attention_similarity=attention_similarity,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
@ -460,6 +473,8 @@ class SamMaskDecoder(nn.Module):
|
|||
dense_prompt_embeddings: torch.Tensor,
|
||||
multimask_output: bool,
|
||||
output_attentions: Optional[bool] = None,
|
||||
attention_similarity: torch.Tensor = None,
|
||||
target_embedding: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
|
@ -500,6 +515,8 @@ class SamMaskDecoder(nn.Module):
|
|||
point_embeddings=point_embeddings,
|
||||
image_embeddings=image_embeddings,
|
||||
image_positional_embeddings=image_positional_embeddings,
|
||||
attention_similarity=attention_similarity,
|
||||
target_embedding=target_embedding,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
iou_token_out = point_embedding[:, :, 0, :]
|
||||
|
@ -576,8 +593,12 @@ class SamMaskEmbedding(nn.Module):
|
|||
self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
|
||||
self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
|
||||
self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
|
||||
self.layer_norm1 = SamLayerNorm(self.mask_input_channels, config.layer_norm_eps)
|
||||
self.layer_norm2 = SamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps)
|
||||
self.layer_norm1 = SamLayerNorm(
|
||||
self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
|
||||
)
|
||||
self.layer_norm2 = SamLayerNorm(
|
||||
self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
|
||||
)
|
||||
|
||||
def forward(self, masks):
|
||||
hidden_states = self.conv1(masks)
|
||||
|
@ -1146,6 +1167,12 @@ SAM_INPUTS_DOCSTRING = r"""
|
|||
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
|
||||
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
|
||||
"best" mask, by specifying `multimask_output=False`.
|
||||
attention_similarity (`torch.FloatTensor`, *optional*):
|
||||
Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
|
||||
model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
|
||||
target_embedding (`torch.FloatTensor`, *optional*):
|
||||
Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
|
||||
the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
|
@ -1265,6 +1292,8 @@ class SamModel(SamPreTrainedModel):
|
|||
input_masks: Optional[torch.LongTensor] = None,
|
||||
image_embeddings: Optional[torch.FloatTensor] = None,
|
||||
multimask_output: bool = True,
|
||||
attention_similarity: Optional[torch.FloatTensor] = None,
|
||||
target_embedding: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict=None,
|
||||
|
@ -1374,6 +1403,8 @@ class SamModel(SamPreTrainedModel):
|
|||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
attention_similarity=attention_similarity,
|
||||
target_embedding=target_embedding,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue