run_squad WIP
This commit is contained in:
parent
c0065af6cb
commit
e61db0d1c0
|
@ -316,7 +316,7 @@ def read_examples(input_file):
|
|||
return examples
|
||||
|
||||
|
||||
def main(_):
|
||||
def main():
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
layer_indexes = [int(x) for x in args.layers.split(",")]
|
||||
|
@ -387,4 +387,4 @@ def main(_):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
main()
|
||||
|
|
|
@ -441,8 +441,8 @@ class BertForSequenceClassification(nn.Module):
|
|||
|
||||
class BertForQuestionAnswering(nn.Module):
|
||||
"""BERT model for Question Answering (span extraction).
|
||||
This module is composed of the BERT model with linear layers on top of
|
||||
the sequence output.
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
the sequence output that computes start_logits and end_logits
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
|
@ -455,7 +455,7 @@ class BertForQuestionAnswering(nn.Module):
|
|||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
||||
|
||||
model = BertForQuestionAnswering(config)
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
@ -29,6 +30,14 @@ import six
|
|||
import tensorflow as tf
|
||||
import argparse
|
||||
|
||||
from modeling_pytorch import BertConfig, BertForQuestionAnswering
|
||||
from optimization_pytorch import BERTAdam
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
|
@ -94,6 +103,10 @@ parser.add_argument("--num_tpu_cores", default=8, type=int, help="Only used if `
|
|||
parser.add_argument("--verbose_logging", default=False, type=bool,
|
||||
help="If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation.")
|
||||
parser.add_argument("--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help = "local_rank for distributed training on gpus")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -926,8 +939,15 @@ def _compute_softmax(scores):
|
|||
return probs
|
||||
|
||||
|
||||
def main(_):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
def main():
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
else:
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
n_gpu = 1
|
||||
# print("Initializing the distributed backend: NCCL")
|
||||
print("device", device, "n_gpu", n_gpu)
|
||||
|
||||
if not args.do_train and not args.do_predict:
|
||||
raise ValueError("At least one of `do_train` or `do_predict` must be True.")
|
||||
|
@ -941,7 +961,7 @@ def main(_):
|
|||
raise ValueError(
|
||||
"If `do_predict` is True, then `predict_file` must be specified.")
|
||||
|
||||
bert_config = modeling.BertConfig.from_json_file(args.bert_config_file)
|
||||
bert_config = BertConfig.from_json_file(args.bert_config_file)
|
||||
|
||||
if args.max_seq_length > bert_config.max_position_embeddings:
|
||||
raise ValueError(
|
||||
|
@ -949,54 +969,69 @@ def main(_):
|
|||
"was only trained up to sequence length %d" %
|
||||
(args.max_seq_length, bert_config.max_position_embeddings))
|
||||
|
||||
tf.gfile.MakeDirs(args.output_dir)
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||
raise ValueError(f"Output directory ({args.output_dir}) already exists and is "
|
||||
f"not empty.")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
|
||||
|
||||
tpu_cluster_resolver = None
|
||||
if args.use_tpu and args.tpu_name:
|
||||
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
||||
args.tpu_name, zone=args.tpu_zone, project=args.gcp_project)
|
||||
# tpu_cluster_resolver = None
|
||||
# if args.use_tpu and args.tpu_name:
|
||||
# tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
||||
# args.tpu_name, zone=args.tpu_zone, project=args.gcp_project)
|
||||
|
||||
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||
run_config = tf.contrib.tpu.RunConfig(
|
||||
cluster=tpu_cluster_resolver,
|
||||
master=args.master,
|
||||
model_dir=args.output_dir,
|
||||
save_checkpoints_steps=args.save_checkpoints_steps,
|
||||
tpu_config=tf.contrib.tpu.TPUConfig(
|
||||
iterations_per_loop=args.iterations_per_loop,
|
||||
num_shards=args.num_tpu_cores,
|
||||
per_host_input_for_training=is_per_host))
|
||||
# is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||
# run_config = tf.contrib.tpu.RunConfig(
|
||||
# cluster=tpu_cluster_resolver,
|
||||
# master=args.master,
|
||||
# model_dir=args.output_dir,
|
||||
# save_checkpoints_steps=args.save_checkpoints_steps,
|
||||
# tpu_config=tf.contrib.tpu.TPUConfig(
|
||||
# iterations_per_loop=args.iterations_per_loop,
|
||||
# num_shards=args.num_tpu_cores,
|
||||
# per_host_input_for_training=is_per_host))
|
||||
|
||||
train_examples = None
|
||||
num_train_steps = None
|
||||
num_warmup_steps = None
|
||||
# num_warmup_steps = None
|
||||
if args.do_train:
|
||||
train_examples = read_squad_examples(
|
||||
input_file=args.train_file, is_training=True)
|
||||
num_train_steps = int(
|
||||
len(train_examples) / args.train_batch_size * args.num_train_epochs)
|
||||
num_warmup_steps = int(num_train_steps * args.warmup_proportion)
|
||||
# num_warmup_steps = int(num_train_steps * args.warmup_proportion)
|
||||
|
||||
model_fn = model_fn_builder(
|
||||
bert_config=bert_config,
|
||||
init_checkpoint=args.init_checkpoint,
|
||||
learning_rate=args.learning_rate,
|
||||
num_train_steps=num_train_steps,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
use_tpu=args.use_tpu,
|
||||
use_one_hot_embeddings=args.use_tpu)
|
||||
# model_fn = model_fn_builder(
|
||||
# bert_config=bert_config,
|
||||
# init_checkpoint=args.init_checkpoint,
|
||||
# learning_rate=args.learning_rate,
|
||||
# num_train_steps=num_train_steps,
|
||||
# num_warmup_steps=num_warmup_steps,
|
||||
# use_tpu=args.use_tpu,
|
||||
# use_one_hot_embeddings=args.use_tpu)
|
||||
|
||||
# If TPU is not available, this will fall back to normal Estimator on CPU
|
||||
# or GPU.
|
||||
estimator = tf.contrib.tpu.TPUEstimator(
|
||||
use_tpu=args.use_tpu,
|
||||
model_fn=model_fn,
|
||||
config=run_config,
|
||||
train_batch_size=args.train_batch_size,
|
||||
predict_batch_size=args.predict_batch_size)
|
||||
# estimator = tf.contrib.tpu.TPUEstimator(
|
||||
# use_tpu=args.use_tpu,
|
||||
# model_fn=model_fn,
|
||||
# config=run_config,
|
||||
# train_batch_size=args.train_batch_size,
|
||||
# predict_batch_size=args.predict_batch_size)
|
||||
|
||||
model = BertForQuestionAnswering(bert_config)
|
||||
if args.init_checkpoint is not None:
|
||||
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||
model.to(device)
|
||||
|
||||
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
|
||||
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
|
||||
],
|
||||
lr=args.learning_rate, schedule='warmup_linear',
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=num_train_steps)
|
||||
|
||||
if args.do_train:
|
||||
train_features = convert_examples_to_features(
|
||||
|
@ -1067,4 +1102,4 @@ def main(_):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue