mmpose/projects/yolox_pose/datasets/bbox_keypoint_structure.py

286 lines
11 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import List, Optional, Sequence, Tuple, Type, TypeVar, Union
import numpy as np
import torch
from mmdet.structures.bbox import HorizontalBoxes
from torch import Tensor
DeviceType = Union[str, torch.device]
T = TypeVar('T')
IndexType = Union[slice, int, list, torch.LongTensor, torch.cuda.LongTensor,
torch.BoolTensor, torch.cuda.BoolTensor, np.ndarray]
class BBoxKeypoints(HorizontalBoxes):
"""The BBoxKeypoints class is a combination of bounding boxes and keypoints
representation. The box format used in BBoxKeypoints is the same as
HorizontalBoxes.
Args:
data (Tensor or np.ndarray): The box data with shape of
(N, 4).
keypoints (Tensor or np.ndarray): The keypoint data with shape of
(N, K, 2).
keypoints_visible (Tensor or np.ndarray): The visibility of keypoints
with shape of (N, K).
dtype (torch.dtype, Optional): data type of boxes. Defaults to None.
device (str or torch.device, Optional): device of boxes.
Default to None.
clone (bool): Whether clone ``boxes`` or not. Defaults to True.
mode (str, Optional): the mode of boxes. If it is 'cxcywh', the
`data` will be converted to 'xyxy' mode. Defaults to None.
flip_indices (list, Optional): The indices of keypoints when the
images is flipped. Defaults to None.
Notes:
N: the number of instances.
K: the number of keypoints.
"""
def __init__(self,
data: Union[Tensor, np.ndarray],
keypoints: Union[Tensor, np.ndarray],
keypoints_visible: Union[Tensor, np.ndarray],
dtype: Optional[torch.dtype] = None,
device: Optional[DeviceType] = None,
clone: bool = True,
in_mode: Optional[str] = None,
flip_indices: Optional[List] = None) -> None:
super().__init__(
data=data,
dtype=dtype,
device=device,
clone=clone,
in_mode=in_mode)
assert len(data) == len(keypoints)
assert len(data) == len(keypoints_visible)
assert keypoints.ndim == 3
assert keypoints_visible.ndim == 2
keypoints = torch.as_tensor(keypoints)
keypoints_visible = torch.as_tensor(keypoints_visible)
if device is not None:
keypoints = keypoints.to(device=device)
keypoints_visible = keypoints_visible.to(device=device)
if clone:
keypoints = keypoints.clone()
keypoints_visible = keypoints_visible.clone()
self.keypoints = keypoints
self.keypoints_visible = keypoints_visible
self.flip_indices = flip_indices
def flip_(self,
img_shape: Tuple[int, int],
direction: str = 'horizontal') -> None:
"""Flip boxes & kpts horizontally in-place.
Args:
img_shape (Tuple[int, int]): A tuple of image height and width.
direction (str): Flip direction, options are "horizontal",
"vertical" and "diagonal". Defaults to "horizontal"
"""
assert direction == 'horizontal'
super().flip_(img_shape, direction)
self.keypoints[..., 0] = img_shape[1] - self.keypoints[..., 0]
self.keypoints = self.keypoints[:, self.flip_indices]
self.keypoints_visible = self.keypoints_visible[:, self.flip_indices]
def translate_(self, distances: Tuple[float, float]) -> None:
"""Translate boxes and keypoints in-place.
Args:
distances (Tuple[float, float]): translate distances. The first
is horizontal distance and the second is vertical distance.
"""
boxes = self.tensor
assert len(distances) == 2
self.tensor = boxes + boxes.new_tensor(distances).repeat(2)
distances = self.keypoints.new_tensor(distances).reshape(1, 1, 2)
self.keypoints = self.keypoints + distances
def rescale_(self, scale_factor: Tuple[float, float]) -> None:
"""Rescale boxes & keypoints w.r.t. rescale_factor in-place.
Note:
Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes
w.r.t ``scale_facotr``. The difference is that ``resize_`` only
changes the width and the height of boxes, but ``rescale_`` also
rescales the box centers simultaneously.
Args:
scale_factor (Tuple[float, float]): factors for scaling boxes.
The length should be 2.
"""
boxes = self.tensor
assert len(scale_factor) == 2
self.tensor = boxes * boxes.new_tensor(scale_factor).repeat(2)
scale_factor = self.keypoints.new_tensor(scale_factor).reshape(1, 1, 2)
self.keypoints = self.keypoints * scale_factor
def clip_(self, img_shape: Tuple[int, int]) -> None:
"""Clip bounding boxes and set invisible keypoints outside the image
boundary in-place.
Args:
img_shape (Tuple[int, int]): A tuple of image height and width.
"""
boxes = self.tensor
boxes[..., 0::2] = boxes[..., 0::2].clamp(0, img_shape[1])
boxes[..., 1::2] = boxes[..., 1::2].clamp(0, img_shape[0])
kpt_outside = torch.logical_or(
torch.logical_or(self.keypoints[..., 0] < 0,
self.keypoints[..., 1] < 0),
torch.logical_or(self.keypoints[..., 0] > img_shape[1],
self.keypoints[..., 1] > img_shape[0]))
self.keypoints_visible[kpt_outside] *= 0
def project_(self, homography_matrix: Union[Tensor, np.ndarray]) -> None:
"""Geometrically transform bounding boxes and keypoints in-place using
a homography matrix.
Args:
homography_matrix (Tensor or np.ndarray): A 3x3 tensor or ndarray
representing the homography matrix for the transformation.
"""
boxes = self.tensor
if isinstance(homography_matrix, np.ndarray):
homography_matrix = boxes.new_tensor(homography_matrix)
# Convert boxes to corners in homogeneous coordinates
corners = self.hbox2corner(boxes)
corners = torch.cat(
[corners, corners.new_ones(*corners.shape[:-1], 1)], dim=-1)
# Convert keypoints to homogeneous coordinates
keypoints = torch.cat([
self.keypoints,
self.keypoints.new_ones(*self.keypoints.shape[:-1], 1)
],
dim=-1)
# Transpose corners and keypoints for matrix multiplication
corners_T = torch.transpose(corners, -1, -2)
keypoints_T = torch.transpose(keypoints, -1, 0).contiguous().flatten(1)
# Apply homography matrix to corners and keypoints
corners_T = torch.matmul(homography_matrix, corners_T)
keypoints_T = torch.matmul(homography_matrix, keypoints_T)
# Transpose back to original shape
corners = torch.transpose(corners_T, -1, -2)
keypoints_T = keypoints_T.reshape(3, self.keypoints.shape[1], -1)
keypoints = torch.transpose(keypoints_T, -1, 0).contiguous()
# Convert corners and keypoints back to non-homogeneous coordinates
corners = corners[..., :2] / corners[..., 2:3]
keypoints = keypoints[..., :2] / keypoints[..., 2:3]
# Convert corners back to bounding boxes and update object attributes
self.tensor = self.corner2hbox(corners)
self.keypoints = keypoints
@classmethod
def cat(cls: Type[T], box_list: Sequence[T], dim: int = 0) -> T:
"""Cancatenates an instance list into one single instance. Similar to
``torch.cat``.
Args:
box_list (Sequence[T]): A sequence of instances.
dim (int): The dimension over which the box and keypoint are
concatenated. Defaults to 0.
Returns:
T: Concatenated instance.
"""
assert isinstance(box_list, Sequence)
if len(box_list) == 0:
raise ValueError('box_list should not be a empty list.')
assert dim == 0
assert all(isinstance(boxes, cls) for boxes in box_list)
th_box_list = torch.cat([boxes.tensor for boxes in box_list], dim=dim)
th_kpt_list = torch.cat([boxes.keypoints for boxes in box_list],
dim=dim)
th_kpt_vis_list = torch.cat(
[boxes.keypoints_visible for boxes in box_list], dim=dim)
flip_indices = box_list[0].flip_indices
return cls(
th_box_list,
th_kpt_list,
th_kpt_vis_list,
clone=False,
flip_indices=flip_indices)
def __getitem__(self: T, index: IndexType) -> T:
"""Rewrite getitem to protect the last dimension shape."""
boxes = self.tensor
if isinstance(index, np.ndarray):
index = torch.as_tensor(index, device=self.device)
if isinstance(index, Tensor) and index.dtype == torch.bool:
assert index.dim() < boxes.dim()
elif isinstance(index, tuple):
assert len(index) < boxes.dim()
# `Ellipsis`(...) is commonly used in index like [None, ...].
# When `Ellipsis` is in index, it must be the last item.
if Ellipsis in index:
assert index[-1] is Ellipsis
boxes = boxes[index]
keypoints = self.keypoints[index]
keypoints_visible = self.keypoints_visible[index]
if boxes.dim() == 1:
boxes = boxes.reshape(1, -1)
keypoints = keypoints.reshape(1, -1, 2)
keypoints_visible = keypoints_visible.reshape(1, -1)
return type(self)(
boxes,
keypoints,
keypoints_visible,
flip_indices=self.flip_indices,
clone=False)
@property
def num_keypoints(self) -> Tensor:
"""Compute the number of visible keypoints for each object."""
return self.keypoints_visible.sum(dim=1).int()
def __deepcopy__(self, memo):
"""Only clone the tensors when applying deepcopy."""
cls = self.__class__
other = cls.__new__(cls)
memo[id(self)] = other
other.tensor = self.tensor.clone()
other.keypoints = self.keypoints.clone()
other.keypoints_visible = self.keypoints_visible.clone()
other.flip_indices = deepcopy(self.flip_indices)
return other
def clone(self: T) -> T:
"""Reload ``clone`` for tensors."""
return type(self)(
self.tensor,
self.keypoints,
self.keypoints_visible,
flip_indices=self.flip_indices,
clone=True)
def to(self: T, *args, **kwargs) -> T:
"""Reload ``to`` for tensors."""
return type(self)(
self.tensor.to(*args, **kwargs),
self.keypoints.to(*args, **kwargs),
self.keypoints_visible.to(*args, **kwargs),
flip_indices=self.flip_indices,
clone=False)