mirror of https://github.com/open-mmlab/mmpose
187 lines
5.9 KiB
Python
187 lines
5.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from itertools import product
|
|
from typing import Tuple
|
|
from unittest import TestCase
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
|
|
from mmpose.codecs.associative_embedding import AssociativeEmbedding
|
|
from mmpose.models.losses.ae_loss import AssociativeEmbeddingLoss
|
|
from mmpose.testing._utils import get_coco_sample
|
|
|
|
|
|
class AELoss(nn.Module):
|
|
"""Associative Embedding loss in MMPose v0.x."""
|
|
|
|
def __init__(self, loss_type):
|
|
super().__init__()
|
|
self.loss_type = loss_type
|
|
|
|
@staticmethod
|
|
def _make_input(t, requires_grad=False, device=torch.device('cpu')):
|
|
"""Make zero inputs for AE loss.
|
|
|
|
Args:
|
|
t (torch.Tensor): input
|
|
requires_grad (bool): Option to use requires_grad.
|
|
device: torch device
|
|
|
|
Returns:
|
|
torch.Tensor: zero input.
|
|
"""
|
|
inp = torch.autograd.Variable(t, requires_grad=requires_grad)
|
|
inp = inp.sum()
|
|
inp = inp.to(device)
|
|
return inp
|
|
|
|
def singleTagLoss(self, pred_tag, joints):
|
|
"""Associative embedding loss for one image.
|
|
|
|
Note:
|
|
- heatmaps weight: W
|
|
- heatmaps height: H
|
|
- max_num_people: M
|
|
- num_keypoints: K
|
|
|
|
Args:
|
|
pred_tag (torch.Tensor[KxHxW,1]): tag of output for one image.
|
|
joints (torch.Tensor[M,K,2]): joints information for one image.
|
|
"""
|
|
tags = []
|
|
pull = 0
|
|
pred_tag = pred_tag.view(17, -1, 1)
|
|
for joints_per_person in joints:
|
|
tmp = []
|
|
for k, joint in enumerate(joints_per_person):
|
|
if joint[1] > 0:
|
|
tmp.append(pred_tag[k, joint[0]])
|
|
if len(tmp) == 0:
|
|
continue
|
|
tmp = torch.stack(tmp)
|
|
tags.append(torch.mean(tmp, dim=0))
|
|
pull = pull + torch.mean((tmp - tags[-1].expand_as(tmp))**2)
|
|
|
|
num_tags = len(tags)
|
|
if num_tags == 0:
|
|
return (self._make_input(
|
|
torch.zeros(1).float(), device=pred_tag.device),
|
|
self._make_input(
|
|
torch.zeros(1).float(), device=pred_tag.device))
|
|
elif num_tags == 1:
|
|
return (self._make_input(
|
|
torch.zeros(1).float(), device=pred_tag.device), pull)
|
|
|
|
tags = torch.stack(tags)
|
|
|
|
size = (num_tags, num_tags)
|
|
A = tags.expand(*size)
|
|
B = A.permute(1, 0)
|
|
|
|
diff = A - B
|
|
|
|
if self.loss_type == 'exp':
|
|
diff = torch.pow(diff, 2)
|
|
push = torch.exp(-diff)
|
|
push = torch.sum(push)
|
|
elif self.loss_type == 'max':
|
|
diff = 1 - torch.abs(diff)
|
|
push = torch.clamp(diff, min=0).sum() - num_tags
|
|
else:
|
|
raise ValueError('Unknown ae loss type')
|
|
|
|
push_loss = push / ((num_tags - 1) * num_tags) * 0.5
|
|
pull_loss = pull / (num_tags)
|
|
|
|
return push_loss, pull_loss
|
|
|
|
def forward(self, tags, keypoint_indices):
|
|
assert tags.shape[0] == len(keypoint_indices)
|
|
|
|
pull_loss = 0.
|
|
push_loss = 0.
|
|
|
|
for i in range(tags.shape[0]):
|
|
_push, _pull = self.singleTagLoss(tags[i].view(-1, 1),
|
|
keypoint_indices[i])
|
|
pull_loss += _pull
|
|
push_loss += _push
|
|
|
|
return pull_loss, push_loss
|
|
|
|
|
|
class TestAssociativeEmbeddingLoss(TestCase):
|
|
|
|
def _make_input(self, num_instance: int) -> Tuple[Tensor, Tensor]:
|
|
|
|
encoder = AssociativeEmbedding(
|
|
input_size=(256, 256), heatmap_size=(64, 64))
|
|
|
|
data = get_coco_sample(
|
|
img_shape=(256, 256), num_instances=num_instance)
|
|
encoded = encoder.encode(data['keypoints'], data['keypoints_visible'])
|
|
heatmaps = encoded['heatmaps']
|
|
keypoint_indices = encoded['keypoint_indices']
|
|
|
|
tags = self._get_tags(
|
|
heatmaps, keypoint_indices, tag_per_keypoint=True)
|
|
|
|
batch_tags = torch.from_numpy(tags[None])
|
|
batch_keypoint_indices = [torch.from_numpy(keypoint_indices)]
|
|
|
|
return batch_tags, batch_keypoint_indices
|
|
|
|
def _get_tags(self,
|
|
heatmaps,
|
|
keypoint_indices,
|
|
tag_per_keypoint: bool,
|
|
with_randomness: bool = True):
|
|
|
|
K, H, W = heatmaps.shape
|
|
N = keypoint_indices.shape[0]
|
|
|
|
if tag_per_keypoint:
|
|
tags = np.zeros((K, H, W), dtype=np.float32)
|
|
else:
|
|
tags = np.zeros((1, H, W), dtype=np.float32)
|
|
|
|
for n, k in product(range(N), range(K)):
|
|
y, x = np.unravel_index(keypoint_indices[n, k, 0], (H, W))
|
|
|
|
randomness = np.random.rand() if with_randomness else 0
|
|
|
|
if tag_per_keypoint:
|
|
tags[k, y, x] = n + randomness
|
|
else:
|
|
tags[0, y, x] = n + randomness
|
|
|
|
return tags
|
|
|
|
def test_loss(self):
|
|
|
|
tags, keypoint_indices = self._make_input(num_instance=2)
|
|
|
|
# test loss calculation
|
|
loss_module = AssociativeEmbeddingLoss()
|
|
pull_loss, push_loss = loss_module(tags, keypoint_indices)
|
|
_pull_loss, _push_loss = AELoss('exp')(tags, keypoint_indices)
|
|
|
|
self.assertTrue(torch.allclose(pull_loss, _pull_loss))
|
|
self.assertTrue(torch.allclose(push_loss, _push_loss))
|
|
|
|
# test loss weight
|
|
loss_module = AssociativeEmbeddingLoss(loss_weight=0.)
|
|
pull_loss, push_loss = loss_module(tags, keypoint_indices)
|
|
|
|
self.assertTrue(torch.allclose(pull_loss, torch.zeros(1)))
|
|
self.assertTrue(torch.allclose(push_loss, torch.zeros(1)))
|
|
|
|
# test push loss factor
|
|
loss_module = AssociativeEmbeddingLoss(push_loss_factor=0.)
|
|
pull_loss, push_loss = loss_module(tags, keypoint_indices)
|
|
|
|
self.assertFalse(torch.allclose(pull_loss, torch.zeros(1)))
|
|
self.assertTrue(torch.allclose(push_loss, torch.zeros(1)))
|