tensorlayer3/tracking_base.py

59 lines
1.6 KiB
Python

from collections import defaultdict
import logging
from .PifPafbackbone import BaseNetwork
from .running_cache import RunningCache
LOG = logging.getLogger(__name__)
class Signal:
"""In-process Signal infrastructure.
Objects can subscribe to 'events'. Events are triggered with the emit()
function and are fanned out to all subscribers.
"""
subscribers = defaultdict(list)
@classmethod
def emit(cls, name, *args, **kwargs):
subscribers = cls.subscribers.get(name, [])
LOG.debug('emit %s to %d subscribers', name, len(subscribers))
for subscriber in subscribers:
subscriber(*args, **kwargs)
@classmethod
def subscribe(cls, name, subscriber):
LOG.debug('subscribe to %s', name)
cls.subscribers[name].append(subscriber)
class TrackingBase(BaseNetwork):
cached_items = [0, -1]
def __init__(self, single_image_backbone):
super().__init__(
't' + single_image_backbone.name,
stride=single_image_backbone.stride,
out_features=single_image_backbone.out_features,
)
self.single_image_backbone = single_image_backbone
self.running_cache = RunningCache(self.cached_items)
Signal.subscribe('eval_reset', self.reset)
def reset(self):
del self.running_cache
self.running_cache = RunningCache(self.cached_items)
def forward(self, *args):
x = args[0]
# backbone
x = self.single_image_backbone(x)
# feature cache
if not self.training:
x = self.running_cache(x)
return x