Added a --reduce_memory option to the training script to keep training

data on disc as a memmap rather than in memory
This commit is contained in:
Matthew Carrigan 2019-03-21 17:02:18 +00:00
parent 2bba7f810e
commit 7d1ae644ef
2 changed files with 30 additions and 20 deletions

View File

@ -6,6 +6,7 @@ import json
import random import random
import numpy as np import numpy as np
from collections import namedtuple from collections import namedtuple
from tempfile import TemporaryDirectory
from torch.utils.data import DataLoader, Dataset, RandomSampler from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
@ -53,8 +54,7 @@ def convert_example_to_features(example, tokenizer, max_seq_length):
class PregeneratedDataset(Dataset): class PregeneratedDataset(Dataset):
def __init__(self, training_path, epoch, tokenizer, num_data_epochs): def __init__(self, training_path, epoch, tokenizer, num_data_epochs, reduce_memory=False):
# TODO Add an option to memmap the training data if needed (see note in pregenerate_training_data)
self.vocab = tokenizer.vocab self.vocab = tokenizer.vocab
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.epoch = epoch self.epoch = epoch
@ -65,11 +65,28 @@ class PregeneratedDataset(Dataset):
metrics = json.loads(metrics_file.read_text()) metrics = json.loads(metrics_file.read_text())
num_samples = metrics['num_training_examples'] num_samples = metrics['num_training_examples']
seq_len = metrics['max_seq_len'] seq_len = metrics['max_seq_len']
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32) self.temp_dir = None
input_masks = np.zeros(shape=(num_samples, seq_len), dtype=np.bool) self.working_dir = None
segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool) if reduce_memory:
lm_label_ids = np.full(shape=(num_samples, seq_len), dtype=np.int32, fill_value=-1) self.temp_dir = TemporaryDirectory()
is_nexts = np.zeros(shape=(num_samples,), dtype=np.bool) self.working_dir = Path(self.temp_dir.name)
input_ids = np.memmap(filename=self.working_dir/'input_ids.memmap',
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
segment_ids = np.memmap(filename=self.working_dir/'input_masks.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
lm_label_ids[:] = -1
is_nexts = np.memmap(filename=self.working_dir/'is_nexts.memmap',
shape=(num_samples,), mode='w+', dtype=np.bool)
else:
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
input_masks = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
lm_label_ids = np.full(shape=(num_samples, seq_len), dtype=np.int32, fill_value=-1)
is_nexts = np.zeros(shape=(num_samples,), dtype=np.bool)
logging.info(f"Loading training examples for epoch {epoch}") logging.info(f"Loading training examples for epoch {epoch}")
with data_file.open() as f: with data_file.open() as f:
for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")): for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")):
@ -110,6 +127,8 @@ def main():
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
"bert-base-multilingual", "bert-base-chinese"]) "bert-base-multilingual", "bert-base-chinese"])
parser.add_argument("--do_lower_case", action="store_true") parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--reduce_memory", action="store_true",
help="Store training data as on-disc memmaps to massively reduce memory usage")
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for") parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for")
parser.add_argument("--local_rank", parser.add_argument("--local_rank",
@ -311,7 +330,5 @@ def main():
torch.save(model_to_save.state_dict(), str(output_model_file)) torch.save(model_to_save.state_dict(), str(output_model_file))
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -11,15 +11,10 @@ import json
class DocumentDatabase: class DocumentDatabase:
def __init__(self, reduce_memory=False, working_dir=None): def __init__(self, reduce_memory=False):
if reduce_memory: if reduce_memory:
if working_dir is None: self.temp_dir = TemporaryDirectory()
self.temp_dir = TemporaryDirectory() self.working_dir = Path(self.temp_dir.name)
self.working_dir = Path(self.temp_dir.name)
else:
self.temp_dir = None
self.working_dir = Path(working_dir)
self.working_dir.mkdir(parents=True, exist_ok=True)
self.document_shelf_filepath = self.working_dir / 'shelf.db' self.document_shelf_filepath = self.working_dir / 'shelf.db'
self.document_shelf = shelve.open(str(self.document_shelf_filepath), self.document_shelf = shelve.open(str(self.document_shelf_filepath),
flag='n', protocol=-1) flag='n', protocol=-1)
@ -237,8 +232,6 @@ def main():
parser.add_argument("--reduce_memory", action="store_true", parser.add_argument("--reduce_memory", action="store_true",
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory") help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
parser.add_argument("--working_dir", type=Path, default=None,
help="Temporary directory to use for --reduce_memory. If not set, uses TemporaryDirectory()")
parser.add_argument("--epochs_to_generate", type=int, default=3, parser.add_argument("--epochs_to_generate", type=int, default=3,
help="Number of epochs of data to pregenerate") help="Number of epochs of data to pregenerate")
@ -254,7 +247,7 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys()) vocab_list = list(tokenizer.vocab.keys())
docs = DocumentDatabase(reduce_memory=args.reduce_memory, working_dir=args.working_dir) docs = DocumentDatabase(reduce_memory=args.reduce_memory)
with args.train_corpus.open() as f: with args.train_corpus.open() as f:
doc = [] doc = []
for line in tqdm(f, desc="Loading Dataset", unit=" lines"): for line in tqdm(f, desc="Loading Dataset", unit=" lines"):