Fix model parallelism for `BridgeTower` (#23039)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
d337631b91
commit
b6865b9bef
|
@ -981,7 +981,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
|
|||
config_class = BridgeTowerConfig
|
||||
base_model_prefix = "bridgetower"
|
||||
supports_gradient_checkpointing = False
|
||||
_no_split_modules = ["BridgeTowerSelfAttention"]
|
||||
_no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, BridgeTowerVisionModel):
|
||||
|
@ -1863,12 +1863,16 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
|
|||
|
||||
# normalized features
|
||||
text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2)
|
||||
image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2)
|
||||
cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2)
|
||||
image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2).to(
|
||||
device=text_embeds.device
|
||||
)
|
||||
cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2).to(
|
||||
device=text_embeds.device
|
||||
)
|
||||
|
||||
logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2)
|
||||
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logit_scale = self.logit_scale.exp().to(device=text_embeds.device)
|
||||
logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
||||
logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale
|
||||
logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale
|
||||
|
|
Loading…
Reference in New Issue