mmpose/projects/example_project/models/example_head.py

78 lines
2.9 KiB
Python

from mmpose.models import HeatmapHead
from mmpose.registry import MODELS
# Register your head to the `MODELS`.
@MODELS.register_module()
class ExampleHead(HeatmapHead):
"""Implements an example head.
Implement the model head just like a normal pytorch module.
"""
def __init__(self, **kwargs) -> None:
print('Initializing ExampleHead...')
super().__init__(**kwargs)
def forward(self, feats):
"""Forward the network. The input is multi scale feature maps and the
output is the coordinates.
Args:
feats (Tuple[Tensor]): Multi scale feature maps.
Returns:
Tensor: output coordinates or heatmaps.
"""
return super().forward(feats)
def predict(self, feats, batch_data_samples, test_cfg={}):
"""Predict results from outputs. The behaviour of head during testing
should be defined in this function.
Args:
feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
features (or multiple multi-stage features in TTA)
batch_data_samples (List[:obj:`PoseDataSample`]): A list of
data samples for instances in a batch
test_cfg (dict): The runtime config for testing process. Defaults
to {}
Returns:
Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
``test_cfg['output_heatmap']==True``, return both pose and heatmap
prediction; otherwise only return the pose prediction.
The pose prediction is a list of ``InstanceData``, each contains
the following fields:
- keypoints (np.ndarray): predicted keypoint coordinates in
shape (num_instances, K, D) where K is the keypoint number
and D is the keypoint dimension
- keypoint_scores (np.ndarray): predicted keypoint scores in
shape (num_instances, K)
The heatmap prediction is a list of ``PixelData``, each contains
the following fields:
- heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
"""
return super().predict(feats, batch_data_samples, test_cfg)
def loss(self, feats, batch_data_samples, train_cfg={}) -> dict:
"""Calculate losses from a batch of inputs and data samples. The
behaviour of head during training should be defined in this function.
Args:
feats (Tuple[Tensor]): The multi-stage features
batch_data_samples (List[:obj:`PoseDataSample`]): A list of
data samples for instances in a batch
train_cfg (dict): The runtime config for training process.
Defaults to {}
Returns:
dict: A dictionary of losses.
"""
return super().loss(feats, batch_data_samples, train_cfg)