`input_fn_builder` WIP
This commit is contained in:
parent
f8e347b557
commit
555b7d66c9
|
@ -23,6 +23,7 @@ import os
|
|||
# import modeling_pytorch
|
||||
# import optimization
|
||||
import tokenization_pytorch
|
||||
import torch
|
||||
|
||||
import logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
|
@ -381,4 +382,64 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
|||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
tokens_b.pop()
|
||||
|
||||
|
||||
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
|
||||
labels, num_labels, use_one_hot_embeddings):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
|
||||
num_train_steps, num_warmup_steps,
|
||||
use_one_hot_embeddings):
|
||||
raise NotImplementedError()
|
||||
### ATTENTION - I removed the `use_tpu` argument
|
||||
|
||||
|
||||
def input_fn_builder(features, seq_length, is_training, drop_remainder):
|
||||
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ###
|
||||
|
||||
all_input_ids = []
|
||||
all_input_mask = []
|
||||
all_segment_ids = []
|
||||
all_label_ids = []
|
||||
|
||||
for feature in features:
|
||||
all_input_ids.append(feature.input_ids)
|
||||
all_input_mask.append(feature.input_mask)
|
||||
all_segment_ids.append(feature.segment_ids)
|
||||
all_label_ids.append(feature.label_id)
|
||||
|
||||
def input_fn(params):
|
||||
"""The actual input function."""
|
||||
batch_size = params["batch_size"]
|
||||
|
||||
num_examples = len(features)
|
||||
|
||||
# This is for demo purposes and does NOT scale to large data sets. We do
|
||||
# not use Dataset.from_generator() because that uses tf.py_func which is
|
||||
# not TPU compatible. The right way to load data is with TFRecordReader.
|
||||
d = tf.data.Dataset.from_tensor_slices({
|
||||
"input_ids":
|
||||
torch.Tensor(all_input_ids, size=[num_examples, seq_length],
|
||||
dtype=torch.int32, requires_grad=False),
|
||||
"input_mask":
|
||||
torch.Tensor(all_input_mask, size=[num_examples, seq_length],
|
||||
dtype=torch.int32, requires_grad=False),
|
||||
"segment_ids":
|
||||
torch.Tensor(all_segment_ids, size=[num_examples, seq_length],
|
||||
dtype=torch.int32, requires_grad=False),
|
||||
"label_ids":
|
||||
torch.Tensor(all_label_ids, size=[num_examples],
|
||||
dtype=torch.int32, requires_grad=False)
|
||||
})
|
||||
|
||||
if is_training:
|
||||
d = d.repeat()
|
||||
d = d.shuffle(buffer_size=100)
|
||||
|
||||
d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
|
||||
return d
|
||||
|
||||
return input_fn
|
||||
|
|
Loading…
Reference in New Issue