Reduced memory usage for pregenerating the data a lot by writing it
out on the fly without shuffling - the Sampler in the finetuning script will shuffle for us.
This commit is contained in:
parent
6a9038ba53
commit
0ae59e662d
|
@ -73,7 +73,10 @@ class PregeneratedDataset(Dataset):
|
|||
logging.info(f"Loading training examples for epoch {epoch}")
|
||||
with data_file.open() as f:
|
||||
for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")):
|
||||
example = json.loads(line.rstrip())
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue # Skip trailing blank lines etc.
|
||||
example = json.loads(line)
|
||||
features = convert_example_to_features(example, tokenizer, seq_len)
|
||||
input_ids[i] = features.input_ids
|
||||
segment_ids[i] = features.segment_ids
|
||||
|
|
|
@ -242,24 +242,22 @@ def main():
|
|||
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
|
||||
# Google BERT doesn't do this, and as a result oversamples shorter docs
|
||||
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
||||
epoch_instances = []
|
||||
for doc_idx in trange(len(docs), desc="Document"):
|
||||
doc_instances = create_instances_from_document(
|
||||
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
|
||||
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
|
||||
vocab_list=vocab_list)
|
||||
doc_instances = [json.dumps(instance) for instance in doc_instances]
|
||||
epoch_instances.extend(doc_instances)
|
||||
|
||||
shuffle(epoch_instances)
|
||||
epoch_file = args.output_dir / f"epoch_{epoch}.json"
|
||||
epoch_filename = args.output_dir / f"epoch_{epoch}.json"
|
||||
num_instances = 0
|
||||
with epoch_filename.open('w') as epoch_file:
|
||||
for doc_idx in trange(len(docs), desc="Document"):
|
||||
doc_instances = create_instances_from_document(
|
||||
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
|
||||
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
|
||||
vocab_list=vocab_list)
|
||||
doc_instances = [json.dumps(instance) for instance in doc_instances]
|
||||
for instance in doc_instances:
|
||||
epoch_file.write(instance + '\n')
|
||||
num_instances += 1
|
||||
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
|
||||
with epoch_file.open('w') as out_file:
|
||||
for instance in epoch_instances:
|
||||
out_file.write(instance + '\n')
|
||||
with metrics_file.open('w') as metrics_file:
|
||||
metrics = {
|
||||
"num_training_examples": len(epoch_instances),
|
||||
"num_training_examples": num_instances,
|
||||
"max_seq_len": args.max_seq_len
|
||||
}
|
||||
metrics_file.write(json.dumps(metrics))
|
||||
|
|
Loading…
Reference in New Issue