tensorlayer3/tensorlayer/files/dataset_loaders/ptb_dataset.py

73 lines
2.6 KiB
Python

#! /usr/bin/python
# -*- coding: utf-8 -*-
import os
from tensorlayer import logging, nlp
from tensorlayer.files.utils import maybe_download_and_extract
__all__ = ['load_ptb_dataset']
def load_ptb_dataset(path='data'):
"""Load Penn TreeBank (PTB) dataset.
It is used in many LANGUAGE MODELING papers,
including "Empirical Evaluation and Combination of Advanced Language
Modeling Techniques", "Recurrent Neural Network Regularization".
It consists of 929k training words, 73k validation words, and 82k test
words. It has 10k words in its vocabulary.
Parameters
----------
path : str
The path that the data is downloaded to, defaults is ``data/ptb/``.
Returns
--------
train_data, valid_data, test_data : list of int
The training, validating and testing data in integer format.
vocab_size : int
The vocabulary size.
Examples
--------
>>> train_data, valid_data, test_data, vocab_size = tl.files.load_ptb_dataset()
References
---------------
- ``tensorflow.models.rnn.ptb import reader``
- `Manual download <http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz>`__
Notes
------
- If you want to get the raw data, see the source code.
"""
path = os.path.join(path, 'ptb')
logging.info("Load or Download Penn TreeBank (PTB) dataset > {}".format(path))
#Maybe dowload and uncompress tar, or load exsisting files
filename = 'simple-examples.tgz'
url = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/'
maybe_download_and_extract(filename, path, url, extract=True)
data_path = os.path.join(path, 'simple-examples', 'data')
train_path = os.path.join(data_path, "ptb.train.txt")
valid_path = os.path.join(data_path, "ptb.valid.txt")
test_path = os.path.join(data_path, "ptb.test.txt")
word_to_id = nlp.build_vocab(nlp.read_words(train_path))
train_data = nlp.words_to_word_ids(nlp.read_words(train_path), word_to_id)
valid_data = nlp.words_to_word_ids(nlp.read_words(valid_path), word_to_id)
test_data = nlp.words_to_word_ids(nlp.read_words(test_path), word_to_id)
vocab_size = len(word_to_id)
# logging.info(nlp.read_words(train_path)) # ... 'according', 'to', 'mr.', '<unk>', '<eos>']
# logging.info(train_data) # ... 214, 5, 23, 1, 2]
# logging.info(word_to_id) # ... 'beyond': 1295, 'anti-nuclear': 9599, 'trouble': 1520, '<eos>': 2 ... }
# logging.info(vocabulary) # 10000
# exit()
return train_data, valid_data, test_data, vocab_size