Added a --reduce_memory option to the training script to keep training

data on disc as a memmap rather than in memory
This commit is contained in:
Matthew Carrigan 2019-03-21 17:02:18 +00:00
parent 2bba7f810e
commit 7d1ae644ef
2 changed files with 30 additions and 20 deletions

View File

@ -6,6 +6,7 @@ import json
import random
import numpy as np
from collections import namedtuple
from tempfile import TemporaryDirectory
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler
@ -53,8 +54,7 @@ def convert_example_to_features(example, tokenizer, max_seq_length):
class PregeneratedDataset(Dataset):
def __init__(self, training_path, epoch, tokenizer, num_data_epochs):
# TODO Add an option to memmap the training data if needed (see note in pregenerate_training_data)
def __init__(self, training_path, epoch, tokenizer, num_data_epochs, reduce_memory=False):
self.vocab = tokenizer.vocab
self.tokenizer = tokenizer
self.epoch = epoch
@ -65,11 +65,28 @@ class PregeneratedDataset(Dataset):
metrics = json.loads(metrics_file.read_text())
num_samples = metrics['num_training_examples']
seq_len = metrics['max_seq_len']
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
input_masks = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
lm_label_ids = np.full(shape=(num_samples, seq_len), dtype=np.int32, fill_value=-1)
is_nexts = np.zeros(shape=(num_samples,), dtype=np.bool)
self.temp_dir = None
self.working_dir = None
if reduce_memory:
self.temp_dir = TemporaryDirectory()
self.working_dir = Path(self.temp_dir.name)
input_ids = np.memmap(filename=self.working_dir/'input_ids.memmap',
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
segment_ids = np.memmap(filename=self.working_dir/'input_masks.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
lm_label_ids[:] = -1
is_nexts = np.memmap(filename=self.working_dir/'is_nexts.memmap',
shape=(num_samples,), mode='w+', dtype=np.bool)
else:
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
input_masks = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
lm_label_ids = np.full(shape=(num_samples, seq_len), dtype=np.int32, fill_value=-1)
is_nexts = np.zeros(shape=(num_samples,), dtype=np.bool)
logging.info(f"Loading training examples for epoch {epoch}")
with data_file.open() as f:
for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")):
@ -110,6 +127,8 @@ def main():
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
"bert-base-multilingual", "bert-base-chinese"])
parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--reduce_memory", action="store_true",
help="Store training data as on-disc memmaps to massively reduce memory usage")
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for")
parser.add_argument("--local_rank",
@ -311,7 +330,5 @@ def main():
torch.save(model_to_save.state_dict(), str(output_model_file))
if __name__ == '__main__':
main()

View File

@ -11,15 +11,10 @@ import json
class DocumentDatabase:
def __init__(self, reduce_memory=False, working_dir=None):
def __init__(self, reduce_memory=False):
if reduce_memory:
if working_dir is None:
self.temp_dir = TemporaryDirectory()
self.working_dir = Path(self.temp_dir.name)
else:
self.temp_dir = None
self.working_dir = Path(working_dir)
self.working_dir.mkdir(parents=True, exist_ok=True)
self.temp_dir = TemporaryDirectory()
self.working_dir = Path(self.temp_dir.name)
self.document_shelf_filepath = self.working_dir / 'shelf.db'
self.document_shelf = shelve.open(str(self.document_shelf_filepath),
flag='n', protocol=-1)
@ -237,8 +232,6 @@ def main():
parser.add_argument("--reduce_memory", action="store_true",
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
parser.add_argument("--working_dir", type=Path, default=None,
help="Temporary directory to use for --reduce_memory. If not set, uses TemporaryDirectory()")
parser.add_argument("--epochs_to_generate", type=int, default=3,
help="Number of epochs of data to pregenerate")
@ -254,7 +247,7 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys())
docs = DocumentDatabase(reduce_memory=args.reduce_memory, working_dir=args.working_dir)
docs = DocumentDatabase(reduce_memory=args.reduce_memory)
with args.train_corpus.open() as f:
doc = []
for line in tqdm(f, desc="Loading Dataset", unit=" lines"):