forked from TensorLayer/tensorlayer3
73 lines
2.6 KiB
Python
73 lines
2.6 KiB
Python
#! /usr/bin/python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
|
|
from tensorlayer import logging, nlp
|
|
from tensorlayer.files.utils import maybe_download_and_extract
|
|
|
|
__all__ = ['load_ptb_dataset']
|
|
|
|
|
|
def load_ptb_dataset(path='data'):
|
|
"""Load Penn TreeBank (PTB) dataset.
|
|
|
|
It is used in many LANGUAGE MODELING papers,
|
|
including "Empirical Evaluation and Combination of Advanced Language
|
|
Modeling Techniques", "Recurrent Neural Network Regularization".
|
|
It consists of 929k training words, 73k validation words, and 82k test
|
|
words. It has 10k words in its vocabulary.
|
|
|
|
Parameters
|
|
----------
|
|
path : str
|
|
The path that the data is downloaded to, defaults is ``data/ptb/``.
|
|
|
|
Returns
|
|
--------
|
|
train_data, valid_data, test_data : list of int
|
|
The training, validating and testing data in integer format.
|
|
vocab_size : int
|
|
The vocabulary size.
|
|
|
|
Examples
|
|
--------
|
|
>>> train_data, valid_data, test_data, vocab_size = tl.files.load_ptb_dataset()
|
|
|
|
References
|
|
---------------
|
|
- ``tensorflow.models.rnn.ptb import reader``
|
|
- `Manual download <http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz>`__
|
|
|
|
Notes
|
|
------
|
|
- If you want to get the raw data, see the source code.
|
|
|
|
"""
|
|
path = os.path.join(path, 'ptb')
|
|
logging.info("Load or Download Penn TreeBank (PTB) dataset > {}".format(path))
|
|
|
|
#Maybe dowload and uncompress tar, or load exsisting files
|
|
filename = 'simple-examples.tgz'
|
|
url = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/'
|
|
maybe_download_and_extract(filename, path, url, extract=True)
|
|
|
|
data_path = os.path.join(path, 'simple-examples', 'data')
|
|
train_path = os.path.join(data_path, "ptb.train.txt")
|
|
valid_path = os.path.join(data_path, "ptb.valid.txt")
|
|
test_path = os.path.join(data_path, "ptb.test.txt")
|
|
|
|
word_to_id = nlp.build_vocab(nlp.read_words(train_path))
|
|
|
|
train_data = nlp.words_to_word_ids(nlp.read_words(train_path), word_to_id)
|
|
valid_data = nlp.words_to_word_ids(nlp.read_words(valid_path), word_to_id)
|
|
test_data = nlp.words_to_word_ids(nlp.read_words(test_path), word_to_id)
|
|
vocab_size = len(word_to_id)
|
|
|
|
# logging.info(nlp.read_words(train_path)) # ... 'according', 'to', 'mr.', '<unk>', '<eos>']
|
|
# logging.info(train_data) # ... 214, 5, 23, 1, 2]
|
|
# logging.info(word_to_id) # ... 'beyond': 1295, 'anti-nuclear': 9599, 'trouble': 1520, '<eos>': 2 ... }
|
|
# logging.info(vocabulary) # 10000
|
|
# exit()
|
|
return train_data, valid_data, test_data, vocab_size
|