Added a --reduce_memory option to the training script to keep training
data on disc as a memmap rather than in memory
This commit is contained in:
parent
2bba7f810e
commit
7d1ae644ef
|
@ -6,6 +6,7 @@ import json
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
@ -53,8 +54,7 @@ def convert_example_to_features(example, tokenizer, max_seq_length):
|
||||||
|
|
||||||
|
|
||||||
class PregeneratedDataset(Dataset):
|
class PregeneratedDataset(Dataset):
|
||||||
def __init__(self, training_path, epoch, tokenizer, num_data_epochs):
|
def __init__(self, training_path, epoch, tokenizer, num_data_epochs, reduce_memory=False):
|
||||||
# TODO Add an option to memmap the training data if needed (see note in pregenerate_training_data)
|
|
||||||
self.vocab = tokenizer.vocab
|
self.vocab = tokenizer.vocab
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
|
@ -65,11 +65,28 @@ class PregeneratedDataset(Dataset):
|
||||||
metrics = json.loads(metrics_file.read_text())
|
metrics = json.loads(metrics_file.read_text())
|
||||||
num_samples = metrics['num_training_examples']
|
num_samples = metrics['num_training_examples']
|
||||||
seq_len = metrics['max_seq_len']
|
seq_len = metrics['max_seq_len']
|
||||||
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
|
self.temp_dir = None
|
||||||
input_masks = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
|
self.working_dir = None
|
||||||
segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
|
if reduce_memory:
|
||||||
lm_label_ids = np.full(shape=(num_samples, seq_len), dtype=np.int32, fill_value=-1)
|
self.temp_dir = TemporaryDirectory()
|
||||||
is_nexts = np.zeros(shape=(num_samples,), dtype=np.bool)
|
self.working_dir = Path(self.temp_dir.name)
|
||||||
|
input_ids = np.memmap(filename=self.working_dir/'input_ids.memmap',
|
||||||
|
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
|
||||||
|
input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap',
|
||||||
|
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
||||||
|
segment_ids = np.memmap(filename=self.working_dir/'input_masks.memmap',
|
||||||
|
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
||||||
|
lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap',
|
||||||
|
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
|
||||||
|
lm_label_ids[:] = -1
|
||||||
|
is_nexts = np.memmap(filename=self.working_dir/'is_nexts.memmap',
|
||||||
|
shape=(num_samples,), mode='w+', dtype=np.bool)
|
||||||
|
else:
|
||||||
|
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
|
||||||
|
input_masks = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
|
||||||
|
segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
|
||||||
|
lm_label_ids = np.full(shape=(num_samples, seq_len), dtype=np.int32, fill_value=-1)
|
||||||
|
is_nexts = np.zeros(shape=(num_samples,), dtype=np.bool)
|
||||||
logging.info(f"Loading training examples for epoch {epoch}")
|
logging.info(f"Loading training examples for epoch {epoch}")
|
||||||
with data_file.open() as f:
|
with data_file.open() as f:
|
||||||
for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")):
|
for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")):
|
||||||
|
@ -110,6 +127,8 @@ def main():
|
||||||
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
|
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
|
||||||
"bert-base-multilingual", "bert-base-chinese"])
|
"bert-base-multilingual", "bert-base-chinese"])
|
||||||
parser.add_argument("--do_lower_case", action="store_true")
|
parser.add_argument("--do_lower_case", action="store_true")
|
||||||
|
parser.add_argument("--reduce_memory", action="store_true",
|
||||||
|
help="Store training data as on-disc memmaps to massively reduce memory usage")
|
||||||
|
|
||||||
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for")
|
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for")
|
||||||
parser.add_argument("--local_rank",
|
parser.add_argument("--local_rank",
|
||||||
|
@ -311,7 +330,5 @@ def main():
|
||||||
torch.save(model_to_save.state_dict(), str(output_model_file))
|
torch.save(model_to_save.state_dict(), str(output_model_file))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -11,15 +11,10 @@ import json
|
||||||
|
|
||||||
|
|
||||||
class DocumentDatabase:
|
class DocumentDatabase:
|
||||||
def __init__(self, reduce_memory=False, working_dir=None):
|
def __init__(self, reduce_memory=False):
|
||||||
if reduce_memory:
|
if reduce_memory:
|
||||||
if working_dir is None:
|
self.temp_dir = TemporaryDirectory()
|
||||||
self.temp_dir = TemporaryDirectory()
|
self.working_dir = Path(self.temp_dir.name)
|
||||||
self.working_dir = Path(self.temp_dir.name)
|
|
||||||
else:
|
|
||||||
self.temp_dir = None
|
|
||||||
self.working_dir = Path(working_dir)
|
|
||||||
self.working_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
self.document_shelf_filepath = self.working_dir / 'shelf.db'
|
self.document_shelf_filepath = self.working_dir / 'shelf.db'
|
||||||
self.document_shelf = shelve.open(str(self.document_shelf_filepath),
|
self.document_shelf = shelve.open(str(self.document_shelf_filepath),
|
||||||
flag='n', protocol=-1)
|
flag='n', protocol=-1)
|
||||||
|
@ -237,8 +232,6 @@ def main():
|
||||||
|
|
||||||
parser.add_argument("--reduce_memory", action="store_true",
|
parser.add_argument("--reduce_memory", action="store_true",
|
||||||
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
|
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
|
||||||
parser.add_argument("--working_dir", type=Path, default=None,
|
|
||||||
help="Temporary directory to use for --reduce_memory. If not set, uses TemporaryDirectory()")
|
|
||||||
|
|
||||||
parser.add_argument("--epochs_to_generate", type=int, default=3,
|
parser.add_argument("--epochs_to_generate", type=int, default=3,
|
||||||
help="Number of epochs of data to pregenerate")
|
help="Number of epochs of data to pregenerate")
|
||||||
|
@ -254,7 +247,7 @@ def main():
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
vocab_list = list(tokenizer.vocab.keys())
|
vocab_list = list(tokenizer.vocab.keys())
|
||||||
docs = DocumentDatabase(reduce_memory=args.reduce_memory, working_dir=args.working_dir)
|
docs = DocumentDatabase(reduce_memory=args.reduce_memory)
|
||||||
with args.train_corpus.open() as f:
|
with args.train_corpus.open() as f:
|
||||||
doc = []
|
doc = []
|
||||||
for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
|
for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
|
||||||
|
|
Loading…
Reference in New Issue