Fix model parallelism for `BridgeTower` (#23039)

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-04-28 21:53:58 +02:00 committed by GitHub
parent d337631b91
commit b6865b9bef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 4 deletions

View File

@ -981,7 +981,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
config_class = BridgeTowerConfig
base_model_prefix = "bridgetower"
supports_gradient_checkpointing = False
_no_split_modules = ["BridgeTowerSelfAttention"]
_no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
def _init_weights(self, module):
if isinstance(module, BridgeTowerVisionModel):
@ -1863,12 +1863,16 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
# normalized features
text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2)
image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2)
cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2)
image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2).to(
device=text_embeds.device
)
cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2).to(
device=text_embeds.device
)
logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2)
logit_scale = self.logit_scale.exp()
logit_scale = self.logit_scale.exp().to(device=text_embeds.device)
logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale
logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale