mmpose/tests/test_models/test_backbones/test_litehrnet.py

146 lines
4.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from mmpose.models.backbones import LiteHRNet
from mmpose.models.backbones.litehrnet import LiteHRModule
from mmpose.models.backbones.resnet import Bottleneck
class TestLiteHrnet(TestCase):
@staticmethod
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (_BatchNorm, )):
return True
return False
@staticmethod
def all_zeros(modules):
"""Check if the weight(and bias) is all zero."""
weight_zero = torch.equal(modules.weight.data,
torch.zeros_like(modules.weight.data))
if hasattr(modules, 'bias'):
bias_zero = torch.equal(modules.bias.data,
torch.zeros_like(modules.bias.data))
else:
bias_zero = True
return weight_zero and bias_zero
def test_litehrmodule(self):
# Test LiteHRModule forward
block = LiteHRModule(
num_branches=1,
num_blocks=1,
in_channels=[
40,
],
reduce_ratio=8,
module_type='LITE')
x = torch.randn(2, 40, 56, 56)
x_out = block([[x]])
self.assertEqual(x_out[0][0].shape, torch.Size([2, 40, 56, 56]))
block = LiteHRModule(
num_branches=1,
num_blocks=1,
in_channels=[
40,
],
reduce_ratio=8,
module_type='NAIVE')
x = torch.randn(2, 40, 56, 56)
x_out = block([x])
self.assertEqual(x_out[0].shape, torch.Size([2, 40, 56, 56]))
with self.assertRaises(ValueError):
block = LiteHRModule(
num_branches=1,
num_blocks=1,
in_channels=[
40,
],
reduce_ratio=8,
module_type='none')
def test_litehrnet_backbone(self):
extra = dict(
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
num_stages=3,
stages_spec=dict(
num_modules=(2, 4, 2),
num_branches=(2, 3, 4),
num_blocks=(2, 2, 2),
module_type=('LITE', 'LITE', 'LITE'),
with_fuse=(True, True, True),
reduce_ratios=(8, 8, 8),
num_channels=(
(40, 80),
(40, 80, 160),
(40, 80, 160, 320),
)),
with_head=True)
model = LiteHRNet(extra, in_channels=3)
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
self.assertIsInstance(feat, tuple)
self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56]))
# Test HRNet zero initialization of residual
model = LiteHRNet(extra, in_channels=3)
model.init_weights()
for m in model.modules():
if isinstance(m, Bottleneck):
self.assertTrue(self.all_zeros(m.norm3))
model.train()
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
self.assertIsInstance(feat, tuple)
self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56]))
extra = dict(
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
num_stages=3,
stages_spec=dict(
num_modules=(2, 4, 2),
num_branches=(2, 3, 4),
num_blocks=(2, 2, 2),
module_type=('NAIVE', 'NAIVE', 'NAIVE'),
with_fuse=(True, True, True),
reduce_ratios=(8, 8, 8),
num_channels=(
(40, 80),
(40, 80, 160),
(40, 80, 160, 320),
)),
with_head=True)
model = LiteHRNet(extra, in_channels=3)
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
self.assertIsInstance(feat, tuple)
self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56]))
# Test HRNet zero initialization of residual
model = LiteHRNet(extra, in_channels=3)
model.init_weights()
for m in model.modules():
if isinstance(m, Bottleneck):
self.assertTrue(self.all_zeros(m.norm3))
model.train()
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
self.assertIsInstance(feat, tuple)
self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56]))