mirror of https://github.com/open-mmlab/mmpose
243 lines
8.6 KiB
Python
243 lines
8.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from itertools import product
|
|
from unittest import TestCase
|
|
|
|
import numpy as np
|
|
import torch
|
|
from munkres import Munkres
|
|
|
|
from mmpose.codecs import AssociativeEmbedding
|
|
from mmpose.registry import KEYPOINT_CODECS
|
|
from mmpose.testing import get_coco_sample
|
|
|
|
|
|
class TestAssociativeEmbedding(TestCase):
|
|
|
|
def setUp(self) -> None:
|
|
self.decode_keypoint_order = [
|
|
0, 1, 2, 3, 4, 5, 6, 11, 12, 7, 8, 9, 10, 13, 14, 15, 16
|
|
]
|
|
|
|
def test_build(self):
|
|
cfg = dict(
|
|
type='AssociativeEmbedding',
|
|
input_size=(256, 256),
|
|
heatmap_size=(64, 64),
|
|
use_udp=False,
|
|
decode_keypoint_order=self.decode_keypoint_order,
|
|
)
|
|
codec = KEYPOINT_CODECS.build(cfg)
|
|
self.assertIsInstance(codec, AssociativeEmbedding)
|
|
|
|
def test_encode(self):
|
|
data = get_coco_sample(img_shape=(256, 256), num_instances=1)
|
|
|
|
# w/o UDP
|
|
codec = AssociativeEmbedding(
|
|
input_size=(256, 256),
|
|
heatmap_size=(64, 64),
|
|
use_udp=False,
|
|
decode_keypoint_order=self.decode_keypoint_order)
|
|
|
|
encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
|
|
|
|
heatmaps = encoded['heatmaps']
|
|
keypoint_indices = encoded['keypoint_indices']
|
|
keypoint_weights = encoded['keypoint_weights']
|
|
|
|
self.assertEqual(heatmaps.shape, (17, 64, 64))
|
|
self.assertEqual(keypoint_indices.shape, (1, 17, 2))
|
|
self.assertEqual(keypoint_weights.shape, (1, 17))
|
|
|
|
for k in range(heatmaps.shape[0]):
|
|
index_expected = np.argmax(heatmaps[k])
|
|
index_encoded = keypoint_indices[0, k, 0]
|
|
self.assertEqual(index_expected, index_encoded)
|
|
|
|
# w/ UDP
|
|
codec = AssociativeEmbedding(
|
|
input_size=(256, 256),
|
|
heatmap_size=(64, 64),
|
|
use_udp=True,
|
|
decode_keypoint_order=self.decode_keypoint_order)
|
|
|
|
encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
|
|
|
|
heatmaps = encoded['heatmaps']
|
|
keypoint_indices = encoded['keypoint_indices']
|
|
keypoint_weights = encoded['keypoint_weights']
|
|
|
|
self.assertEqual(heatmaps.shape, (17, 64, 64))
|
|
self.assertEqual(keypoint_indices.shape, (1, 17, 2))
|
|
self.assertEqual(keypoint_weights.shape, (1, 17))
|
|
|
|
for k in range(heatmaps.shape[0]):
|
|
index_expected = np.argmax(heatmaps[k])
|
|
index_encoded = keypoint_indices[0, k, 0]
|
|
self.assertEqual(index_expected, index_encoded)
|
|
|
|
def _get_tags(self,
|
|
heatmaps,
|
|
keypoint_indices,
|
|
tag_per_keypoint: bool,
|
|
tag_dim: int = 1):
|
|
|
|
K, H, W = heatmaps.shape
|
|
N = keypoint_indices.shape[0]
|
|
|
|
if tag_per_keypoint:
|
|
tags = np.zeros((K * tag_dim, H, W), dtype=np.float32)
|
|
else:
|
|
tags = np.zeros((tag_dim, H, W), dtype=np.float32)
|
|
|
|
for n, k in product(range(N), range(K)):
|
|
y, x = np.unravel_index(keypoint_indices[n, k, 0], (H, W))
|
|
if tag_per_keypoint:
|
|
tags[k::K, y, x] = n
|
|
else:
|
|
tags[:, y, x] = n
|
|
|
|
return tags
|
|
|
|
def _sort_preds(self, keypoints_pred, scores_pred, keypoints_gt):
|
|
"""Sort multi-instance predictions to best match the ground-truth.
|
|
|
|
Args:
|
|
keypoints_pred (np.ndarray): predictions in shape (N, K, D)
|
|
scores (np.ndarray): predictions in shape (N, K)
|
|
keypoints_gt (np.ndarray): ground-truth in shape (N, K, D)
|
|
|
|
Returns:
|
|
np.ndarray: Sorted predictions
|
|
"""
|
|
assert keypoints_gt.shape == keypoints_pred.shape
|
|
costs = np.linalg.norm(
|
|
keypoints_gt[None] - keypoints_pred[:, None], ord=2,
|
|
axis=3).mean(axis=2)
|
|
match = Munkres().compute(costs)
|
|
keypoints_pred_sorted = np.zeros_like(keypoints_pred)
|
|
scores_pred_sorted = np.zeros_like(scores_pred)
|
|
for i, j in match:
|
|
keypoints_pred_sorted[i] = keypoints_pred[j]
|
|
scores_pred_sorted[i] = scores_pred[j]
|
|
|
|
return keypoints_pred_sorted, scores_pred_sorted
|
|
|
|
def test_decode(self):
|
|
data = get_coco_sample(
|
|
img_shape=(256, 256), num_instances=2, non_occlusion=True)
|
|
|
|
# w/o UDP
|
|
codec = AssociativeEmbedding(
|
|
input_size=(256, 256),
|
|
heatmap_size=(64, 64),
|
|
use_udp=False,
|
|
decode_keypoint_order=self.decode_keypoint_order)
|
|
|
|
encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
|
|
|
|
heatmaps = encoded['heatmaps']
|
|
keypoint_indices = encoded['keypoint_indices']
|
|
|
|
tags = self._get_tags(
|
|
heatmaps, keypoint_indices, tag_per_keypoint=True)
|
|
|
|
# to Tensor
|
|
batch_heatmaps = torch.from_numpy(heatmaps[None])
|
|
batch_tags = torch.from_numpy(tags[None])
|
|
|
|
batch_keypoints, batch_keypoint_scores, batch_instance_scores = \
|
|
codec.batch_decode(batch_heatmaps, batch_tags)
|
|
|
|
self.assertIsInstance(batch_keypoints, list)
|
|
self.assertIsInstance(batch_keypoint_scores, list)
|
|
self.assertEqual(len(batch_keypoints), 1)
|
|
self.assertEqual(len(batch_keypoint_scores), 1)
|
|
|
|
keypoints, scores = self._sort_preds(batch_keypoints[0],
|
|
batch_keypoint_scores[0],
|
|
data['keypoints'])
|
|
|
|
self.assertIsInstance(keypoints, np.ndarray)
|
|
self.assertIsInstance(scores, np.ndarray)
|
|
self.assertEqual(keypoints.shape, (2, 17, 2))
|
|
self.assertEqual(scores.shape, (2, 17))
|
|
|
|
self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0))
|
|
|
|
# w/o UDP, tag_imd=2
|
|
codec = AssociativeEmbedding(
|
|
input_size=(256, 256),
|
|
heatmap_size=(64, 64),
|
|
use_udp=False,
|
|
decode_keypoint_order=self.decode_keypoint_order)
|
|
|
|
encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
|
|
|
|
heatmaps = encoded['heatmaps']
|
|
keypoint_indices = encoded['keypoint_indices']
|
|
|
|
tags = self._get_tags(
|
|
heatmaps, keypoint_indices, tag_per_keypoint=True, tag_dim=2)
|
|
|
|
# to Tensor
|
|
batch_heatmaps = torch.from_numpy(heatmaps[None])
|
|
batch_tags = torch.from_numpy(tags[None])
|
|
|
|
batch_keypoints, batch_keypoint_scores, batch_instance_scores = \
|
|
codec.batch_decode(batch_heatmaps, batch_tags)
|
|
|
|
self.assertIsInstance(batch_keypoints, list)
|
|
self.assertIsInstance(batch_keypoint_scores, list)
|
|
self.assertEqual(len(batch_keypoints), 1)
|
|
self.assertEqual(len(batch_keypoint_scores), 1)
|
|
|
|
keypoints, scores = self._sort_preds(batch_keypoints[0],
|
|
batch_keypoint_scores[0],
|
|
data['keypoints'])
|
|
|
|
self.assertIsInstance(keypoints, np.ndarray)
|
|
self.assertIsInstance(scores, np.ndarray)
|
|
self.assertEqual(keypoints.shape, (2, 17, 2))
|
|
self.assertEqual(scores.shape, (2, 17))
|
|
|
|
self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0))
|
|
|
|
# w/ UDP
|
|
codec = AssociativeEmbedding(
|
|
input_size=(256, 256),
|
|
heatmap_size=(64, 64),
|
|
use_udp=True,
|
|
decode_keypoint_order=self.decode_keypoint_order)
|
|
|
|
encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
|
|
|
|
heatmaps = encoded['heatmaps']
|
|
keypoint_indices = encoded['keypoint_indices']
|
|
|
|
tags = self._get_tags(
|
|
heatmaps, keypoint_indices, tag_per_keypoint=True)
|
|
|
|
# to Tensor
|
|
batch_heatmaps = torch.from_numpy(heatmaps[None])
|
|
batch_tags = torch.from_numpy(tags[None])
|
|
|
|
batch_keypoints, batch_keypoint_scores, batch_instance_scores = \
|
|
codec.batch_decode(batch_heatmaps, batch_tags)
|
|
|
|
self.assertIsInstance(batch_keypoints, list)
|
|
self.assertIsInstance(batch_keypoint_scores, list)
|
|
self.assertEqual(len(batch_keypoints), 1)
|
|
self.assertEqual(len(batch_keypoint_scores), 1)
|
|
|
|
keypoints, scores = self._sort_preds(batch_keypoints[0],
|
|
batch_keypoint_scores[0],
|
|
data['keypoints'])
|
|
|
|
self.assertIsInstance(keypoints, np.ndarray)
|
|
self.assertIsInstance(scores, np.ndarray)
|
|
self.assertEqual(keypoints.shape, (2, 17, 2))
|
|
self.assertEqual(scores.shape, (2, 17))
|
|
|
|
self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0))
|