mirror of https://github.com/open-mmlab/mmpose
160 lines
6.0 KiB
Python
160 lines
6.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
|
|
from mmpose.models.utils.transformer import GAUEncoder, SinePositionalEncoding
|
|
|
|
|
|
class TestSinePositionalEncoding(TestCase):
|
|
|
|
def test_init(self):
|
|
|
|
spe = SinePositionalEncoding(out_channels=128)
|
|
self.assertTrue(hasattr(spe, 'dim_t'))
|
|
self.assertFalse(spe.dim_t.requires_grad)
|
|
self.assertEqual(spe.dim_t.size(0), 128 // 2)
|
|
|
|
spe = SinePositionalEncoding(out_channels=128, learnable=True)
|
|
self.assertTrue(spe.dim_t.requires_grad)
|
|
|
|
spe = SinePositionalEncoding(out_channels=128, eval_size=10)
|
|
self.assertTrue(hasattr(spe, 'pos_enc_10'))
|
|
self.assertEqual(spe.pos_enc_10.size(-1), 128)
|
|
|
|
spe = SinePositionalEncoding(
|
|
out_channels=128, eval_size=(2, 3), spatial_dim=2)
|
|
self.assertTrue(hasattr(spe, 'pos_enc_(2, 3)'))
|
|
self.assertSequenceEqual(
|
|
getattr(spe, 'pos_enc_(2, 3)').shape[-2:], (128, 2))
|
|
|
|
def test_generate_speoding(self):
|
|
|
|
# spatial_dim = 1
|
|
spe = SinePositionalEncoding(out_channels=128)
|
|
pos_enc = spe.generate_pos_encoding(size=10)
|
|
self.assertSequenceEqual(pos_enc.shape, (10, 128))
|
|
|
|
position = torch.arange(8)
|
|
pos_enc = spe.generate_pos_encoding(position=position)
|
|
self.assertSequenceEqual(pos_enc.shape, (8, 128))
|
|
|
|
with self.assertRaises(AssertionError):
|
|
pos_enc = spe.generate_pos_encoding(size=10, position=position)
|
|
|
|
# spatial_dim = 2
|
|
spe = SinePositionalEncoding(out_channels=128, spatial_dim=2)
|
|
pos_enc = spe.generate_pos_encoding(size=10)
|
|
self.assertSequenceEqual(pos_enc.shape, (100, 128, 2))
|
|
|
|
pos_enc = spe.generate_pos_encoding(size=(5, 6))
|
|
self.assertSequenceEqual(pos_enc.shape, (30, 128, 2))
|
|
|
|
position = torch.arange(8).unsqueeze(1).repeat(1, 2)
|
|
pos_enc = spe.generate_pos_encoding(position=position)
|
|
self.assertSequenceEqual(pos_enc.shape, (8, 128, 2))
|
|
|
|
with self.assertRaises(AssertionError):
|
|
pos_enc = spe.generate_pos_encoding(size=10, position=position)
|
|
|
|
with self.assertRaises(ValueError):
|
|
pos_enc = spe.generate_pos_encoding(size=position)
|
|
|
|
def test_apply_additional_pos_enc(self):
|
|
|
|
# spatial_dim = 1
|
|
spe = SinePositionalEncoding(out_channels=128)
|
|
pos_enc = spe.generate_pos_encoding(size=10)
|
|
feature = torch.randn(2, 3, 10, 128)
|
|
out_feature = spe.apply_additional_pos_enc(feature, pos_enc,
|
|
spe.spatial_dim)
|
|
self.assertSequenceEqual(feature.shape, out_feature.shape)
|
|
|
|
# spatial_dim = 2
|
|
spe = SinePositionalEncoding(out_channels=128 // 2, spatial_dim=2)
|
|
pos_enc = spe.generate_pos_encoding(size=(2, 5))
|
|
feature = torch.randn(2, 3, 10, 128)
|
|
out_feature = spe.apply_additional_pos_enc(feature, pos_enc,
|
|
spe.spatial_dim)
|
|
self.assertSequenceEqual(feature.shape, out_feature.shape)
|
|
|
|
def test_apply_rotary_pos_enc(self):
|
|
|
|
# spatial_dim = 1
|
|
spe = SinePositionalEncoding(out_channels=128)
|
|
pos_enc = spe.generate_pos_encoding(size=10)
|
|
feature = torch.randn(2, 3, 10, 128)
|
|
out_feature = spe.apply_rotary_pos_enc(feature, pos_enc,
|
|
spe.spatial_dim)
|
|
self.assertSequenceEqual(feature.shape, out_feature.shape)
|
|
|
|
# spatial_dim = 2
|
|
spe = SinePositionalEncoding(out_channels=128, spatial_dim=2)
|
|
pos_enc = spe.generate_pos_encoding(size=(2, 5))
|
|
feature = torch.randn(2, 3, 10, 128)
|
|
out_feature = spe.apply_rotary_pos_enc(feature, pos_enc,
|
|
spe.spatial_dim)
|
|
self.assertSequenceEqual(feature.shape, out_feature.shape)
|
|
|
|
|
|
class TestGAUEncoder(TestCase):
|
|
|
|
def test_init(self):
|
|
gau = GAUEncoder(in_token_dims=64, out_token_dims=64)
|
|
self.assertTrue(gau.shortcut)
|
|
|
|
gau = GAUEncoder(in_token_dims=64, out_token_dims=64, dropout_rate=0.5)
|
|
self.assertTrue(hasattr(gau, 'dropout'))
|
|
|
|
def test_forward(self):
|
|
gau = GAUEncoder(in_token_dims=64, out_token_dims=64)
|
|
|
|
# compatibility with various dimension input
|
|
feat = torch.randn(2, 3, 64)
|
|
with torch.no_grad():
|
|
out_feat = gau.forward(feat)
|
|
self.assertSequenceEqual(feat.shape, out_feat.shape)
|
|
|
|
feat = torch.randn(1, 2, 3, 64)
|
|
with torch.no_grad():
|
|
out_feat = gau.forward(feat)
|
|
self.assertSequenceEqual(feat.shape, out_feat.shape)
|
|
|
|
feat = torch.randn(1, 2, 3, 4, 64)
|
|
with torch.no_grad():
|
|
out_feat = gau.forward(feat)
|
|
self.assertSequenceEqual(feat.shape, out_feat.shape)
|
|
|
|
# positional encoding
|
|
gau = GAUEncoder(
|
|
s=32, in_token_dims=64, out_token_dims=64, pos_enc=True)
|
|
feat = torch.randn(2, 3, 64)
|
|
spe = SinePositionalEncoding(out_channels=32)
|
|
pos_enc = spe.generate_pos_encoding(size=3)
|
|
with torch.no_grad():
|
|
out_feat = gau.forward(feat, pos_enc=pos_enc)
|
|
self.assertSequenceEqual(feat.shape, out_feat.shape)
|
|
|
|
gau = GAUEncoder(
|
|
s=32,
|
|
in_token_dims=64,
|
|
out_token_dims=64,
|
|
pos_enc=True,
|
|
spatial_dim=2)
|
|
feat = torch.randn(1, 2, 6, 64)
|
|
spe = SinePositionalEncoding(out_channels=32, spatial_dim=2)
|
|
pos_enc = spe.generate_pos_encoding(size=(2, 3))
|
|
with torch.no_grad():
|
|
out_feat = gau.forward(feat, pos_enc=pos_enc)
|
|
self.assertSequenceEqual(feat.shape, out_feat.shape)
|
|
|
|
# mask
|
|
gau = GAUEncoder(in_token_dims=64, out_token_dims=64)
|
|
|
|
# compatibility with various dimension input
|
|
feat = torch.randn(2, 3, 64)
|
|
mask = torch.rand(2, 3, 3)
|
|
with torch.no_grad():
|
|
out_feat = gau.forward(feat, mask=mask)
|
|
self.assertSequenceEqual(feat.shape, out_feat.shape)
|