forked from TensorLayer/tensorlayer3
116 lines
3.5 KiB
Python
116 lines
3.5 KiB
Python
#! /usr/bin/python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import gzip
|
|
import os
|
|
|
|
import numpy as np
|
|
import six.moves.cPickle as pickle
|
|
|
|
from tensorlayer.files.utils import maybe_download_and_extract
|
|
|
|
__all__ = ['load_imdb_dataset']
|
|
|
|
|
|
def load_imdb_dataset(
|
|
path='data', nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113, start_char=1, oov_char=2,
|
|
index_from=3
|
|
):
|
|
"""Load IMDB dataset.
|
|
|
|
Parameters
|
|
----------
|
|
path : str
|
|
The path that the data is downloaded to, defaults is ``data/imdb/``.
|
|
nb_words : int
|
|
Number of words to get.
|
|
skip_top : int
|
|
Top most frequent words to ignore (they will appear as oov_char value in the sequence data).
|
|
maxlen : int
|
|
Maximum sequence length. Any longer sequence will be truncated.
|
|
seed : int
|
|
Seed for reproducible data shuffling.
|
|
start_char : int
|
|
The start of a sequence will be marked with this character. Set to 1 because 0 is usually the padding character.
|
|
oov_char : int
|
|
Words that were cut out because of the num_words or skip_top limit will be replaced with this character.
|
|
index_from : int
|
|
Index actual words with this index and higher.
|
|
|
|
Examples
|
|
--------
|
|
>>> X_train, y_train, X_test, y_test = tl.files.load_imdb_dataset(
|
|
... nb_words=20000, test_split=0.2)
|
|
>>> print('X_train.shape', X_train.shape)
|
|
(20000,) [[1, 62, 74, ... 1033, 507, 27],[1, 60, 33, ... 13, 1053, 7]..]
|
|
>>> print('y_train.shape', y_train.shape)
|
|
(20000,) [1 0 0 ..., 1 0 1]
|
|
|
|
References
|
|
-----------
|
|
- `Modified from keras. <https://github.com/fchollet/keras/blob/master/keras/datasets/imdb.py>`__
|
|
|
|
"""
|
|
path = os.path.join(path, 'imdb')
|
|
|
|
filename = "imdb.pkl"
|
|
url = 'https://s3.amazonaws.com/text-datasets/'
|
|
maybe_download_and_extract(filename, path, url)
|
|
|
|
if filename.endswith(".gz"):
|
|
f = gzip.open(os.path.join(path, filename), 'rb')
|
|
else:
|
|
f = open(os.path.join(path, filename), 'rb')
|
|
|
|
X, labels = pickle.load(f)
|
|
f.close()
|
|
|
|
np.random.seed(seed)
|
|
np.random.shuffle(X)
|
|
np.random.seed(seed)
|
|
np.random.shuffle(labels)
|
|
|
|
if start_char is not None:
|
|
X = [[start_char] + [w + index_from for w in x] for x in X]
|
|
elif index_from:
|
|
X = [[w + index_from for w in x] for x in X]
|
|
|
|
if maxlen:
|
|
new_X = []
|
|
new_labels = []
|
|
for x, y in zip(X, labels):
|
|
if len(x) < maxlen:
|
|
new_X.append(x)
|
|
new_labels.append(y)
|
|
X = new_X
|
|
labels = new_labels
|
|
if not X:
|
|
raise Exception(
|
|
'After filtering for sequences shorter than maxlen=' + str(maxlen) + ', no sequence was kept. '
|
|
'Increase maxlen.'
|
|
)
|
|
if not nb_words:
|
|
nb_words = max([max(x) for x in X])
|
|
|
|
# by convention, use 2 as OOV word
|
|
# reserve 'index_from' (=3 by default) characters: 0 (padding), 1 (start), 2 (OOV)
|
|
if oov_char is not None:
|
|
X = [[oov_char if (w >= nb_words or w < skip_top) else w for w in x] for x in X]
|
|
else:
|
|
nX = []
|
|
for x in X:
|
|
nx = []
|
|
for w in x:
|
|
if (w >= nb_words or w < skip_top):
|
|
nx.append(w)
|
|
nX.append(nx)
|
|
X = nX
|
|
|
|
X_train = np.array(X[:int(len(X) * (1 - test_split))])
|
|
y_train = np.array(labels[:int(len(X) * (1 - test_split))])
|
|
|
|
X_test = np.array(X[int(len(X) * (1 - test_split)):])
|
|
y_test = np.array(labels[int(len(X) * (1 - test_split)):])
|
|
|
|
return X_train, y_train, X_test, y_test
|