mirror of https://github.com/open-mmlab/mmpose
146 lines
4.6 KiB
Python
146 lines
4.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|
|
|
from mmpose.models.backbones import LiteHRNet
|
|
from mmpose.models.backbones.litehrnet import LiteHRModule
|
|
from mmpose.models.backbones.resnet import Bottleneck
|
|
|
|
|
|
class TestLiteHrnet(TestCase):
|
|
|
|
@staticmethod
|
|
def is_norm(modules):
|
|
"""Check if is one of the norms."""
|
|
if isinstance(modules, (_BatchNorm, )):
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def all_zeros(modules):
|
|
"""Check if the weight(and bias) is all zero."""
|
|
weight_zero = torch.equal(modules.weight.data,
|
|
torch.zeros_like(modules.weight.data))
|
|
if hasattr(modules, 'bias'):
|
|
bias_zero = torch.equal(modules.bias.data,
|
|
torch.zeros_like(modules.bias.data))
|
|
else:
|
|
bias_zero = True
|
|
|
|
return weight_zero and bias_zero
|
|
|
|
def test_litehrmodule(self):
|
|
# Test LiteHRModule forward
|
|
block = LiteHRModule(
|
|
num_branches=1,
|
|
num_blocks=1,
|
|
in_channels=[
|
|
40,
|
|
],
|
|
reduce_ratio=8,
|
|
module_type='LITE')
|
|
|
|
x = torch.randn(2, 40, 56, 56)
|
|
x_out = block([[x]])
|
|
self.assertEqual(x_out[0][0].shape, torch.Size([2, 40, 56, 56]))
|
|
|
|
block = LiteHRModule(
|
|
num_branches=1,
|
|
num_blocks=1,
|
|
in_channels=[
|
|
40,
|
|
],
|
|
reduce_ratio=8,
|
|
module_type='NAIVE')
|
|
|
|
x = torch.randn(2, 40, 56, 56)
|
|
x_out = block([x])
|
|
self.assertEqual(x_out[0].shape, torch.Size([2, 40, 56, 56]))
|
|
|
|
with self.assertRaises(ValueError):
|
|
block = LiteHRModule(
|
|
num_branches=1,
|
|
num_blocks=1,
|
|
in_channels=[
|
|
40,
|
|
],
|
|
reduce_ratio=8,
|
|
module_type='none')
|
|
|
|
def test_litehrnet_backbone(self):
|
|
extra = dict(
|
|
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
|
|
num_stages=3,
|
|
stages_spec=dict(
|
|
num_modules=(2, 4, 2),
|
|
num_branches=(2, 3, 4),
|
|
num_blocks=(2, 2, 2),
|
|
module_type=('LITE', 'LITE', 'LITE'),
|
|
with_fuse=(True, True, True),
|
|
reduce_ratios=(8, 8, 8),
|
|
num_channels=(
|
|
(40, 80),
|
|
(40, 80, 160),
|
|
(40, 80, 160, 320),
|
|
)),
|
|
with_head=True)
|
|
|
|
model = LiteHRNet(extra, in_channels=3)
|
|
|
|
imgs = torch.randn(2, 3, 224, 224)
|
|
feat = model(imgs)
|
|
self.assertIsInstance(feat, tuple)
|
|
self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56]))
|
|
|
|
# Test HRNet zero initialization of residual
|
|
model = LiteHRNet(extra, in_channels=3)
|
|
model.init_weights()
|
|
for m in model.modules():
|
|
if isinstance(m, Bottleneck):
|
|
self.assertTrue(self.all_zeros(m.norm3))
|
|
model.train()
|
|
|
|
imgs = torch.randn(2, 3, 224, 224)
|
|
feat = model(imgs)
|
|
self.assertIsInstance(feat, tuple)
|
|
self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56]))
|
|
|
|
extra = dict(
|
|
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
|
|
num_stages=3,
|
|
stages_spec=dict(
|
|
num_modules=(2, 4, 2),
|
|
num_branches=(2, 3, 4),
|
|
num_blocks=(2, 2, 2),
|
|
module_type=('NAIVE', 'NAIVE', 'NAIVE'),
|
|
with_fuse=(True, True, True),
|
|
reduce_ratios=(8, 8, 8),
|
|
num_channels=(
|
|
(40, 80),
|
|
(40, 80, 160),
|
|
(40, 80, 160, 320),
|
|
)),
|
|
with_head=True)
|
|
|
|
model = LiteHRNet(extra, in_channels=3)
|
|
|
|
imgs = torch.randn(2, 3, 224, 224)
|
|
feat = model(imgs)
|
|
self.assertIsInstance(feat, tuple)
|
|
self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56]))
|
|
|
|
# Test HRNet zero initialization of residual
|
|
model = LiteHRNet(extra, in_channels=3)
|
|
model.init_weights()
|
|
for m in model.modules():
|
|
if isinstance(m, Bottleneck):
|
|
self.assertTrue(self.all_zeros(m.norm3))
|
|
model.train()
|
|
|
|
imgs = torch.randn(2, 3, 224, 224)
|
|
feat = model(imgs)
|
|
self.assertIsInstance(feat, tuple)
|
|
self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56]))
|