enable training mask2former and maskformer for transformers trainer (#28277)
* fix get_num_masks output as [int] to int * fix loss size from torch.Size([1]) to torch.Size([])
This commit is contained in:
parent
6b8ec2588e
commit
4a66c0d952
|
@ -787,7 +787,7 @@ class Mask2FormerLoss(nn.Module):
|
|||
Computes the average number of target masks across the batch, for normalization purposes.
|
||||
"""
|
||||
num_masks = sum([len(classes) for classes in class_labels])
|
||||
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
|
||||
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
|
||||
return num_masks_pt
|
||||
|
||||
|
||||
|
|
|
@ -1193,7 +1193,7 @@ class MaskFormerLoss(nn.Module):
|
|||
Computes the average number of target masks across the batch, for normalization purposes.
|
||||
"""
|
||||
num_masks = sum([len(classes) for classes in class_labels])
|
||||
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
|
||||
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
|
||||
return num_masks_pt
|
||||
|
||||
|
||||
|
|
|
@ -190,7 +190,7 @@ class Mask2FormerModelTester:
|
|||
comm_check_on_output(result)
|
||||
|
||||
self.parent.assertTrue(result.loss is not None)
|
||||
self.parent.assertEqual(result.loss.shape, torch.Size([1]))
|
||||
self.parent.assertEqual(result.loss.shape, torch.Size([]))
|
||||
|
||||
|
||||
@require_torch
|
||||
|
|
|
@ -189,7 +189,7 @@ class MaskFormerModelTester:
|
|||
comm_check_on_output(result)
|
||||
|
||||
self.parent.assertTrue(result.loss is not None)
|
||||
self.parent.assertEqual(result.loss.shape, torch.Size([1]))
|
||||
self.parent.assertEqual(result.loss.shape, torch.Size([]))
|
||||
|
||||
|
||||
@require_torch
|
||||
|
|
Loading…
Reference in New Issue