tensorlayer3/factory.py

36 lines
899 B
Python

"""Factory method for easily getting imdbs by name."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__sets = {}
from coco import coco
# Set up coco_2014_<split>
for year in ['2014']:
for split in ['train', 'val', 'minival', 'valminusminival', 'trainval']:
name = 'coco_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year: coco(split, year))
# Set up coco_2015_<split>
for year in ['2015']:
for split in ['test', 'test-dev']:
name = 'coco_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year: coco(split, year))
def get_imdb(name):
"""Get an imdb (image database) by name."""
if name not in __sets:
raise KeyError('Unknown dataset: {}'.format(name))
return __sets[name]()
def list_imdbs():
"""List all registered imdbs."""
return list(__sets.keys())