mmpose/tests/test_models/test_backbones/test_v2v_net.py

18 lines
470 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmpose.models.backbones import V2VNet
class TestV2Vnet(TestCase):
def test_v2v_net(self):
"""Test V2VNet."""
model = V2VNet(input_channels=17, output_channels=15)
input = torch.randn(2, 17, 32, 32, 32)
output = model(input)
self.assertIsInstance(output, tuple)
self.assertEqual(output[-1].shape, (2, 15, 32, 32, 32))