DeformableDETR two stage support bfloat16 (#30907)
Update modeling_deformable_detr.py
This commit is contained in:
parent
5d0bf59b4d
commit
66b0d9ee5d
|
@ -1616,8 +1616,8 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||
valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
|
||||
|
||||
grid_y, grid_x = meshgrid(
|
||||
torch.linspace(0, height - 1, height, dtype=torch.float32, device=enc_output.device),
|
||||
torch.linspace(0, width - 1, width, dtype=torch.float32, device=enc_output.device),
|
||||
torch.linspace(0, height - 1, height, dtype=enc_output.dtype, device=enc_output.device),
|
||||
torch.linspace(0, width - 1, width, dtype=enc_output.dtype, device=enc_output.device),
|
||||
indexing="ij",
|
||||
)
|
||||
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
|
||||
|
|
Loading…
Reference in New Issue