mmpose/tests/test_codecs/test_associative_embedding.py

243 lines
8.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from itertools import product
from unittest import TestCase
import numpy as np
import torch
from munkres import Munkres
from mmpose.codecs import AssociativeEmbedding
from mmpose.registry import KEYPOINT_CODECS
from mmpose.testing import get_coco_sample
class TestAssociativeEmbedding(TestCase):
def setUp(self) -> None:
self.decode_keypoint_order = [
0, 1, 2, 3, 4, 5, 6, 11, 12, 7, 8, 9, 10, 13, 14, 15, 16
]
def test_build(self):
cfg = dict(
type='AssociativeEmbedding',
input_size=(256, 256),
heatmap_size=(64, 64),
use_udp=False,
decode_keypoint_order=self.decode_keypoint_order,
)
codec = KEYPOINT_CODECS.build(cfg)
self.assertIsInstance(codec, AssociativeEmbedding)
def test_encode(self):
data = get_coco_sample(img_shape=(256, 256), num_instances=1)
# w/o UDP
codec = AssociativeEmbedding(
input_size=(256, 256),
heatmap_size=(64, 64),
use_udp=False,
decode_keypoint_order=self.decode_keypoint_order)
encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
heatmaps = encoded['heatmaps']
keypoint_indices = encoded['keypoint_indices']
keypoint_weights = encoded['keypoint_weights']
self.assertEqual(heatmaps.shape, (17, 64, 64))
self.assertEqual(keypoint_indices.shape, (1, 17, 2))
self.assertEqual(keypoint_weights.shape, (1, 17))
for k in range(heatmaps.shape[0]):
index_expected = np.argmax(heatmaps[k])
index_encoded = keypoint_indices[0, k, 0]
self.assertEqual(index_expected, index_encoded)
# w/ UDP
codec = AssociativeEmbedding(
input_size=(256, 256),
heatmap_size=(64, 64),
use_udp=True,
decode_keypoint_order=self.decode_keypoint_order)
encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
heatmaps = encoded['heatmaps']
keypoint_indices = encoded['keypoint_indices']
keypoint_weights = encoded['keypoint_weights']
self.assertEqual(heatmaps.shape, (17, 64, 64))
self.assertEqual(keypoint_indices.shape, (1, 17, 2))
self.assertEqual(keypoint_weights.shape, (1, 17))
for k in range(heatmaps.shape[0]):
index_expected = np.argmax(heatmaps[k])
index_encoded = keypoint_indices[0, k, 0]
self.assertEqual(index_expected, index_encoded)
def _get_tags(self,
heatmaps,
keypoint_indices,
tag_per_keypoint: bool,
tag_dim: int = 1):
K, H, W = heatmaps.shape
N = keypoint_indices.shape[0]
if tag_per_keypoint:
tags = np.zeros((K * tag_dim, H, W), dtype=np.float32)
else:
tags = np.zeros((tag_dim, H, W), dtype=np.float32)
for n, k in product(range(N), range(K)):
y, x = np.unravel_index(keypoint_indices[n, k, 0], (H, W))
if tag_per_keypoint:
tags[k::K, y, x] = n
else:
tags[:, y, x] = n
return tags
def _sort_preds(self, keypoints_pred, scores_pred, keypoints_gt):
"""Sort multi-instance predictions to best match the ground-truth.
Args:
keypoints_pred (np.ndarray): predictions in shape (N, K, D)
scores (np.ndarray): predictions in shape (N, K)
keypoints_gt (np.ndarray): ground-truth in shape (N, K, D)
Returns:
np.ndarray: Sorted predictions
"""
assert keypoints_gt.shape == keypoints_pred.shape
costs = np.linalg.norm(
keypoints_gt[None] - keypoints_pred[:, None], ord=2,
axis=3).mean(axis=2)
match = Munkres().compute(costs)
keypoints_pred_sorted = np.zeros_like(keypoints_pred)
scores_pred_sorted = np.zeros_like(scores_pred)
for i, j in match:
keypoints_pred_sorted[i] = keypoints_pred[j]
scores_pred_sorted[i] = scores_pred[j]
return keypoints_pred_sorted, scores_pred_sorted
def test_decode(self):
data = get_coco_sample(
img_shape=(256, 256), num_instances=2, non_occlusion=True)
# w/o UDP
codec = AssociativeEmbedding(
input_size=(256, 256),
heatmap_size=(64, 64),
use_udp=False,
decode_keypoint_order=self.decode_keypoint_order)
encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
heatmaps = encoded['heatmaps']
keypoint_indices = encoded['keypoint_indices']
tags = self._get_tags(
heatmaps, keypoint_indices, tag_per_keypoint=True)
# to Tensor
batch_heatmaps = torch.from_numpy(heatmaps[None])
batch_tags = torch.from_numpy(tags[None])
batch_keypoints, batch_keypoint_scores, batch_instance_scores = \
codec.batch_decode(batch_heatmaps, batch_tags)
self.assertIsInstance(batch_keypoints, list)
self.assertIsInstance(batch_keypoint_scores, list)
self.assertEqual(len(batch_keypoints), 1)
self.assertEqual(len(batch_keypoint_scores), 1)
keypoints, scores = self._sort_preds(batch_keypoints[0],
batch_keypoint_scores[0],
data['keypoints'])
self.assertIsInstance(keypoints, np.ndarray)
self.assertIsInstance(scores, np.ndarray)
self.assertEqual(keypoints.shape, (2, 17, 2))
self.assertEqual(scores.shape, (2, 17))
self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0))
# w/o UDP, tag_imd=2
codec = AssociativeEmbedding(
input_size=(256, 256),
heatmap_size=(64, 64),
use_udp=False,
decode_keypoint_order=self.decode_keypoint_order)
encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
heatmaps = encoded['heatmaps']
keypoint_indices = encoded['keypoint_indices']
tags = self._get_tags(
heatmaps, keypoint_indices, tag_per_keypoint=True, tag_dim=2)
# to Tensor
batch_heatmaps = torch.from_numpy(heatmaps[None])
batch_tags = torch.from_numpy(tags[None])
batch_keypoints, batch_keypoint_scores, batch_instance_scores = \
codec.batch_decode(batch_heatmaps, batch_tags)
self.assertIsInstance(batch_keypoints, list)
self.assertIsInstance(batch_keypoint_scores, list)
self.assertEqual(len(batch_keypoints), 1)
self.assertEqual(len(batch_keypoint_scores), 1)
keypoints, scores = self._sort_preds(batch_keypoints[0],
batch_keypoint_scores[0],
data['keypoints'])
self.assertIsInstance(keypoints, np.ndarray)
self.assertIsInstance(scores, np.ndarray)
self.assertEqual(keypoints.shape, (2, 17, 2))
self.assertEqual(scores.shape, (2, 17))
self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0))
# w/ UDP
codec = AssociativeEmbedding(
input_size=(256, 256),
heatmap_size=(64, 64),
use_udp=True,
decode_keypoint_order=self.decode_keypoint_order)
encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
heatmaps = encoded['heatmaps']
keypoint_indices = encoded['keypoint_indices']
tags = self._get_tags(
heatmaps, keypoint_indices, tag_per_keypoint=True)
# to Tensor
batch_heatmaps = torch.from_numpy(heatmaps[None])
batch_tags = torch.from_numpy(tags[None])
batch_keypoints, batch_keypoint_scores, batch_instance_scores = \
codec.batch_decode(batch_heatmaps, batch_tags)
self.assertIsInstance(batch_keypoints, list)
self.assertIsInstance(batch_keypoint_scores, list)
self.assertEqual(len(batch_keypoints), 1)
self.assertEqual(len(batch_keypoint_scores), 1)
keypoints, scores = self._sort_preds(batch_keypoints[0],
batch_keypoint_scores[0],
data['keypoints'])
self.assertIsInstance(keypoints, np.ndarray)
self.assertIsInstance(scores, np.ndarray)
self.assertEqual(keypoints.shape, (2, 17, 2))
self.assertEqual(scores.shape, (2, 17))
self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0))