mirror of https://github.com/open-mmlab/mmpose
41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
from mmpose.models import KeypointMSELoss
|
|
from mmpose.registry import MODELS
|
|
|
|
|
|
# Register your loss to the `MODELS`.
|
|
@MODELS.register_module()
|
|
class ExampleLoss(KeypointMSELoss):
|
|
"""Implements an example loss.
|
|
|
|
Implement the loss just like a normal pytorch module.
|
|
"""
|
|
|
|
def __init__(self, **kwargs) -> None:
|
|
print('Initializing ExampleLoss...')
|
|
super().__init__(**kwargs)
|
|
|
|
def forward(self, output, target, target_weights=None, mask=None):
|
|
"""Forward function of loss. The input arguments should match those
|
|
given in `head.loss` function.
|
|
|
|
Note:
|
|
- batch_size: B
|
|
- num_keypoints: K
|
|
- heatmaps height: H
|
|
- heatmaps weight: W
|
|
|
|
Args:
|
|
output (Tensor): The output heatmaps with shape [B, K, H, W]
|
|
target (Tensor): The target heatmaps with shape [B, K, H, W]
|
|
target_weights (Tensor, optional): The target weights of differet
|
|
keypoints, with shape [B, K] (keypoint-wise) or
|
|
[B, K, H, W] (pixel-wise).
|
|
mask (Tensor, optional): The masks of valid heatmap pixels in
|
|
shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will
|
|
be applied. Defaults to ``None``
|
|
|
|
Returns:
|
|
Tensor: The calculated loss.
|
|
"""
|
|
return super().forward(output, target, target_weights, mask)
|