Fix BeitFeatureExtractor postprocessing (#19119)

* return post-processed segmentations as list, add test
* use torch to resize logits
* fix assertion error if no target_size is specified
This commit is contained in:
Alara Dirik 2022-09-20 18:53:40 +03:00 committed by GitHub
parent 36e356caa4
commit 36b9a99433
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 22 deletions

View File

@ -226,43 +226,43 @@ class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
return encoded_inputs
def post_process_semantic_segmentation(self, outputs, target_sizes: Union[TensorType, List[Tuple]] = None):
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`BeitForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*):
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length
`batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if
`target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
if len(logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes is not None and target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
semantic_segmentation = logits.argmax(dim=1)
# Resize semantic segmentation maps
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
resized_maps = []
semantic_segmentation = semantic_segmentation.numpy()
semantic_segmentation = []
for idx in range(len(semantic_segmentation)):
resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx])
resized_maps.append(resized)
semantic_segmentation = [torch.Tensor(np.array(image)) for image in resized_maps]
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation

View File

@ -455,3 +455,28 @@ class BeitModelIntegrationTest(unittest.TestCase):
)
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
@slow
def test_post_processing_semantic_segmentation(self):
model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
model = model.to(torch_device)
feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False)
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
image = Image.open(ds[0]["file"])
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
outputs.logits = outputs.logits.detach().cpu()
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)])
expected_shape = torch.Size((500, 300))
self.assertEqual(segmentation[0].shape, expected_shape)
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
expected_shape = torch.Size((160, 160))
self.assertEqual(segmentation[0].shape, expected_shape)