mirror of https://github.com/open-mmlab/mmpose
57 lines
1.8 KiB
Python
57 lines
1.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from copy import deepcopy
|
|
from unittest import TestCase
|
|
|
|
import numpy as np
|
|
|
|
from mmpose.evaluation.functional import (transform_ann, transform_pred,
|
|
transform_sigmas)
|
|
|
|
|
|
class TestKeypointEval(TestCase):
|
|
|
|
def test_transform_sigmas(self):
|
|
|
|
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
|
|
num_keypoints = 5
|
|
sigmas = np.random.rand(17)
|
|
new_sigmas = transform_sigmas(sigmas, num_keypoints, mapping)
|
|
self.assertEqual(len(new_sigmas), 5)
|
|
for i, j in mapping:
|
|
self.assertEqual(sigmas[i], new_sigmas[j])
|
|
|
|
def test_transform_ann(self):
|
|
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
|
|
num_keypoints = 5
|
|
|
|
kpt_info = dict(
|
|
num_keypoints=17,
|
|
keypoints=np.random.randint(3, size=(17 * 3, )).tolist())
|
|
kpt_info_copy = deepcopy(kpt_info)
|
|
|
|
_ = transform_ann(kpt_info, num_keypoints, mapping)
|
|
|
|
self.assertEqual(kpt_info['num_keypoints'], 5)
|
|
self.assertEqual(len(kpt_info['keypoints']), 15)
|
|
for i, j in mapping:
|
|
self.assertListEqual(kpt_info_copy['keypoints'][i * 3:i * 3 + 3],
|
|
kpt_info['keypoints'][j * 3:j * 3 + 3])
|
|
|
|
def test_transform_pred(self):
|
|
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
|
|
num_keypoints = 5
|
|
|
|
kpt_info = dict(
|
|
num_keypoints=17,
|
|
keypoints=np.random.randint(3, size=(
|
|
1,
|
|
17,
|
|
3,
|
|
)),
|
|
keypoint_scores=np.ones((1, 17)))
|
|
|
|
_ = transform_pred(kpt_info, num_keypoints, mapping)
|
|
|
|
self.assertEqual(kpt_info['num_keypoints'], 5)
|
|
self.assertEqual(len(kpt_info['keypoints']), 1)
|