mmpose/tests/test_evaluation/test_functional/test_transforms.py

57 lines
1.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase
import numpy as np
from mmpose.evaluation.functional import (transform_ann, transform_pred,
transform_sigmas)
class TestKeypointEval(TestCase):
def test_transform_sigmas(self):
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
num_keypoints = 5
sigmas = np.random.rand(17)
new_sigmas = transform_sigmas(sigmas, num_keypoints, mapping)
self.assertEqual(len(new_sigmas), 5)
for i, j in mapping:
self.assertEqual(sigmas[i], new_sigmas[j])
def test_transform_ann(self):
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
num_keypoints = 5
kpt_info = dict(
num_keypoints=17,
keypoints=np.random.randint(3, size=(17 * 3, )).tolist())
kpt_info_copy = deepcopy(kpt_info)
_ = transform_ann(kpt_info, num_keypoints, mapping)
self.assertEqual(kpt_info['num_keypoints'], 5)
self.assertEqual(len(kpt_info['keypoints']), 15)
for i, j in mapping:
self.assertListEqual(kpt_info_copy['keypoints'][i * 3:i * 3 + 3],
kpt_info['keypoints'][j * 3:j * 3 + 3])
def test_transform_pred(self):
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
num_keypoints = 5
kpt_info = dict(
num_keypoints=17,
keypoints=np.random.randint(3, size=(
1,
17,
3,
)),
keypoint_scores=np.ones((1, 17)))
_ = transform_pred(kpt_info, num_keypoints, mapping)
self.assertEqual(kpt_info['num_keypoints'], 5)
self.assertEqual(len(kpt_info['keypoints']), 1)