Remove float64 cast for OwlVit and OwlV2 to support MPS device (#31071)

Remove float64
This commit is contained in:
Pavel Iakubovskii 2024-05-28 10:41:40 +00:00 committed by GitHub
parent 936ab7bae5
commit c31473ed44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 0 additions and 2 deletions

View File

@ -1276,7 +1276,6 @@ class Owlv2ClassPredictionHead(nn.Module):
if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2)
pred_logits = pred_logits.to(torch.float64)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
pred_logits = pred_logits.to(torch.float32)

View File

@ -1257,7 +1257,6 @@ class OwlViTClassPredictionHead(nn.Module):
if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2)
pred_logits = pred_logits.to(torch.float64)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
pred_logits = pred_logits.to(torch.float32)