mirror of https://github.com/open-mmlab/mmpose
360 lines
14 KiB
Python
360 lines
14 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from collections import defaultdict
|
|
from typing import List, Optional, Sequence, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.ops import batched_nms
|
|
from mmdet.models.utils import filter_scores_and_topk
|
|
from mmdet.utils import ConfigType, OptInstanceList
|
|
from mmengine.config import ConfigDict
|
|
from mmengine.model import ModuleList, bias_init_with_prob
|
|
from mmengine.structures import InstanceData
|
|
from mmyolo.models.dense_heads import YOLOXHead, YOLOXHeadModule
|
|
from mmyolo.registry import MODELS
|
|
from torch import Tensor
|
|
|
|
from .utils import OutputSaveFunctionWrapper, OutputSaveObjectWrapper
|
|
|
|
|
|
@MODELS.register_module()
|
|
class YOLOXPoseHeadModule(YOLOXHeadModule):
|
|
"""YOLOXPoseHeadModule serves as a head module for `YOLOX-Pose`.
|
|
|
|
In comparison to `YOLOXHeadModule`, this module introduces branches for
|
|
keypoint prediction.
|
|
"""
|
|
|
|
def __init__(self, num_keypoints: int, *args, **kwargs):
|
|
self.num_keypoints = num_keypoints
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def _init_layers(self):
|
|
"""Initializes the layers in the head module."""
|
|
super()._init_layers()
|
|
|
|
# The pose branch requires additional layers for precise regression
|
|
self.stacked_convs *= 2
|
|
|
|
# Create separate layers for each level of feature maps
|
|
pose_convs, offsets_preds, vis_preds = [], [], []
|
|
for _ in self.featmap_strides:
|
|
pose_convs.append(self._build_stacked_convs())
|
|
offsets_preds.append(
|
|
nn.Conv2d(self.feat_channels, self.num_keypoints * 2, 1))
|
|
vis_preds.append(
|
|
nn.Conv2d(self.feat_channels, self.num_keypoints, 1))
|
|
|
|
self.multi_level_pose_convs = ModuleList(pose_convs)
|
|
self.multi_level_conv_offsets = ModuleList(offsets_preds)
|
|
self.multi_level_conv_vis = ModuleList(vis_preds)
|
|
|
|
def init_weights(self):
|
|
"""Initialize weights of the head."""
|
|
super().init_weights()
|
|
|
|
# Use prior in model initialization to improve stability
|
|
bias_init = bias_init_with_prob(0.01)
|
|
for conv_vis in self.multi_level_conv_vis:
|
|
conv_vis.bias.data.fill_(bias_init)
|
|
|
|
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
|
|
"""Forward features from the upstream network."""
|
|
offsets_pred, vis_pred = [], []
|
|
for i in range(len(x)):
|
|
pose_feat = self.multi_level_pose_convs[i](x[i])
|
|
offsets_pred.append(self.multi_level_conv_offsets[i](pose_feat))
|
|
vis_pred.append(self.multi_level_conv_vis[i](pose_feat))
|
|
|
|
return (*super().forward(x), offsets_pred, vis_pred)
|
|
|
|
|
|
@MODELS.register_module()
|
|
class YOLOXPoseHead(YOLOXHead):
|
|
"""YOLOXPoseHead head used in `YOLO-Pose.
|
|
|
|
<https://arxiv.org/abs/2204.06806>`_.
|
|
|
|
Args:
|
|
loss_pose (ConfigDict, optional): Config of keypoint OKS loss.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
loss_pose: Optional[ConfigType] = None,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
super().__init__(*args, **kwargs)
|
|
self.loss_pose = MODELS.build(loss_pose)
|
|
self.num_keypoints = self.head_module.num_keypoints
|
|
|
|
# set up buffers to save variables generated in methods of
|
|
# the class's base class.
|
|
self._log = defaultdict(list)
|
|
self.sampler = OutputSaveObjectWrapper(self.sampler)
|
|
|
|
# ensure that the `sigmas` in self.assigner.oks_calculator
|
|
# is on the same device as the model
|
|
if hasattr(self.assigner, 'oks_calculator'):
|
|
self.add_module('assigner_oks_calculator',
|
|
self.assigner.oks_calculator)
|
|
|
|
def _clear(self):
|
|
"""Clear variable buffers."""
|
|
self.sampler.clear()
|
|
self._log.clear()
|
|
|
|
def loss_by_feat(self,
|
|
cls_scores: Sequence[Tensor],
|
|
bbox_preds: Sequence[Tensor],
|
|
objectnesses: Sequence[Tensor],
|
|
kpt_preds: Sequence[Tensor],
|
|
vis_preds: Sequence[Tensor],
|
|
batch_gt_instances: Sequence[InstanceData],
|
|
batch_img_metas: Sequence[dict],
|
|
batch_gt_instances_ignore: OptInstanceList = None
|
|
) -> dict:
|
|
"""Calculate the loss based on the features extracted by the detection
|
|
head.
|
|
|
|
In addition to the base class method, keypoint losses are also
|
|
calculated in this method.
|
|
"""
|
|
|
|
self._clear()
|
|
|
|
# collect keypoints coordinates and visibility from model predictions
|
|
kpt_preds = torch.cat([
|
|
kpt_pred.flatten(2).permute(0, 2, 1).contiguous()
|
|
for kpt_pred in kpt_preds
|
|
],
|
|
dim=1)
|
|
|
|
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
|
|
mlvl_priors = self.prior_generator.grid_priors(
|
|
featmap_sizes,
|
|
dtype=cls_scores[0].dtype,
|
|
device=cls_scores[0].device,
|
|
with_stride=True)
|
|
grid_priors = torch.cat(mlvl_priors)
|
|
|
|
flatten_kpts = self.decode_pose(grid_priors[..., :2], kpt_preds,
|
|
grid_priors[..., 2])
|
|
|
|
vis_preds = torch.cat([
|
|
vis_pred.flatten(2).permute(0, 2, 1).contiguous()
|
|
for vis_pred in vis_preds
|
|
],
|
|
dim=1)
|
|
|
|
# compute detection losses and collect targets for keypoints
|
|
# predictions simultaneously
|
|
self._log['pred_keypoints'] = list(flatten_kpts.detach().split(
|
|
1, dim=0))
|
|
self._log['pred_keypoints_vis'] = list(vis_preds.detach().split(
|
|
1, dim=0))
|
|
|
|
losses = super().loss_by_feat(cls_scores, bbox_preds, objectnesses,
|
|
batch_gt_instances, batch_img_metas,
|
|
batch_gt_instances_ignore)
|
|
|
|
kpt_targets, vis_targets = [], []
|
|
sampling_results = self.sampler.log['sample']
|
|
sampling_result_idx = 0
|
|
for gt_instances in batch_gt_instances:
|
|
if len(gt_instances) > 0:
|
|
sampling_result = sampling_results[sampling_result_idx]
|
|
kpt_target = gt_instances['keypoints'][
|
|
sampling_result.pos_assigned_gt_inds]
|
|
vis_target = gt_instances['keypoints_visible'][
|
|
sampling_result.pos_assigned_gt_inds]
|
|
sampling_result_idx += 1
|
|
kpt_targets.append(kpt_target)
|
|
vis_targets.append(vis_target)
|
|
|
|
if len(kpt_targets) > 0:
|
|
kpt_targets = torch.cat(kpt_targets, 0)
|
|
vis_targets = torch.cat(vis_targets, 0)
|
|
|
|
# compute keypoint losses
|
|
if len(kpt_targets) > 0:
|
|
vis_targets = (vis_targets > 0).float()
|
|
pos_masks = torch.cat(self._log['foreground_mask'], 0)
|
|
bbox_targets = torch.cat(self._log['bbox_target'], 0)
|
|
loss_kpt = self.loss_pose(
|
|
flatten_kpts.view(-1, self.num_keypoints, 2)[pos_masks],
|
|
kpt_targets, vis_targets, bbox_targets)
|
|
loss_vis = self.loss_cls(
|
|
vis_preds.view(-1, self.num_keypoints)[pos_masks],
|
|
vis_targets) / vis_targets.sum()
|
|
else:
|
|
loss_kpt = kpt_preds.sum() * 0
|
|
loss_vis = vis_preds.sum() * 0
|
|
|
|
losses.update(dict(loss_kpt=loss_kpt, loss_vis=loss_vis))
|
|
|
|
self._clear()
|
|
return losses
|
|
|
|
@torch.no_grad()
|
|
def _get_targets_single(self,
|
|
priors: Tensor,
|
|
cls_preds: Tensor,
|
|
decoded_bboxes: Tensor,
|
|
objectness: Tensor,
|
|
gt_instances: InstanceData,
|
|
img_meta: dict,
|
|
gt_instances_ignore: Optional[InstanceData] = None
|
|
) -> tuple:
|
|
"""Calculates targets for a single image, and saves them to the log.
|
|
|
|
This method is similar to the _get_targets_single method in the base
|
|
class, but additionally saves the foreground mask and bbox targets to
|
|
the log.
|
|
"""
|
|
|
|
# Construct a combined representation of bboxes and keypoints to
|
|
# ensure keypoints are also involved in the positive sample
|
|
# assignment process
|
|
kpt = self._log['pred_keypoints'].pop(0).squeeze(0)
|
|
kpt_vis = self._log['pred_keypoints_vis'].pop(0).squeeze(0)
|
|
kpt = torch.cat((kpt, kpt_vis.unsqueeze(-1)), dim=-1)
|
|
decoded_bboxes = torch.cat((decoded_bboxes, kpt.flatten(1)), dim=1)
|
|
|
|
targets = super()._get_targets_single(priors, cls_preds,
|
|
decoded_bboxes, objectness,
|
|
gt_instances, img_meta,
|
|
gt_instances_ignore)
|
|
self._log['foreground_mask'].append(targets[0])
|
|
self._log['bbox_target'].append(targets[3])
|
|
return targets
|
|
|
|
def predict_by_feat(self,
|
|
cls_scores: List[Tensor],
|
|
bbox_preds: List[Tensor],
|
|
objectnesses: Optional[List[Tensor]] = None,
|
|
kpt_preds: Optional[List[Tensor]] = None,
|
|
vis_preds: Optional[List[Tensor]] = None,
|
|
batch_img_metas: Optional[List[dict]] = None,
|
|
cfg: Optional[ConfigDict] = None,
|
|
rescale: bool = True,
|
|
with_nms: bool = True) -> List[InstanceData]:
|
|
"""Transform a batch of output features extracted by the head into bbox
|
|
and keypoint results.
|
|
|
|
In addition to the base class method, keypoint predictions are also
|
|
calculated in this method.
|
|
"""
|
|
|
|
# calculate predicted bboxes and get the kept instances indices
|
|
with OutputSaveFunctionWrapper(
|
|
filter_scores_and_topk,
|
|
super().predict_by_feat.__globals__) as outputs_1:
|
|
with OutputSaveFunctionWrapper(
|
|
batched_nms,
|
|
super()._bbox_post_process.__globals__) as outputs_2:
|
|
results_list = super().predict_by_feat(cls_scores, bbox_preds,
|
|
objectnesses,
|
|
batch_img_metas, cfg,
|
|
rescale, with_nms)
|
|
keep_indices_topk = [out[2] for out in outputs_1]
|
|
keep_indices_nms = [out[1] for out in outputs_2]
|
|
|
|
num_imgs = len(batch_img_metas)
|
|
|
|
# recover keypoints coordinates from model predictions
|
|
featmap_sizes = [vis_pred.shape[2:] for vis_pred in vis_preds]
|
|
priors = torch.cat(self.mlvl_priors)
|
|
strides = [
|
|
priors.new_full((featmap_size.numel() * self.num_base_priors, ),
|
|
stride) for featmap_size, stride in zip(
|
|
featmap_sizes, self.featmap_strides)
|
|
]
|
|
strides = torch.cat(strides)
|
|
kpt_preds = torch.cat([
|
|
kpt_pred.permute(0, 2, 3, 1).reshape(
|
|
num_imgs, -1, self.num_keypoints * 2) for kpt_pred in kpt_preds
|
|
],
|
|
dim=1)
|
|
flatten_decoded_kpts = self.decode_pose(priors, kpt_preds, strides)
|
|
|
|
vis_preds = torch.cat([
|
|
vis_pred.permute(0, 2, 3, 1).reshape(
|
|
num_imgs, -1, self.num_keypoints) for vis_pred in vis_preds
|
|
],
|
|
dim=1).sigmoid()
|
|
|
|
# select keypoints predictions according to bbox scores and nms result
|
|
keep_indices_nms_idx = 0
|
|
for pred_instances, kpts, kpts_vis, img_meta, keep_idxs \
|
|
in zip(
|
|
results_list, flatten_decoded_kpts, vis_preds,
|
|
batch_img_metas, keep_indices_topk):
|
|
|
|
pred_instances.bbox_scores = pred_instances.scores
|
|
|
|
if len(pred_instances) == 0:
|
|
pred_instances.keypoints = kpts[:0]
|
|
pred_instances.keypoint_scores = kpts_vis[:0]
|
|
continue
|
|
|
|
kpts = kpts[keep_idxs]
|
|
kpts_vis = kpts_vis[keep_idxs]
|
|
|
|
if rescale:
|
|
pad_param = img_meta.get('img_meta', None)
|
|
scale_factor = img_meta['scale_factor']
|
|
if pad_param is not None:
|
|
kpts -= kpts.new_tensor([pad_param[2], pad_param[0]])
|
|
kpts /= kpts.new_tensor(scale_factor).repeat(
|
|
(1, self.num_keypoints, 1))
|
|
|
|
keep_idxs_nms = keep_indices_nms[keep_indices_nms_idx]
|
|
kpts = kpts[keep_idxs_nms]
|
|
kpts_vis = kpts_vis[keep_idxs_nms]
|
|
keep_indices_nms_idx += 1
|
|
|
|
pred_instances.keypoints = kpts
|
|
pred_instances.keypoint_scores = kpts_vis
|
|
|
|
return results_list
|
|
|
|
def predict(self,
|
|
x: Tuple[Tensor],
|
|
batch_data_samples,
|
|
rescale: bool = False):
|
|
predictions = [
|
|
pred_instances.numpy() for pred_instances in super().predict(
|
|
x, batch_data_samples, rescale)
|
|
]
|
|
return predictions
|
|
|
|
def decode_pose(self, grids: torch.Tensor, offsets: torch.Tensor,
|
|
strides: Union[torch.Tensor, int]) -> torch.Tensor:
|
|
"""Decode regression offsets to keypoints.
|
|
|
|
Args:
|
|
grids (torch.Tensor): The coordinates of the feature map grids.
|
|
offsets (torch.Tensor): The predicted offset of each keypoint
|
|
relative to its corresponding grid.
|
|
strides (torch.Tensor | int): The stride of the feature map for
|
|
each instance.
|
|
|
|
Returns:
|
|
torch.Tensor: The decoded keypoints coordinates.
|
|
"""
|
|
|
|
if isinstance(strides, int):
|
|
strides = torch.tensor([strides]).to(offsets)
|
|
|
|
strides = strides.reshape(1, -1, 1, 1)
|
|
offsets = offsets.reshape(*offsets.shape[:2], -1, 2)
|
|
xy_coordinates = (offsets[..., :2] * strides) + grids.unsqueeze(1)
|
|
return xy_coordinates
|
|
|
|
@staticmethod
|
|
def gt_instances_preprocess(batch_gt_instances: List[InstanceData], *args,
|
|
**kwargs) -> List[InstanceData]:
|
|
return batch_gt_instances
|