mmpose/projects/yolox_pose/datasets/transforms.py

129 lines
4.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import numpy as np
from mmcv.transforms import BaseTransform
from mmdet.datasets.transforms import FilterAnnotations as FilterDetAnnotations
from mmdet.datasets.transforms import PackDetInputs
from mmdet.structures.bbox.box_type import autocast_box_type
from mmyolo.registry import TRANSFORMS
from .bbox_keypoint_structure import BBoxKeypoints
@TRANSFORMS.register_module()
class PoseToDetConverter(BaseTransform):
"""This transform converts the pose data element into a format that is
suitable for the mmdet transforms."""
def transform(self, results: dict) -> dict:
results['seg_map_path'] = None
results['height'] = results['img_shape'][0]
results['width'] = results['img_shape'][1]
num_instances = len(results.get('bbox', []))
if num_instances == 0:
results['bbox'] = np.empty((0, 4), dtype=np.float32)
results['keypoints'] = np.empty(
(0, len(results['flip_indices']), 2), dtype=np.float32)
results['keypoints_visible'] = np.empty(
(0, len(results['flip_indices'])), dtype=np.int32)
results['category_id'] = []
results['gt_bboxes'] = BBoxKeypoints(
data=results['bbox'],
keypoints=results['keypoints'],
keypoints_visible=results['keypoints_visible'],
flip_indices=results['flip_indices'],
)
results['gt_ignore_flags'] = np.array([False] * num_instances)
results['gt_bboxes_labels'] = np.array(results['category_id']) - 1
return results
@TRANSFORMS.register_module()
class PackDetPoseInputs(PackDetInputs):
mapping_table = {
'gt_bboxes': 'bboxes',
'gt_bboxes_labels': 'labels',
'gt_masks': 'masks',
'gt_keypoints': 'keypoints',
'gt_keypoints_visible': 'keypoints_visible'
}
def __init__(self,
meta_keys=('id', 'img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor', 'flip',
'flip_direction', 'flip_indices', 'raw_ann_info'),
pack_transformed=False):
self.meta_keys = meta_keys
def transform(self, results: dict) -> dict:
# Add keypoints and their visibility to the results dictionary
results['gt_keypoints'] = results['gt_bboxes'].keypoints
results['gt_keypoints_visible'] = results[
'gt_bboxes'].keypoints_visible
# Ensure all keys in `self.meta_keys` are in the `results` dictionary,
# which is necessary for `PackDetInputs` but not guaranteed during
# inference with an inferencer
for key in self.meta_keys:
if key not in results:
results[key] = None
return super().transform(results)
@TRANSFORMS.register_module()
class FilterDetPoseAnnotations(FilterDetAnnotations):
"""Filter invalid annotations.
In addition to the conditions checked by ``FilterDetAnnotations``, this
filter adds a new condition requiring instances to have at least one
visible keypoints.
"""
@autocast_box_type()
def transform(self, results: dict) -> Union[dict, None]:
"""Transform function to filter annotations.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""
assert 'gt_bboxes' in results
gt_bboxes = results['gt_bboxes']
if gt_bboxes.shape[0] == 0:
return results
tests = []
if self.by_box:
tests.append(((gt_bboxes.widths > self.min_gt_bbox_wh[0]) &
(gt_bboxes.heights > self.min_gt_bbox_wh[1]) &
(gt_bboxes.num_keypoints > 0)).numpy())
if self.by_mask:
assert 'gt_masks' in results
gt_masks = results['gt_masks']
tests.append(gt_masks.areas >= self.min_gt_mask_area)
keep = tests[0]
for t in tests[1:]:
keep = keep & t
if not keep.any():
if self.keep_empty:
return None
keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags')
for key in keys:
if key in results:
results[key] = results[key][keep]
return results