mirror of https://github.com/open-mmlab/mmpose
73 lines
2.4 KiB
Python
73 lines
2.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import numpy as np
|
|
import torch
|
|
from mmengine.structures import PixelData
|
|
|
|
from mmpose.structures import MultilevelPixelData
|
|
|
|
|
|
class TestMultilevelPixelData(TestCase):
|
|
|
|
def get_multi_level_pixel_data(self):
|
|
metainfo = dict(num_keypoints=17)
|
|
sizes = [(64, 48), (32, 24), (16, 12)]
|
|
heatmaps = [np.random.rand(17, h, w) for h, w in sizes]
|
|
masks = [torch.rand(1, h, w) for h, w in sizes]
|
|
data = MultilevelPixelData(
|
|
metainfo=metainfo, heatmaps=heatmaps, masks=masks)
|
|
|
|
return data
|
|
|
|
def test_init(self):
|
|
|
|
data = self.get_multi_level_pixel_data()
|
|
self.assertIn('num_keypoints', data)
|
|
self.assertTrue(data.nlevel == 3)
|
|
self.assertTrue(data.shape == ((64, 48), (32, 24), (16, 12)))
|
|
self.assertTrue(isinstance(data[0], PixelData))
|
|
|
|
def test_setter(self):
|
|
# test `set_field`
|
|
data = self.get_multi_level_pixel_data()
|
|
sizes = [(64, 48), (32, 24), (16, 8)]
|
|
offset_maps = [torch.rand(2, h, w) for h, w in sizes]
|
|
data.offset_maps = offset_maps
|
|
|
|
# test `to_tensor`
|
|
data = self.get_multi_level_pixel_data()
|
|
self.assertTrue(isinstance(data[0].heatmaps, np.ndarray))
|
|
data = data.to_tensor()
|
|
self.assertTrue(isinstance(data[0].heatmaps, torch.Tensor))
|
|
|
|
# test `cpu`
|
|
data = self.get_multi_level_pixel_data()
|
|
self.assertTrue(isinstance(data[0].heatmaps, np.ndarray))
|
|
self.assertTrue(isinstance(data[0].masks, torch.Tensor))
|
|
self.assertTrue(data[0].masks.device.type == 'cpu')
|
|
data = data.cpu()
|
|
self.assertTrue(isinstance(data[0].heatmaps, np.ndarray))
|
|
self.assertTrue(data[0].masks.device.type == 'cpu')
|
|
|
|
# test `to`
|
|
data = self.get_multi_level_pixel_data()
|
|
self.assertTrue(data[0].masks.device.type == 'cpu')
|
|
data = data.to('cpu')
|
|
self.assertTrue(data[0].masks.device.type == 'cpu')
|
|
|
|
# test `numpy`
|
|
data = self.get_multi_level_pixel_data()
|
|
self.assertTrue(isinstance(data[0].masks, torch.Tensor))
|
|
data = data.numpy()
|
|
self.assertTrue(isinstance(data[0].masks, np.ndarray))
|
|
|
|
def test_deleter(self):
|
|
|
|
data = self.get_multi_level_pixel_data()
|
|
|
|
for key in ['heatmaps', 'masks']:
|
|
self.assertIn(key, data)
|
|
exec(f'del data.{key}')
|
|
self.assertNotIn(key, data)
|