DeformableDETR two stage support bfloat16 (#30907)

Update modeling_deformable_detr.py
This commit is contained in:
Donggeun Yu 2024-05-20 17:51:04 +09:00 committed by GitHub
parent 5d0bf59b4d
commit 66b0d9ee5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 2 deletions

View File

@ -1616,8 +1616,8 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = meshgrid(
torch.linspace(0, height - 1, height, dtype=torch.float32, device=enc_output.device),
torch.linspace(0, width - 1, width, dtype=torch.float32, device=enc_output.device),
torch.linspace(0, height - 1, height, dtype=enc_output.dtype, device=enc_output.device),
torch.linspace(0, width - 1, width, dtype=enc_output.dtype, device=enc_output.device),
indexing="ij",
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)