tensorlayer3/tensorlayer/files/dataset_loaders/mpii_dataset.py

253 lines
11 KiB
Python

#! /usr/bin/python
# -*- coding: utf-8 -*-
import os
from tensorlayer import logging
from tensorlayer.files.utils import (del_file, folder_exists, load_file_list, maybe_download_and_extract)
__all__ = ['load_mpii_pose_dataset']
def load_mpii_pose_dataset(path='data', is_16_pos_only=False):
"""Load MPII Human Pose Dataset.
Parameters
-----------
path : str
The path that the data is downloaded to.
is_16_pos_only : boolean
If True, only return the peoples contain 16 pose keypoints. (Usually be used for single person pose estimation)
Returns
----------
img_train_list : list of str
The image directories of training data.
ann_train_list : list of dict
The annotations of training data.
img_test_list : list of str
The image directories of testing data.
ann_test_list : list of dict
The annotations of testing data.
Examples
--------
>>> import pprint
>>> import tensorlayer as tl
>>> img_train_list, ann_train_list, img_test_list, ann_test_list = tl.files.load_mpii_pose_dataset()
>>> image = tl.vis.read_image(img_train_list[0])
>>> tl.vis.draw_mpii_pose_to_image(image, ann_train_list[0], 'image.png')
>>> pprint.pprint(ann_train_list[0])
References
-----------
- `MPII Human Pose Dataset. CVPR 14 <http://human-pose.mpi-inf.mpg.de>`__
- `MPII Human Pose Models. CVPR 16 <http://pose.mpi-inf.mpg.de>`__
- `MPII Human Shape, Poselet Conditioned Pictorial Structures and etc <http://pose.mpi-inf.mpg.de/#related>`__
- `MPII Keyponts and ID <http://human-pose.mpi-inf.mpg.de/#download>`__
"""
path = os.path.join(path, 'mpii_human_pose')
logging.info("Load or Download MPII Human Pose > {}".format(path))
# annotation
url = "http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/"
tar_filename = "mpii_human_pose_v1_u12_2.zip"
extracted_filename = "mpii_human_pose_v1_u12_2"
if folder_exists(os.path.join(path, extracted_filename)) is False:
logging.info("[MPII] (annotation) {} is nonexistent in {}".format(extracted_filename, path))
maybe_download_and_extract(tar_filename, path, url, extract=True)
del_file(os.path.join(path, tar_filename))
# images
url = "http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/"
tar_filename = "mpii_human_pose_v1.tar.gz"
extracted_filename2 = "images"
if folder_exists(os.path.join(path, extracted_filename2)) is False:
logging.info("[MPII] (images) {} is nonexistent in {}".format(extracted_filename, path))
maybe_download_and_extract(tar_filename, path, url, extract=True)
del_file(os.path.join(path, tar_filename))
# parse annotation, format see http://human-pose.mpi-inf.mpg.de/#download
import scipy.io as sio
logging.info("reading annotations from mat file ...")
# mat = sio.loadmat(os.path.join(path, extracted_filename, "mpii_human_pose_v1_u12_1.mat"))
# def fix_wrong_joints(joint): # https://github.com/mitmul/deeppose/blob/master/datasets/mpii_dataset.py
# if '12' in joint and '13' in joint and '2' in joint and '3' in joint:
# if ((joint['12'][0] < joint['13'][0]) and
# (joint['3'][0] < joint['2'][0])):
# joint['2'], joint['3'] = joint['3'], joint['2']
# if ((joint['12'][0] > joint['13'][0]) and
# (joint['3'][0] > joint['2'][0])):
# joint['2'], joint['3'] = joint['3'], joint['2']
# return joint
ann_train_list = []
ann_test_list = []
img_train_list = []
img_test_list = []
def save_joints():
# joint_data_fn = os.path.join(path, 'data.json')
# fp = open(joint_data_fn, 'w')
mat = sio.loadmat(os.path.join(path, extracted_filename, "mpii_human_pose_v1_u12_1.mat"))
for _, (anno, train_flag) in enumerate( # all images
zip(mat['RELEASE']['annolist'][0, 0][0], mat['RELEASE']['img_train'][0, 0][0])):
img_fn = anno['image']['name'][0, 0][0]
train_flag = int(train_flag)
# print(i, img_fn, train_flag) # DEBUG print all images
if train_flag:
img_train_list.append(img_fn)
ann_train_list.append([])
else:
img_test_list.append(img_fn)
ann_test_list.append([])
head_rect = []
if 'x1' in str(anno['annorect'].dtype):
head_rect = zip(
[x1[0, 0] for x1 in anno['annorect']['x1'][0]], [y1[0, 0] for y1 in anno['annorect']['y1'][0]],
[x2[0, 0] for x2 in anno['annorect']['x2'][0]], [y2[0, 0] for y2 in anno['annorect']['y2'][0]]
)
else:
head_rect = [] # TODO
if 'annopoints' in str(anno['annorect'].dtype):
annopoints = anno['annorect']['annopoints'][0]
head_x1s = anno['annorect']['x1'][0]
head_y1s = anno['annorect']['y1'][0]
head_x2s = anno['annorect']['x2'][0]
head_y2s = anno['annorect']['y2'][0]
for annopoint, head_x1, head_y1, head_x2, head_y2 in zip(annopoints, head_x1s, head_y1s, head_x2s,
head_y2s):
# if annopoint != []:
# if len(annopoint) != 0:
if annopoint.size:
head_rect = [
float(head_x1[0, 0]),
float(head_y1[0, 0]),
float(head_x2[0, 0]),
float(head_y2[0, 0])
]
# joint coordinates
annopoint = annopoint['point'][0, 0]
j_id = [str(j_i[0, 0]) for j_i in annopoint['id'][0]]
x = [x[0, 0] for x in annopoint['x'][0]]
y = [y[0, 0] for y in annopoint['y'][0]]
joint_pos = {}
for _j_id, (_x, _y) in zip(j_id, zip(x, y)):
joint_pos[int(_j_id)] = [float(_x), float(_y)]
# joint_pos = fix_wrong_joints(joint_pos)
# visibility list
if 'is_visible' in str(annopoint.dtype):
vis = [v[0] if v.size > 0 else [0] for v in annopoint['is_visible'][0]]
vis = dict([(k, int(v[0])) if len(v) > 0 else v for k, v in zip(j_id, vis)])
else:
vis = None
# if len(joint_pos) == 16:
if ((is_16_pos_only ==True) and (len(joint_pos) == 16)) or (is_16_pos_only == False):
# only use image with 16 key points / or use all
data = {
'filename': img_fn,
'train': train_flag,
'head_rect': head_rect,
'is_visible': vis,
'joint_pos': joint_pos
}
# print(json.dumps(data), file=fp) # py3
if train_flag:
ann_train_list[-1].append(data)
else:
ann_test_list[-1].append(data)
# def write_line(datum, fp):
# joints = sorted([[int(k), v] for k, v in datum['joint_pos'].items()])
# joints = np.array([j for i, j in joints]).flatten()
#
# out = [datum['filename']]
# out.extend(joints)
# out = [str(o) for o in out]
# out = ','.join(out)
#
# print(out, file=fp)
# def split_train_test():
# # fp_test = open('data/mpii/test_joints.csv', 'w')
# fp_test = open(os.path.join(path, 'test_joints.csv'), 'w')
# # fp_train = open('data/mpii/train_joints.csv', 'w')
# fp_train = open(os.path.join(path, 'train_joints.csv'), 'w')
# # all_data = open('data/mpii/data.json').readlines()
# all_data = open(os.path.join(path, 'data.json')).readlines()
# N = len(all_data)
# N_test = int(N * 0.1)
# N_train = N - N_test
#
# print('N:{}'.format(N))
# print('N_train:{}'.format(N_train))
# print('N_test:{}'.format(N_test))
#
# np.random.seed(1701)
# perm = np.random.permutation(N)
# test_indices = perm[:N_test]
# train_indices = perm[N_test:]
#
# print('train_indices:{}'.format(len(train_indices)))
# print('test_indices:{}'.format(len(test_indices)))
#
# for i in train_indices:
# datum = json.loads(all_data[i].strip())
# write_line(datum, fp_train)
#
# for i in test_indices:
# datum = json.loads(all_data[i].strip())
# write_line(datum, fp_test)
save_joints()
# split_train_test() #
## read images dir
logging.info("reading images list ...")
img_dir = os.path.join(path, extracted_filename2)
_img_list = load_file_list(path=os.path.join(path, extracted_filename2), regx='\\.jpg', printable=False)
# ann_list = json.load(open(os.path.join(path, 'data.json')))
for i, im in enumerate(img_train_list):
if im not in _img_list:
print('missing training image {} in {} (remove from img(ann)_train_list)'.format(im, img_dir))
# img_train_list.remove(im)
del img_train_list[i]
del ann_train_list[i]
for i, im in enumerate(img_test_list):
if im not in _img_list:
print('missing testing image {} in {} (remove from img(ann)_test_list)'.format(im, img_dir))
# img_test_list.remove(im)
del img_train_list[i]
del ann_train_list[i]
## check annotation and images
n_train_images = len(img_train_list)
n_test_images = len(img_test_list)
n_images = n_train_images + n_test_images
logging.info("n_images: {} n_train_images: {} n_test_images: {}".format(n_images, n_train_images, n_test_images))
n_train_ann = len(ann_train_list)
n_test_ann = len(ann_test_list)
n_ann = n_train_ann + n_test_ann
logging.info("n_ann: {} n_train_ann: {} n_test_ann: {}".format(n_ann, n_train_ann, n_test_ann))
n_train_people = len(sum(ann_train_list, []))
n_test_people = len(sum(ann_test_list, []))
n_people = n_train_people + n_test_people
logging.info("n_people: {} n_train_people: {} n_test_people: {}".format(n_people, n_train_people, n_test_people))
# add path to all image file name
for i, value in enumerate(img_train_list):
img_train_list[i] = os.path.join(img_dir, value)
for i, value in enumerate(img_test_list):
img_test_list[i] = os.path.join(img_dir, value)
return img_train_list, ann_train_list, img_test_list, ann_test_list