forked from TensorLayer/tensorlayer3
PifPaf_tracking_heads.py
This commit is contained in:
parent
3e3d554a21
commit
d2881e6f04
|
@ -0,0 +1,105 @@
|
||||||
|
import tensorlayer as tl
|
||||||
|
from tensorlayer.layers import Conv2d
|
||||||
|
from tensorlayer.layers import SequentialLayer
|
||||||
|
|
||||||
|
|
||||||
|
from .heads import HeadNetwork, CompositeField4
|
||||||
|
|
||||||
|
|
||||||
|
class TBaseSingleImage(HeadNetwork):
|
||||||
|
"""Filter the feature map so that they can be used by single image loss.
|
||||||
|
|
||||||
|
Training: only apply loss to image 0 of an image pair of image 0 and 1.
|
||||||
|
Evaluation with forward tracking pose: only keep image 0.
|
||||||
|
Evaluation with full tracking pose: keep all but stack group along feature dim.
|
||||||
|
"""
|
||||||
|
forward_tracking_pose = True
|
||||||
|
tracking_pose_length = 2
|
||||||
|
|
||||||
|
def __init__(self, meta, in_features):
|
||||||
|
super().__init__(meta, in_features)
|
||||||
|
self.head = CompositeField4(meta, in_features)
|
||||||
|
|
||||||
|
def forward(self, *args):
|
||||||
|
x = args[0]
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
x = x[::2]
|
||||||
|
elif self.forward_tracking_pose:
|
||||||
|
x = x[::self.tracking_pose_length]
|
||||||
|
|
||||||
|
x = self.head(x)
|
||||||
|
|
||||||
|
if not self.training and not self.forward_tracking_pose:
|
||||||
|
# full tracking pose eval
|
||||||
|
# TODO: stack batch dimension in feature dimension and adjust
|
||||||
|
# meta information (make it a property to dynamically return
|
||||||
|
# a different meta for evaluation)
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Tcaf(HeadNetwork):
|
||||||
|
"""Filter the feature map so that they can be used by single image loss.
|
||||||
|
|
||||||
|
Training: only apply loss to image 0 of an image pair of image 0 and 1.
|
||||||
|
Evaluation with forward tracking pose: only keep image 0.
|
||||||
|
Evaluation with full tracking pose: keep all.
|
||||||
|
"""
|
||||||
|
tracking_pose_length = 2
|
||||||
|
reduced_features = 512
|
||||||
|
|
||||||
|
_global_feature_reduction = None
|
||||||
|
_global_feature_compute = None
|
||||||
|
|
||||||
|
def __init__(self, meta, in_features):
|
||||||
|
super().__init__(meta, in_features)
|
||||||
|
|
||||||
|
if self._global_feature_reduction is None:
|
||||||
|
self.__class__._global_feature_reduction = SequentialLayer(
|
||||||
|
[Conv2d(self.reduced_features,
|
||||||
|
kernel_size=(1,1), bias=True,in_channels=in_features),
|
||||||
|
tl.ReLU(inplace=True)]
|
||||||
|
)
|
||||||
|
self.feature_reduction = self._global_feature_reduction
|
||||||
|
|
||||||
|
if self._global_feature_compute is None:
|
||||||
|
self.__class__._global_feature_compute = SequentialLayer(
|
||||||
|
[Conv2d(self.reduced_features * 2,kernel_size=(1,1) ,bias=True,in_channels=self.reduced_features * 2),
|
||||||
|
tl.ReLU(inplace=True)]
|
||||||
|
)
|
||||||
|
self.feature_compute = self._global_feature_compute
|
||||||
|
|
||||||
|
self.head = CompositeField4(meta, self.reduced_features * 2)
|
||||||
|
|
||||||
|
def forward(self, *args):
|
||||||
|
x = args[0]
|
||||||
|
|
||||||
|
# Batches that are not intended for tracking loss might have an
|
||||||
|
# odd number of images (or only 1 image).
|
||||||
|
# In that case, simply do not execute this head as the result should
|
||||||
|
# never be used.
|
||||||
|
if len(x) % 2 == 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
x = self.feature_reduction(x)
|
||||||
|
|
||||||
|
group_length = 2 if self.training else self.tracking_pose_length
|
||||||
|
primary = x[::group_length]
|
||||||
|
others = [x[i::group_length] for i in range(1, group_length)]
|
||||||
|
|
||||||
|
x = tl.stack([tl.concat([primary, o], dim=1) for o in others], dim=1)
|
||||||
|
x_shape = x.size()
|
||||||
|
x = tl.reshape(x, [x_shape[0] * x_shape[1]] + list(x_shape[2:]))
|
||||||
|
|
||||||
|
x = self.feature_compute(x)
|
||||||
|
|
||||||
|
x = self.head(x)
|
||||||
|
|
||||||
|
if self.tracking_pose_length != 2:
|
||||||
|
# TODO need to stack group from batch dim in feature dim and adjust
|
||||||
|
# meta info
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return x
|
Loading…
Reference in New Issue