mirror of https://github.com/open-mmlab/mmpose
65 lines
2.0 KiB
Python
65 lines
2.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List, Tuple
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
|
|
from mmpose.models.necks import FeatureMapProcessor
|
|
|
|
|
|
class TestFeatureMapProcessor(TestCase):
|
|
|
|
def _get_feats(
|
|
self,
|
|
batch_size: int = 2,
|
|
feat_shapes: List[Tuple[int, int, int]] = [(32, 1, 1)],
|
|
):
|
|
|
|
feats = [
|
|
torch.rand((batch_size, ) + shape, dtype=torch.float32)
|
|
for shape in feat_shapes
|
|
]
|
|
|
|
return feats
|
|
|
|
def test_init(self):
|
|
|
|
neck = FeatureMapProcessor(select_index=0)
|
|
self.assertSequenceEqual(neck.select_index, (0, ))
|
|
|
|
with self.assertRaises(AssertionError):
|
|
neck = FeatureMapProcessor(scale_factor=0.0)
|
|
|
|
def test_call(self):
|
|
|
|
inputs = self._get_feats(
|
|
batch_size=2, feat_shapes=[(2, 16, 16), (4, 8, 8), (8, 4, 4)])
|
|
|
|
neck = FeatureMapProcessor(select_index=0)
|
|
output = neck(inputs)
|
|
self.assertEqual(len(output), 1)
|
|
self.assertSequenceEqual(output[0].shape, (2, 2, 16, 16))
|
|
|
|
neck = FeatureMapProcessor(select_index=(2, 1))
|
|
output = neck(inputs)
|
|
self.assertEqual(len(output), 2)
|
|
self.assertSequenceEqual(output[1].shape, (2, 4, 8, 8))
|
|
self.assertSequenceEqual(output[0].shape, (2, 8, 4, 4))
|
|
|
|
neck = FeatureMapProcessor(select_index=(1, 2), concat=True)
|
|
output = neck(inputs)
|
|
self.assertEqual(len(output), 1)
|
|
self.assertSequenceEqual(output[0].shape, (2, 12, 8, 8))
|
|
|
|
neck = FeatureMapProcessor(
|
|
select_index=(2, 1), concat=True, scale_factor=2)
|
|
output = neck(inputs)
|
|
self.assertEqual(len(output), 1)
|
|
self.assertSequenceEqual(output[0].shape, (2, 12, 8, 8))
|
|
|
|
neck = FeatureMapProcessor(concat=True, apply_relu=True)
|
|
output = neck(inputs)
|
|
self.assertEqual(len(output), 1)
|
|
self.assertSequenceEqual(output[0].shape, (2, 14, 16, 16))
|
|
self.assertGreaterEqual(output[0].max(), 0)
|