Add semantic segmentation post-processing method to MobileViT (#19105)

* add post-processing method for semantic segmentation

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Alara Dirik 2022-09-23 16:24:28 +03:00 committed by GitHub
parent 905635f5d3
commit 7e84723fe4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 2 deletions

View File

@ -66,6 +66,7 @@ This model was contributed by [matthijs](https://huggingface.co/Matthijs). The T
[[autodoc]] MobileViTFeatureExtractor
- __call__
- post_process_semantic_segmentation
## MobileViTModel

View File

@ -14,16 +14,19 @@
# limitations under the License.
"""Feature extractor class for MobileViT."""
from typing import Optional, Union
from typing import List, Optional, Tuple, Union
import numpy as np
from PIL import Image
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor
from ...utils import TensorType, logging
from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
@ -151,3 +154,46 @@ class MobileViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMi
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Only supports
PyTorch.
Args:
outputs ([`MobileViTForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]`, *optional*):
A list of length `batch_size`, where each item is a `Tuple[int, int]` corresponding to the requested
final size (height, width) of each prediction. If left to None, predictions will not be resized.
Returns:
`List[torch.Tensor]`:
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
`torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation

View File

@ -340,3 +340,27 @@ class MobileViTModelIntegrationTest(unittest.TestCase):
)
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
@slow
def test_post_processing_semantic_segmentation(self):
model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
model = model.to(torch_device)
feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
outputs.logits = outputs.logits.detach().cpu()
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(50, 60)])
expected_shape = torch.Size((50, 60))
self.assertEqual(segmentation[0].shape, expected_shape)
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
expected_shape = torch.Size((32, 32))
self.assertEqual(segmentation[0].shape, expected_shape)