forked from TensorLayer/tensorlayer3
59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
from collections import defaultdict
|
|
import logging
|
|
from .PifPafbackbone import BaseNetwork
|
|
from .running_cache import RunningCache
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
class Signal:
|
|
"""In-process Signal infrastructure.
|
|
|
|
Objects can subscribe to 'events'. Events are triggered with the emit()
|
|
function and are fanned out to all subscribers.
|
|
"""
|
|
|
|
subscribers = defaultdict(list)
|
|
|
|
@classmethod
|
|
def emit(cls, name, *args, **kwargs):
|
|
subscribers = cls.subscribers.get(name, [])
|
|
LOG.debug('emit %s to %d subscribers', name, len(subscribers))
|
|
for subscriber in subscribers:
|
|
subscriber(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def subscribe(cls, name, subscriber):
|
|
LOG.debug('subscribe to %s', name)
|
|
cls.subscribers[name].append(subscriber)
|
|
|
|
|
|
class TrackingBase(BaseNetwork):
|
|
cached_items = [0, -1]
|
|
|
|
def __init__(self, single_image_backbone):
|
|
super().__init__(
|
|
't' + single_image_backbone.name,
|
|
stride=single_image_backbone.stride,
|
|
out_features=single_image_backbone.out_features,
|
|
)
|
|
self.single_image_backbone = single_image_backbone
|
|
self.running_cache = RunningCache(self.cached_items)
|
|
|
|
Signal.subscribe('eval_reset', self.reset)
|
|
|
|
def reset(self):
|
|
del self.running_cache
|
|
self.running_cache = RunningCache(self.cached_items)
|
|
|
|
def forward(self, *args):
|
|
x = args[0]
|
|
|
|
# backbone
|
|
x = self.single_image_backbone(x)
|
|
|
|
# feature cache
|
|
if not self.training:
|
|
x = self.running_cache(x)
|
|
|
|
return x
|