Remove float64 cast for OwlVit and OwlV2 to support MPS device (#31071)
Remove float64
This commit is contained in:
parent
936ab7bae5
commit
c31473ed44
|
@ -1276,7 +1276,6 @@ class Owlv2ClassPredictionHead(nn.Module):
|
|||
if query_mask.ndim > 1:
|
||||
query_mask = torch.unsqueeze(query_mask, dim=-2)
|
||||
|
||||
pred_logits = pred_logits.to(torch.float64)
|
||||
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
|
||||
pred_logits = pred_logits.to(torch.float32)
|
||||
|
||||
|
|
|
@ -1257,7 +1257,6 @@ class OwlViTClassPredictionHead(nn.Module):
|
|||
if query_mask.ndim > 1:
|
||||
query_mask = torch.unsqueeze(query_mask, dim=-2)
|
||||
|
||||
pred_logits = pred_logits.to(torch.float64)
|
||||
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
|
||||
pred_logits = pred_logits.to(torch.float32)
|
||||
|
||||
|
|
Loading…
Reference in New Issue