mirror of https://github.com/open-mmlab/mmpose
49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
|
|
from mmpose.models.losses.regression_loss import SoftWeightSmoothL1Loss
|
|
|
|
|
|
class TestSoftWeightSmoothL1Loss(TestCase):
|
|
|
|
def test_loss(self):
|
|
|
|
# test loss w/o target_weight
|
|
loss = SoftWeightSmoothL1Loss(use_target_weight=False, beta=0.5)
|
|
|
|
fake_pred = torch.zeros((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
self.assertTrue(
|
|
torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.)))
|
|
|
|
fake_pred = torch.ones((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
self.assertTrue(
|
|
torch.allclose(loss(fake_pred, fake_label), torch.tensor(.75)))
|
|
|
|
# test loss w/ target_weight
|
|
loss = SoftWeightSmoothL1Loss(
|
|
use_target_weight=True, supervise_empty=True)
|
|
|
|
fake_pred = torch.ones((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
fake_weight = torch.arange(6).reshape(1, 3, 2).float()
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
loss(fake_pred, fake_label, fake_weight), torch.tensor(1.25)))
|
|
|
|
# test loss that does not take empty channels into account
|
|
loss = SoftWeightSmoothL1Loss(
|
|
use_target_weight=True, supervise_empty=False)
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
loss(fake_pred, fake_label, fake_weight), torch.tensor(1.5)))
|
|
|
|
with self.assertRaises(ValueError):
|
|
_ = loss.smooth_l1_loss(fake_pred, fake_label, reduction='fake')
|
|
|
|
output = loss.smooth_l1_loss(fake_pred, fake_label, reduction='sum')
|
|
self.assertTrue(torch.allclose(output, torch.tensor(3.0)))
|