forked from TensorLayer/tensorlayer3
50 lines
1.8 KiB
Plaintext
50 lines
1.8 KiB
Plaintext
#----------------------------------------------------#
|
|
# 获取测试集的ground-truth
|
|
#----------------------------------------------------#
|
|
import glob
|
|
import os
|
|
import sys
|
|
import xml.etree.ElementTree as ET
|
|
|
|
#---------------------------------------------------#
|
|
# 获得类
|
|
#---------------------------------------------------#
|
|
def get_classes(classes_path):
|
|
'''loads the classes'''
|
|
with open(classes_path) as f:
|
|
class_names = f.readlines()
|
|
class_names = [c.strip() for c in class_names]
|
|
return class_names
|
|
|
|
image_ids = open('VOCdevkit/VOC2007/ImageSets/Main/test.txt').read().strip().split()
|
|
|
|
if not os.path.exists("./input"):
|
|
os.makedirs("./input")
|
|
if not os.path.exists("./input/ground-truth"):
|
|
os.makedirs("./input/ground-truth")
|
|
|
|
for image_id in image_ids:
|
|
with open("./input/ground-truth/"+image_id+".txt", "w") as new_f:
|
|
root = ET.parse("VOCdevkit/VOC2007/Annotations/"+image_id+".xml").getroot()
|
|
for obj in root.findall('object'):
|
|
difficult_flag = False
|
|
if obj.find('difficult')!=None:
|
|
difficult = obj.find('difficult').text
|
|
if int(difficult)==1:
|
|
difficult_flag = True
|
|
obj_name = obj.find('name').text
|
|
|
|
|
|
bndbox = obj.find('bndbox')
|
|
left = bndbox.find('xmin').text
|
|
top = bndbox.find('ymin').text
|
|
right = bndbox.find('xmax').text
|
|
bottom = bndbox.find('ymax').text
|
|
|
|
if difficult_flag:
|
|
new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
|
|
else:
|
|
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
|
|
|
|
print("Conversion completed!")
|