forked from TensorLayer/tensorlayer3
42 lines
1.6 KiB
Python
42 lines
1.6 KiB
Python
|
|
|
|
"""Transform a roidb into a trainable roidb by adding a bunch of metadata."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
import PIL
|
|
|
|
def prepare_roidb(imdb):
|
|
"""Enrich the imdb's roidb by adding some derived quantities that
|
|
are useful for training. This function precomputes the maximum
|
|
overlap, taken over ground-truth boxes, between each ROI and
|
|
each ground-truth box. The class with maximum overlap is also
|
|
recorded.
|
|
"""
|
|
roidb = imdb.roidb
|
|
if not (imdb.name.startswith('coco')):
|
|
sizes = [PIL.Image.open(imdb.image_path_at(i)).size
|
|
for i in range(imdb.num_images)]
|
|
for i in range(len(imdb.image_index)):
|
|
roidb[i]['image'] = imdb.image_path_at(i)
|
|
if not (imdb.name.startswith('coco')):
|
|
roidb[i]['width'] = sizes[i][0]
|
|
roidb[i]['height'] = sizes[i][1]
|
|
# need gt_overlaps as a dense array for argmax
|
|
gt_overlaps = roidb[i]['gt_overlaps'].toarray()
|
|
# max overlap with gt over classes (columns)
|
|
max_overlaps = gt_overlaps.max(axis=1)
|
|
# gt class that had the max overlap
|
|
max_classes = gt_overlaps.argmax(axis=1)
|
|
roidb[i]['max_classes'] = max_classes
|
|
roidb[i]['max_overlaps'] = max_overlaps
|
|
# sanity checks
|
|
# max overlap of 0 => class should be zero (background)
|
|
zero_inds = np.where(max_overlaps == 0)[0]
|
|
assert all(max_classes[zero_inds] == 0)
|
|
# max overlap > 0 => class should not be zero (must be a fg class)
|
|
nonzero_inds = np.where(max_overlaps > 0)[0]
|
|
assert all(max_classes[nonzero_inds] != 0)
|