forked from TensorLayer/tensorlayer3
PifPaf_tracking_heads.py
This commit is contained in:
parent
3e3d554a21
commit
d2881e6f04
|
@ -0,0 +1,105 @@
|
|||
import tensorlayer as tl
|
||||
from tensorlayer.layers import Conv2d
|
||||
from tensorlayer.layers import SequentialLayer
|
||||
|
||||
|
||||
from .heads import HeadNetwork, CompositeField4
|
||||
|
||||
|
||||
class TBaseSingleImage(HeadNetwork):
|
||||
"""Filter the feature map so that they can be used by single image loss.
|
||||
|
||||
Training: only apply loss to image 0 of an image pair of image 0 and 1.
|
||||
Evaluation with forward tracking pose: only keep image 0.
|
||||
Evaluation with full tracking pose: keep all but stack group along feature dim.
|
||||
"""
|
||||
forward_tracking_pose = True
|
||||
tracking_pose_length = 2
|
||||
|
||||
def __init__(self, meta, in_features):
|
||||
super().__init__(meta, in_features)
|
||||
self.head = CompositeField4(meta, in_features)
|
||||
|
||||
def forward(self, *args):
|
||||
x = args[0]
|
||||
|
||||
if self.training:
|
||||
x = x[::2]
|
||||
elif self.forward_tracking_pose:
|
||||
x = x[::self.tracking_pose_length]
|
||||
|
||||
x = self.head(x)
|
||||
|
||||
if not self.training and not self.forward_tracking_pose:
|
||||
# full tracking pose eval
|
||||
# TODO: stack batch dimension in feature dimension and adjust
|
||||
# meta information (make it a property to dynamically return
|
||||
# a different meta for evaluation)
|
||||
raise NotImplementedError
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Tcaf(HeadNetwork):
|
||||
"""Filter the feature map so that they can be used by single image loss.
|
||||
|
||||
Training: only apply loss to image 0 of an image pair of image 0 and 1.
|
||||
Evaluation with forward tracking pose: only keep image 0.
|
||||
Evaluation with full tracking pose: keep all.
|
||||
"""
|
||||
tracking_pose_length = 2
|
||||
reduced_features = 512
|
||||
|
||||
_global_feature_reduction = None
|
||||
_global_feature_compute = None
|
||||
|
||||
def __init__(self, meta, in_features):
|
||||
super().__init__(meta, in_features)
|
||||
|
||||
if self._global_feature_reduction is None:
|
||||
self.__class__._global_feature_reduction = SequentialLayer(
|
||||
[Conv2d(self.reduced_features,
|
||||
kernel_size=(1,1), bias=True,in_channels=in_features),
|
||||
tl.ReLU(inplace=True)]
|
||||
)
|
||||
self.feature_reduction = self._global_feature_reduction
|
||||
|
||||
if self._global_feature_compute is None:
|
||||
self.__class__._global_feature_compute = SequentialLayer(
|
||||
[Conv2d(self.reduced_features * 2,kernel_size=(1,1) ,bias=True,in_channels=self.reduced_features * 2),
|
||||
tl.ReLU(inplace=True)]
|
||||
)
|
||||
self.feature_compute = self._global_feature_compute
|
||||
|
||||
self.head = CompositeField4(meta, self.reduced_features * 2)
|
||||
|
||||
def forward(self, *args):
|
||||
x = args[0]
|
||||
|
||||
# Batches that are not intended for tracking loss might have an
|
||||
# odd number of images (or only 1 image).
|
||||
# In that case, simply do not execute this head as the result should
|
||||
# never be used.
|
||||
if len(x) % 2 == 1:
|
||||
return None
|
||||
|
||||
x = self.feature_reduction(x)
|
||||
|
||||
group_length = 2 if self.training else self.tracking_pose_length
|
||||
primary = x[::group_length]
|
||||
others = [x[i::group_length] for i in range(1, group_length)]
|
||||
|
||||
x = tl.stack([tl.concat([primary, o], dim=1) for o in others], dim=1)
|
||||
x_shape = x.size()
|
||||
x = tl.reshape(x, [x_shape[0] * x_shape[1]] + list(x_shape[2:]))
|
||||
|
||||
x = self.feature_compute(x)
|
||||
|
||||
x = self.head(x)
|
||||
|
||||
if self.tracking_pose_length != 2:
|
||||
# TODO need to stack group from batch dim in feature dim and adjust
|
||||
# meta info
|
||||
raise NotImplementedError
|
||||
|
||||
return x
|
Loading…
Reference in New Issue