mmpose/tests/test_structures/test_pose_data_sample.py

127 lines
4.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmengine.structures import InstanceData, PixelData
from mmpose.structures import MultilevelPixelData, PoseDataSample
class TestPoseDataSample(TestCase):
def get_pose_data_sample(self, multilevel: bool = False):
# meta
pose_meta = dict(
img_shape=(600, 900), # [h, w, c]
crop_size=(256, 192), # [h, w]
heatmap_size=(64, 48), # [h, w]
)
# gt_instances
gt_instances = InstanceData()
gt_instances.bboxes = torch.rand(1, 4)
gt_instances.keypoints = torch.rand(1, 17, 2)
gt_instances.keypoints_visible = torch.rand(1, 17)
# pred_instances
pred_instances = InstanceData()
pred_instances.keypoints = torch.rand(1, 17, 2)
pred_instances.keypoint_scores = torch.rand(1, 17)
# gt_fields
if multilevel:
# generate multilevel gt_fields
metainfo = dict(num_keypoints=17)
sizes = [(64, 48), (32, 24), (16, 12)]
heatmaps = [np.random.rand(17, h, w) for h, w in sizes]
masks = [torch.rand(1, h, w) for h, w in sizes]
gt_fields = MultilevelPixelData(
metainfo=metainfo, heatmaps=heatmaps, masks=masks)
else:
gt_fields = PixelData()
gt_fields.heatmaps = torch.rand(17, 64, 48)
# pred_fields
pred_fields = PixelData()
pred_fields.heatmaps = torch.rand(17, 64, 48)
data_sample = PoseDataSample(
gt_instances=gt_instances,
pred_instances=pred_instances,
gt_fields=gt_fields,
pred_fields=pred_fields,
metainfo=pose_meta)
return data_sample
@staticmethod
def _equal(x, y):
if type(x) != type(y):
return False
if isinstance(x, torch.Tensor):
return torch.allclose(x, y)
elif isinstance(x, np.ndarray):
return np.allclose(x, y)
else:
return x == y
def test_init(self):
data_sample = self.get_pose_data_sample()
self.assertIn('img_shape', data_sample)
self.assertTrue(len(data_sample.gt_instances) == 1)
def test_setter(self):
data_sample = self.get_pose_data_sample()
# test gt_instances
data_sample.gt_instances = InstanceData()
# test gt_fields
data_sample.gt_fields = PixelData()
# test multilevel gt_fields
data_sample = self.get_pose_data_sample(multilevel=True)
data_sample.gt_fields = MultilevelPixelData()
# test pred_instances as pytorch tensor
pred_instances_data = dict(
keypoints=torch.rand(1, 17, 2), scores=torch.rand(1, 17, 1))
data_sample.pred_instances = InstanceData(**pred_instances_data)
self.assertTrue(
self._equal(data_sample.pred_instances.keypoints,
pred_instances_data['keypoints']))
self.assertTrue(
self._equal(data_sample.pred_instances.scores,
pred_instances_data['scores']))
# test pred_fields as numpy array
pred_fields_data = dict(heatmaps=np.random.rand(17, 64, 48))
data_sample.pred_fields = PixelData(**pred_fields_data)
self.assertTrue(
self._equal(data_sample.pred_fields.heatmaps,
pred_fields_data['heatmaps']))
# test to_tensor
data_sample = data_sample.to_tensor()
self.assertTrue(
self._equal(data_sample.pred_fields.heatmaps,
torch.from_numpy(pred_fields_data['heatmaps'])))
def test_deleter(self):
data_sample = self.get_pose_data_sample()
for key in [
'gt_instances',
'pred_instances',
'gt_fields',
'pred_fields',
]:
self.assertIn(key, data_sample)
exec(f'del data_sample.{key}')
self.assertNotIn(key, data_sample)