mmpose/tests/test_models/test_losses/test_regression_losses.py

49 lines
1.7 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmpose.models.losses.regression_loss import SoftWeightSmoothL1Loss
class TestSoftWeightSmoothL1Loss(TestCase):
def test_loss(self):
# test loss w/o target_weight
loss = SoftWeightSmoothL1Loss(use_target_weight=False, beta=0.5)
fake_pred = torch.zeros((1, 3, 2))
fake_label = torch.zeros((1, 3, 2))
self.assertTrue(
torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.)))
fake_pred = torch.ones((1, 3, 2))
fake_label = torch.zeros((1, 3, 2))
self.assertTrue(
torch.allclose(loss(fake_pred, fake_label), torch.tensor(.75)))
# test loss w/ target_weight
loss = SoftWeightSmoothL1Loss(
use_target_weight=True, supervise_empty=True)
fake_pred = torch.ones((1, 3, 2))
fake_label = torch.zeros((1, 3, 2))
fake_weight = torch.arange(6).reshape(1, 3, 2).float()
self.assertTrue(
torch.allclose(
loss(fake_pred, fake_label, fake_weight), torch.tensor(1.25)))
# test loss that does not take empty channels into account
loss = SoftWeightSmoothL1Loss(
use_target_weight=True, supervise_empty=False)
self.assertTrue(
torch.allclose(
loss(fake_pred, fake_label, fake_weight), torch.tensor(1.5)))
with self.assertRaises(ValueError):
_ = loss.smooth_l1_loss(fake_pred, fake_label, reduction='fake')
output = loss.smooth_l1_loss(fake_pred, fake_label, reduction='sum')
self.assertTrue(torch.allclose(output, torch.tensor(3.0)))