run_squad WIP
This commit is contained in:
parent
c0065af6cb
commit
e61db0d1c0
|
@ -316,7 +316,7 @@ def read_examples(input_file):
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main():
|
||||||
tf.logging.set_verbosity(tf.logging.INFO)
|
tf.logging.set_verbosity(tf.logging.INFO)
|
||||||
|
|
||||||
layer_indexes = [int(x) for x in args.layers.split(",")]
|
layer_indexes = [int(x) for x in args.layers.split(",")]
|
||||||
|
@ -387,4 +387,4 @@ def main(_):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.app.run()
|
main()
|
||||||
|
|
|
@ -441,8 +441,8 @@ class BertForSequenceClassification(nn.Module):
|
||||||
|
|
||||||
class BertForQuestionAnswering(nn.Module):
|
class BertForQuestionAnswering(nn.Module):
|
||||||
"""BERT model for Question Answering (span extraction).
|
"""BERT model for Question Answering (span extraction).
|
||||||
This module is composed of the BERT model with linear layers on top of
|
This module is composed of the BERT model with a linear layer on top of
|
||||||
the sequence output.
|
the sequence output that computes start_logits and end_logits
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python
|
```python
|
||||||
|
@ -455,7 +455,7 @@ class BertForQuestionAnswering(nn.Module):
|
||||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
||||||
|
|
||||||
model = BertForQuestionAnswering(config)
|
model = BertForQuestionAnswering(config)
|
||||||
logits = model(input_ids, token_type_ids, input_mask)
|
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import logging
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
@ -29,6 +30,14 @@ import six
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
from modeling_pytorch import BertConfig, BertForQuestionAnswering
|
||||||
|
from optimization_pytorch import BERTAdam
|
||||||
|
|
||||||
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
|
level = logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
|
@ -94,6 +103,10 @@ parser.add_argument("--num_tpu_cores", default=8, type=int, help="Only used if `
|
||||||
parser.add_argument("--verbose_logging", default=False, type=bool,
|
parser.add_argument("--verbose_logging", default=False, type=bool,
|
||||||
help="If true, all of the warnings related to data processing will be printed. "
|
help="If true, all of the warnings related to data processing will be printed. "
|
||||||
"A number of warnings are expected for a normal SQuAD evaluation.")
|
"A number of warnings are expected for a normal SQuAD evaluation.")
|
||||||
|
parser.add_argument("--local_rank",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help = "local_rank for distributed training on gpus")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -926,8 +939,15 @@ def _compute_softmax(scores):
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main():
|
||||||
tf.logging.set_verbosity(tf.logging.INFO)
|
if args.local_rank == -1 or args.no_cuda:
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||||
|
n_gpu = torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
device = torch.device("cuda", args.local_rank)
|
||||||
|
n_gpu = 1
|
||||||
|
# print("Initializing the distributed backend: NCCL")
|
||||||
|
print("device", device, "n_gpu", n_gpu)
|
||||||
|
|
||||||
if not args.do_train and not args.do_predict:
|
if not args.do_train and not args.do_predict:
|
||||||
raise ValueError("At least one of `do_train` or `do_predict` must be True.")
|
raise ValueError("At least one of `do_train` or `do_predict` must be True.")
|
||||||
|
@ -941,7 +961,7 @@ def main(_):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If `do_predict` is True, then `predict_file` must be specified.")
|
"If `do_predict` is True, then `predict_file` must be specified.")
|
||||||
|
|
||||||
bert_config = modeling.BertConfig.from_json_file(args.bert_config_file)
|
bert_config = BertConfig.from_json_file(args.bert_config_file)
|
||||||
|
|
||||||
if args.max_seq_length > bert_config.max_position_embeddings:
|
if args.max_seq_length > bert_config.max_position_embeddings:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -949,54 +969,69 @@ def main(_):
|
||||||
"was only trained up to sequence length %d" %
|
"was only trained up to sequence length %d" %
|
||||||
(args.max_seq_length, bert_config.max_position_embeddings))
|
(args.max_seq_length, bert_config.max_position_embeddings))
|
||||||
|
|
||||||
tf.gfile.MakeDirs(args.output_dir)
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||||
|
raise ValueError(f"Output directory ({args.output_dir}) already exists and is "
|
||||||
|
f"not empty.")
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
tokenizer = tokenization.FullTokenizer(
|
tokenizer = tokenization.FullTokenizer(
|
||||||
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
|
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
tpu_cluster_resolver = None
|
# tpu_cluster_resolver = None
|
||||||
if args.use_tpu and args.tpu_name:
|
# if args.use_tpu and args.tpu_name:
|
||||||
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
# tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
||||||
args.tpu_name, zone=args.tpu_zone, project=args.gcp_project)
|
# args.tpu_name, zone=args.tpu_zone, project=args.gcp_project)
|
||||||
|
|
||||||
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
# is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||||
run_config = tf.contrib.tpu.RunConfig(
|
# run_config = tf.contrib.tpu.RunConfig(
|
||||||
cluster=tpu_cluster_resolver,
|
# cluster=tpu_cluster_resolver,
|
||||||
master=args.master,
|
# master=args.master,
|
||||||
model_dir=args.output_dir,
|
# model_dir=args.output_dir,
|
||||||
save_checkpoints_steps=args.save_checkpoints_steps,
|
# save_checkpoints_steps=args.save_checkpoints_steps,
|
||||||
tpu_config=tf.contrib.tpu.TPUConfig(
|
# tpu_config=tf.contrib.tpu.TPUConfig(
|
||||||
iterations_per_loop=args.iterations_per_loop,
|
# iterations_per_loop=args.iterations_per_loop,
|
||||||
num_shards=args.num_tpu_cores,
|
# num_shards=args.num_tpu_cores,
|
||||||
per_host_input_for_training=is_per_host))
|
# per_host_input_for_training=is_per_host))
|
||||||
|
|
||||||
train_examples = None
|
train_examples = None
|
||||||
num_train_steps = None
|
num_train_steps = None
|
||||||
num_warmup_steps = None
|
# num_warmup_steps = None
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_examples = read_squad_examples(
|
train_examples = read_squad_examples(
|
||||||
input_file=args.train_file, is_training=True)
|
input_file=args.train_file, is_training=True)
|
||||||
num_train_steps = int(
|
num_train_steps = int(
|
||||||
len(train_examples) / args.train_batch_size * args.num_train_epochs)
|
len(train_examples) / args.train_batch_size * args.num_train_epochs)
|
||||||
num_warmup_steps = int(num_train_steps * args.warmup_proportion)
|
# num_warmup_steps = int(num_train_steps * args.warmup_proportion)
|
||||||
|
|
||||||
model_fn = model_fn_builder(
|
# model_fn = model_fn_builder(
|
||||||
bert_config=bert_config,
|
# bert_config=bert_config,
|
||||||
init_checkpoint=args.init_checkpoint,
|
# init_checkpoint=args.init_checkpoint,
|
||||||
learning_rate=args.learning_rate,
|
# learning_rate=args.learning_rate,
|
||||||
num_train_steps=num_train_steps,
|
# num_train_steps=num_train_steps,
|
||||||
num_warmup_steps=num_warmup_steps,
|
# num_warmup_steps=num_warmup_steps,
|
||||||
use_tpu=args.use_tpu,
|
# use_tpu=args.use_tpu,
|
||||||
use_one_hot_embeddings=args.use_tpu)
|
# use_one_hot_embeddings=args.use_tpu)
|
||||||
|
|
||||||
# If TPU is not available, this will fall back to normal Estimator on CPU
|
# If TPU is not available, this will fall back to normal Estimator on CPU
|
||||||
# or GPU.
|
# or GPU.
|
||||||
estimator = tf.contrib.tpu.TPUEstimator(
|
# estimator = tf.contrib.tpu.TPUEstimator(
|
||||||
use_tpu=args.use_tpu,
|
# use_tpu=args.use_tpu,
|
||||||
model_fn=model_fn,
|
# model_fn=model_fn,
|
||||||
config=run_config,
|
# config=run_config,
|
||||||
train_batch_size=args.train_batch_size,
|
# train_batch_size=args.train_batch_size,
|
||||||
predict_batch_size=args.predict_batch_size)
|
# predict_batch_size=args.predict_batch_size)
|
||||||
|
|
||||||
|
model = BertForQuestionAnswering(bert_config)
|
||||||
|
if args.init_checkpoint is not None:
|
||||||
|
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
|
||||||
|
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
|
||||||
|
],
|
||||||
|
lr=args.learning_rate, schedule='warmup_linear',
|
||||||
|
warmup=args.warmup_proportion,
|
||||||
|
t_total=num_train_steps)
|
||||||
|
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_features = convert_examples_to_features(
|
train_features = convert_examples_to_features(
|
||||||
|
@ -1067,4 +1102,4 @@ def main(_):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.app.run()
|
main()
|
||||||
|
|
Loading…
Reference in New Issue