`input_fn_builder` WIP

This commit is contained in:
VictorSanh 2018-11-01 02:10:46 -04:00
parent f8e347b557
commit 555b7d66c9
1 changed files with 62 additions and 1 deletions

View File

@ -23,6 +23,7 @@ import os
# import modeling_pytorch
# import optimization
import tokenization_pytorch
import torch
import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
@ -381,4 +382,64 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
tokens_b.pop()
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
labels, num_labels, use_one_hot_embeddings):
raise NotImplementedError()
def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps,
use_one_hot_embeddings):
raise NotImplementedError()
### ATTENTION - I removed the `use_tpu` argument
def input_fn_builder(features, seq_length, is_training, drop_remainder):
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ###
all_input_ids = []
all_input_mask = []
all_segment_ids = []
all_label_ids = []
for feature in features:
all_input_ids.append(feature.input_ids)
all_input_mask.append(feature.input_mask)
all_segment_ids.append(feature.segment_ids)
all_label_ids.append(feature.label_id)
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
num_examples = len(features)
# This is for demo purposes and does NOT scale to large data sets. We do
# not use Dataset.from_generator() because that uses tf.py_func which is
# not TPU compatible. The right way to load data is with TFRecordReader.
d = tf.data.Dataset.from_tensor_slices({
"input_ids":
torch.Tensor(all_input_ids, size=[num_examples, seq_length],
dtype=torch.int32, requires_grad=False),
"input_mask":
torch.Tensor(all_input_mask, size=[num_examples, seq_length],
dtype=torch.int32, requires_grad=False),
"segment_ids":
torch.Tensor(all_segment_ids, size=[num_examples, seq_length],
dtype=torch.int32, requires_grad=False),
"label_ids":
torch.Tensor(all_label_ids, size=[num_examples],
dtype=torch.int32, requires_grad=False)
})
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
return d
return input_fn