mmpose/projects/example_project/models/example_loss.py

41 lines
1.4 KiB
Python

from mmpose.models import KeypointMSELoss
from mmpose.registry import MODELS
# Register your loss to the `MODELS`.
@MODELS.register_module()
class ExampleLoss(KeypointMSELoss):
"""Implements an example loss.
Implement the loss just like a normal pytorch module.
"""
def __init__(self, **kwargs) -> None:
print('Initializing ExampleLoss...')
super().__init__(**kwargs)
def forward(self, output, target, target_weights=None, mask=None):
"""Forward function of loss. The input arguments should match those
given in `head.loss` function.
Note:
- batch_size: B
- num_keypoints: K
- heatmaps height: H
- heatmaps weight: W
Args:
output (Tensor): The output heatmaps with shape [B, K, H, W]
target (Tensor): The target heatmaps with shape [B, K, H, W]
target_weights (Tensor, optional): The target weights of differet
keypoints, with shape [B, K] (keypoint-wise) or
[B, K, H, W] (pixel-wise).
mask (Tensor, optional): The masks of valid heatmap pixels in
shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will
be applied. Defaults to ``None``
Returns:
Tensor: The calculated loss.
"""
return super().forward(output, target, target_weights, mask)