This commit is contained in:
VictorSanh 2018-11-01 12:09:23 -04:00
parent f8e7c95db3
commit 90d360a7a9
1 changed files with 150 additions and 20 deletions

View File

@ -115,7 +115,11 @@ parser.add_argument("--iterations_per_loop",
default = 1000,
type = int,
help = "How many steps to make in each estimator call.")
parser.add_argument("--use_gpu",
default = True,
type = bool,
help = "Whether to use GPU")
### BEGIN - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
parser.add_argument("--use_tpu",
default = False,
@ -416,25 +420,18 @@ def input_fn_builder(features, seq_length, is_training, drop_remainder):
batch_size = params["batch_size"]
num_examples = len(features)
# This is for demo purposes and does NOT scale to large data sets. We do
# not use Dataset.from_generator() because that uses tf.py_func which is
# not TPU compatible. The right way to load data is with TFRecordReader.
d = tf.data.Dataset.from_tensor_slices({
"input_ids":
torch.Tensor(all_input_ids, size=[num_examples, seq_length],
dtype=torch.int32, requires_grad=False),
"input_mask":
torch.Tensor(all_input_mask, size=[num_examples, seq_length],
dtype=torch.int32, requires_grad=False),
"segment_ids":
torch.Tensor(all_segment_ids, size=[num_examples, seq_length],
dtype=torch.int32, requires_grad=False),
"label_ids":
torch.Tensor(all_label_ids, size=[num_examples],
dtype=torch.int32, requires_grad=False)
})
device = torch.device("cuda") if args.use_gpu else torch.device("cpu")
d = {"input_ids":
torch.IntTensor(all_input_ids, device = device), #Requires_grad=False by default
"input_mask":
torch.IntTensor(all_input_mask, device = device),
"segment_ids":
torch.IntTensor(all_segment_ids, device = device),
"label_ids":
torch.IntTensor(all_label_ids, device = device)
}
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
@ -443,3 +440,136 @@ def input_fn_builder(features, seq_length, is_training, drop_remainder):
return d
return input_fn
def main(_):
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
}
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
bert_config = modeling.BertConfig.from_json_file(args.bert_config_file)
if args.max_seq_length > bert_config.max_position_embeddings:
raise ValueError(
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d" %
(args.max_seq_length, bert_config.max_position_embeddings))
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ConfigurationError(f"Output directory ({args.output_dir}) already exists and is "
f"not empty.")
os.makedirs(args.output_dir, exist_ok=True)
task_name = args.task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]()
label_list = processor.get_labels()
tokenizer = tokenization.FullTokenizer(
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
# tpu_cluster_resolver = None
# if FLAGS.use_tpu and FLAGS.tpu_name:
# tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
# FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
# is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
# run_config = tf.contrib.tpu.RunConfig(
# cluster=tpu_cluster_resolver,
# master=FLAGS.master,
# model_dir=FLAGS.output_dir,
# save_checkpoints_steps=FLAGS.save_checkpoints_steps,
# tpu_config=tf.contrib.tpu.TPUConfig(
# iterations_per_loop=FLAGS.iterations_per_loop,
# num_shards=FLAGS.num_tpu_cores,
# per_host_input_for_training=is_per_host))
train_examples = None
num_train_steps = None
num_warmup_steps = None
if args.do_train:
train_examples = processor.get_train_examples(args.data_dir)
num_train_steps = int(
len(train_examples) / args.train_batch_size * args.num_train_epochs)
num_warmup_steps = int(num_train_steps * args.warmup_proportion)
model_fn = model_fn_builder(
bert_config=bert_config,
num_labels=len(label_list),
init_checkpoint=args.init_checkpoint,
learning_rate=args.learning_rate,
num_train_steps=num_train_steps,
num_warmup_steps=num_warmup_steps,
use_gpu=args.use_gpu,
use_one_hot_embeddings=args.use_gpu) ### TO DO - to check when model_fn is written)
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU. - TO DO
for batch in
estimator = tf.contrib.tpu.TPUEstimator(
use_tpu=args.use_tpu,
model_fn=model_fn,
config=run_config,
train_batch_size=args.train_batch_size,
eval_batch_size=args.eval_batch_size)
if args.do_train:
train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps)
train_input_fn = input_fn_builder(
features=train_features,
seq_length=args.max_seq_length,
is_training=True,
drop_remainder=True)
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
if args.do_eval:
eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer)
tf.logging.info("***** Running evaluation *****")
tf.logging.info(" Num examples = %d", len(eval_examples))
tf.logging.info(" Batch size = %d", args.eval_batch_size)
# This tells the estimator to run through the entire set.
eval_steps = None
# However, if running eval on the TPU, you will need to specify the
# number of steps.
if args.use_tpu:
# Eval will be slightly WRONG on the TPU because it will truncate
# the last batch.
eval_steps = int(len(eval_examples) / args.eval_batch_size)
eval_drop_remainder = True if args.use_tpu else False
eval_input_fn = input_fn_builder(
features=eval_features,
seq_length=args.max_seq_length,
is_training=False,
drop_remainder=eval_drop_remainder)
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with tf.gfile.GFile(output_eval_file, "w") as writer:
tf.logging.info("***** Eval results *****")
for key in sorted(result.keys()):
tf.logging.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if __name__ == "__main__":
main()
return None