tensorlayer3/tensorlayer/files/dataset_loaders/celebA_dataset.py

44 lines
1.3 KiB
Python

#! /usr/bin/python
# -*- coding: utf-8 -*-
import os
import zipfile
from tensorlayer import logging
from tensorlayer.files.utils import (download_file_from_google_drive, exists_or_mkdir, load_file_list)
__all__ = ['load_celebA_dataset']
def load_celebA_dataset(path='data'):
"""Load CelebA dataset
Return a list of image path.
Parameters
-----------
path : str
The path that the data is downloaded to, defaults is ``data/celebA/``.
"""
data_dir = 'celebA'
filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
save_path = os.path.join(path, filename)
image_path = os.path.join(path, data_dir)
if os.path.exists(image_path):
logging.info('[*] {} already exists'.format(save_path))
else:
exists_or_mkdir(path)
download_file_from_google_drive(drive_id, save_path)
zip_dir = ''
with zipfile.ZipFile(save_path) as zf:
zip_dir = zf.namelist()[0]
zf.extractall(path)
os.remove(save_path)
os.rename(os.path.join(path, zip_dir), image_path)
data_files = load_file_list(path=image_path, regx='\\.jpg', printable=False)
for i, _v in enumerate(data_files):
data_files[i] = os.path.join(image_path, data_files[i])
return data_files