mirror of https://github.com/open-mmlab/mmpose
318 lines
14 KiB
Python
318 lines
14 KiB
Python
import numpy as np
|
|
|
|
|
|
class TopDownGenerateTargetFewShot:
|
|
"""Generate the target heatmap.
|
|
|
|
Required keys: 'joints_3d', 'joints_3d_visible', 'ann_info'.
|
|
Modified keys: 'target', and 'target_weight'.
|
|
|
|
Args:
|
|
sigma: Sigma of heatmap gaussian for 'MSRA' approach.
|
|
kernel: Kernel of heatmap gaussian for 'Megvii' approach.
|
|
encoding (str): Approach to generate target heatmaps.
|
|
Currently supported approaches: 'MSRA', 'Megvii', 'UDP'.
|
|
Default:'MSRA'
|
|
|
|
unbiased_encoding (bool): Option to use unbiased
|
|
encoding methods.
|
|
Paper ref: Zhang et al. Distribution-Aware Coordinate
|
|
Representation for Human Pose Estimation (CVPR 2020).
|
|
keypoint_pose_distance: Keypoint pose distance for UDP.
|
|
Paper ref: Huang et al. The Devil is in the Details: Delving into
|
|
Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
|
|
target_type (str): supported targets: 'GaussianHeatMap',
|
|
'CombinedTarget'. Default:'GaussianHeatMap'
|
|
CombinedTarget: The combination of classification target
|
|
(response map) and regression target (offset map).
|
|
Paper ref: Huang et al. The Devil is in the Details: Delving into
|
|
Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
|
|
"""
|
|
|
|
def __init__(self,
|
|
sigma=2,
|
|
kernel=(11, 11),
|
|
valid_radius_factor=0.0546875,
|
|
target_type='GaussianHeatMap',
|
|
encoding='MSRA',
|
|
unbiased_encoding=False):
|
|
self.sigma = sigma
|
|
self.unbiased_encoding = unbiased_encoding
|
|
self.kernel = kernel
|
|
self.valid_radius_factor = valid_radius_factor
|
|
self.target_type = target_type
|
|
self.encoding = encoding
|
|
|
|
def _msra_generate_target(self, cfg, joints_3d, joints_3d_visible, sigma):
|
|
"""Generate the target heatmap via "MSRA" approach.
|
|
|
|
Args:
|
|
cfg (dict): data config
|
|
joints_3d: np.ndarray ([num_joints, 3])
|
|
joints_3d_visible: np.ndarray ([num_joints, 3])
|
|
sigma: Sigma of heatmap gaussian
|
|
Returns:
|
|
tuple: A tuple containing targets.
|
|
|
|
- target: Target heatmaps.
|
|
- target_weight: (1: visible, 0: invisible)
|
|
"""
|
|
num_joints = len(joints_3d)
|
|
image_size = cfg['image_size']
|
|
W, H = cfg['heatmap_size']
|
|
joint_weights = cfg['joint_weights']
|
|
use_different_joint_weights = cfg['use_different_joint_weights']
|
|
assert not use_different_joint_weights
|
|
|
|
target_weight = np.zeros((num_joints, 1), dtype=np.float32)
|
|
target = np.zeros((num_joints, H, W), dtype=np.float32)
|
|
|
|
# 3-sigma rule
|
|
tmp_size = sigma * 3
|
|
|
|
if self.unbiased_encoding:
|
|
for joint_id in range(num_joints):
|
|
target_weight[joint_id] = joints_3d_visible[joint_id, 0]
|
|
|
|
feat_stride = image_size / [W, H]
|
|
mu_x = joints_3d[joint_id][0] / feat_stride[0]
|
|
mu_y = joints_3d[joint_id][1] / feat_stride[1]
|
|
# Check that any part of the gaussian is in-bounds
|
|
ul = [mu_x - tmp_size, mu_y - tmp_size]
|
|
br = [mu_x + tmp_size + 1, mu_y + tmp_size + 1]
|
|
if ul[0] >= W or ul[1] >= H or br[0] < 0 or br[1] < 0:
|
|
target_weight[joint_id] = 0
|
|
|
|
if target_weight[joint_id] == 0:
|
|
continue
|
|
|
|
x = np.arange(0, W, 1, np.float32)
|
|
y = np.arange(0, H, 1, np.float32)
|
|
y = y[:, None]
|
|
|
|
if target_weight[joint_id] > 0.5:
|
|
target[joint_id] = np.exp(-((x - mu_x)**2 +
|
|
(y - mu_y)**2) /
|
|
(2 * sigma**2))
|
|
else:
|
|
for joint_id in range(num_joints):
|
|
target_weight[joint_id] = joints_3d_visible[joint_id, 0]
|
|
|
|
feat_stride = image_size / [W, H]
|
|
mu_x = int(joints_3d[joint_id][0] / feat_stride[0] + 0.5)
|
|
mu_y = int(joints_3d[joint_id][1] / feat_stride[1] + 0.5)
|
|
# Check that any part of the gaussian is in-bounds
|
|
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
|
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
|
if ul[0] >= W or ul[1] >= H or br[0] < 0 or br[1] < 0:
|
|
target_weight[joint_id] = 0
|
|
|
|
if target_weight[joint_id] > 0.5:
|
|
size = 2 * tmp_size + 1
|
|
x = np.arange(0, size, 1, np.float32)
|
|
y = x[:, None]
|
|
x0 = y0 = size // 2
|
|
# The gaussian is not normalized,
|
|
# we want the center value to equal 1
|
|
g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
|
|
|
|
# Usable gaussian range
|
|
g_x = max(0, -ul[0]), min(br[0], W) - ul[0]
|
|
g_y = max(0, -ul[1]), min(br[1], H) - ul[1]
|
|
# Image range
|
|
img_x = max(0, ul[0]), min(br[0], W)
|
|
img_y = max(0, ul[1]), min(br[1], H)
|
|
|
|
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
|
|
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
|
|
|
|
if use_different_joint_weights:
|
|
target_weight = np.multiply(target_weight, joint_weights)
|
|
|
|
return target, target_weight
|
|
|
|
def _udp_generate_target(self, cfg, joints_3d, joints_3d_visible, factor,
|
|
target_type):
|
|
"""Generate the target heatmap via 'UDP' approach. Paper ref: Huang et
|
|
al. The Devil is in the Details: Delving into Unbiased Data Processing
|
|
for Human Pose Estimation (CVPR 2020).
|
|
|
|
Note:
|
|
num keypoints: K
|
|
heatmap height: H
|
|
heatmap width: W
|
|
num target channels: C
|
|
C = K if target_type=='GaussianHeatMap'
|
|
C = 3*K if target_type=='CombinedTarget'
|
|
|
|
Args:
|
|
cfg (dict): data config
|
|
joints_3d (np.ndarray[K, 3]): Annotated keypoints.
|
|
joints_3d_visible (np.ndarray[K, 3]): Visibility of keypoints.
|
|
factor (float): kernel factor for GaussianHeatMap target or
|
|
valid radius factor for CombinedTarget.
|
|
target_type (str): 'GaussianHeatMap' or 'CombinedTarget'.
|
|
GaussianHeatMap: Heatmap target with gaussian distribution.
|
|
CombinedTarget: The combination of classification target
|
|
(response map) and regression target (offset map).
|
|
|
|
Returns:
|
|
tuple: A tuple containing targets.
|
|
|
|
- target (np.ndarray[C, H, W]): Target heatmaps.
|
|
- target_weight (np.ndarray[K, 1]): (1: visible, 0: invisible)
|
|
"""
|
|
num_joints = len(joints_3d)
|
|
image_size = cfg['image_size']
|
|
heatmap_size = cfg['heatmap_size']
|
|
joint_weights = cfg['joint_weights']
|
|
use_different_joint_weights = cfg['use_different_joint_weights']
|
|
assert not use_different_joint_weights
|
|
|
|
target_weight = np.ones((num_joints, 1), dtype=np.float32)
|
|
target_weight[:, 0] = joints_3d_visible[:, 0]
|
|
|
|
assert target_type in ['GaussianHeatMap', 'CombinedTarget']
|
|
|
|
if target_type == 'GaussianHeatMap':
|
|
target = np.zeros((num_joints, heatmap_size[1], heatmap_size[0]),
|
|
dtype=np.float32)
|
|
|
|
tmp_size = factor * 3
|
|
|
|
# prepare for gaussian
|
|
size = 2 * tmp_size + 1
|
|
x = np.arange(0, size, 1, np.float32)
|
|
y = x[:, None]
|
|
|
|
for joint_id in range(num_joints):
|
|
feat_stride = (image_size - 1.0) / (heatmap_size - 1.0)
|
|
mu_x = int(joints_3d[joint_id][0] / feat_stride[0] + 0.5)
|
|
mu_y = int(joints_3d[joint_id][1] / feat_stride[1] + 0.5)
|
|
# Check that any part of the gaussian is in-bounds
|
|
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
|
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
|
if ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] \
|
|
or br[0] < 0 or br[1] < 0:
|
|
# If not, just return the image as is
|
|
target_weight[joint_id] = 0
|
|
continue
|
|
|
|
# # Generate gaussian
|
|
mu_x_ac = joints_3d[joint_id][0] / feat_stride[0]
|
|
mu_y_ac = joints_3d[joint_id][1] / feat_stride[1]
|
|
x0 = y0 = size // 2
|
|
x0 += mu_x_ac - mu_x
|
|
y0 += mu_y_ac - mu_y
|
|
g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * factor**2))
|
|
|
|
# Usable gaussian range
|
|
g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]
|
|
g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1]
|
|
# Image range
|
|
img_x = max(0, ul[0]), min(br[0], heatmap_size[0])
|
|
img_y = max(0, ul[1]), min(br[1], heatmap_size[1])
|
|
|
|
v = target_weight[joint_id]
|
|
if v > 0.5:
|
|
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
|
|
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
|
|
elif target_type == 'CombinedTarget':
|
|
target = np.zeros(
|
|
(num_joints, 3, heatmap_size[1] * heatmap_size[0]),
|
|
dtype=np.float32)
|
|
feat_width = heatmap_size[0]
|
|
feat_height = heatmap_size[1]
|
|
feat_x_int = np.arange(0, feat_width)
|
|
feat_y_int = np.arange(0, feat_height)
|
|
feat_x_int, feat_y_int = np.meshgrid(feat_x_int, feat_y_int)
|
|
feat_x_int = feat_x_int.flatten()
|
|
feat_y_int = feat_y_int.flatten()
|
|
# Calculate the radius of the positive area in classification
|
|
# heatmap.
|
|
valid_radius = factor * heatmap_size[1]
|
|
feat_stride = (image_size - 1.0) / (heatmap_size - 1.0)
|
|
for joint_id in range(num_joints):
|
|
mu_x = joints_3d[joint_id][0] / feat_stride[0]
|
|
mu_y = joints_3d[joint_id][1] / feat_stride[1]
|
|
x_offset = (mu_x - feat_x_int) / valid_radius
|
|
y_offset = (mu_y - feat_y_int) / valid_radius
|
|
dis = x_offset**2 + y_offset**2
|
|
keep_pos = np.where(dis <= 1)[0]
|
|
v = target_weight[joint_id]
|
|
if v > 0.5:
|
|
target[joint_id, 0, keep_pos] = 1
|
|
target[joint_id, 1, keep_pos] = x_offset[keep_pos]
|
|
target[joint_id, 2, keep_pos] = y_offset[keep_pos]
|
|
target = target.reshape(num_joints * 3, heatmap_size[1],
|
|
heatmap_size[0])
|
|
|
|
if use_different_joint_weights:
|
|
target_weight = np.multiply(target_weight, joint_weights)
|
|
|
|
return target, target_weight
|
|
|
|
def __call__(self, results):
|
|
"""Generate the target heatmap."""
|
|
joints_3d = results['joints_3d']
|
|
joints_3d_visible = results['joints_3d_visible']
|
|
|
|
assert self.encoding in ['MSRA', 'UDP']
|
|
|
|
if self.encoding == 'MSRA':
|
|
if isinstance(self.sigma, list):
|
|
num_sigmas = len(self.sigma)
|
|
cfg = results['ann_info']
|
|
num_joints = len(joints_3d)
|
|
heatmap_size = cfg['heatmap_size']
|
|
|
|
target = np.empty(
|
|
(0, num_joints, heatmap_size[1], heatmap_size[0]),
|
|
dtype=np.float32)
|
|
target_weight = np.empty((0, num_joints, 1), dtype=np.float32)
|
|
for i in range(num_sigmas):
|
|
target_i, target_weight_i = self._msra_generate_target(
|
|
cfg, joints_3d, joints_3d_visible, self.sigma[i])
|
|
target = np.concatenate([target, target_i[None]], axis=0)
|
|
target_weight = np.concatenate(
|
|
[target_weight, target_weight_i[None]], axis=0)
|
|
else:
|
|
target, target_weight = self._msra_generate_target(
|
|
results['ann_info'], joints_3d, joints_3d_visible,
|
|
self.sigma)
|
|
elif self.encoding == 'UDP':
|
|
if self.target_type == 'CombinedTarget':
|
|
factors = self.valid_radius_factor
|
|
channel_factor = 3
|
|
elif self.target_type == 'GaussianHeatMap':
|
|
factors = self.sigma
|
|
channel_factor = 1
|
|
if isinstance(factors, list):
|
|
num_factors = len(factors)
|
|
cfg = results['ann_info']
|
|
num_joints = len(joints_3d)
|
|
W, H = cfg['heatmap_size']
|
|
|
|
target = np.empty((0, channel_factor * num_joints, H, W),
|
|
dtype=np.float32)
|
|
target_weight = np.empty((0, num_joints, 1), dtype=np.float32)
|
|
for i in range(num_factors):
|
|
target_i, target_weight_i = self._udp_generate_target(
|
|
cfg, joints_3d, joints_3d_visible, factors[i],
|
|
self.target_type)
|
|
target = np.concatenate([target, target_i[None]], axis=0)
|
|
target_weight = np.concatenate(
|
|
[target_weight, target_weight_i[None]], axis=0)
|
|
else:
|
|
target, target_weight = self._udp_generate_target(
|
|
results['ann_info'], joints_3d, joints_3d_visible, factors,
|
|
self.target_type)
|
|
else:
|
|
raise ValueError(
|
|
f'Encoding approach {self.encoding} is not supported!')
|
|
|
|
results['target'] = target
|
|
results['target_weight'] = target_weight
|
|
|
|
return results
|