New squad example (#8992)

* Add new SQUAD example

* Same with a task-specific Trainer

* Address review comment.

* Small fixes

* Initial work for XLNet

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Final clean up and working XLNet script

* Test and debug

* Final working version

* Add new SQUAD example

* Same with a task-specific Trainer

* Address review comment.

* Small fixes

* Initial work for XLNet

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Final clean up and working XLNet script

* Test and debug

* Final working version

* Add tick

* Update README

* Address review comments

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger 2020-12-08 14:39:29 -05:00 committed by GitHub
parent 7809eb82ae
commit 447808c85f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 2205 additions and 354 deletions

3
.gitignore vendored
View File

@ -159,3 +159,6 @@ tags
# pre-commit
.pre-commit*
# .lock
*.lock

View File

@ -55,7 +55,7 @@ git checkout tags/v3.4.0
| [**`text-classification`**](https://github.com/huggingface/transformers/tree/master/examples/text-classification) | GLUE, XNLI | ✅ | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/huggingface/notebooks/blob/master/examples/text_classification.ipynb)
| [**`token-classification`**](https://github.com/huggingface/transformers/tree/master/examples/token-classification) | CoNLL NER | ✅ | ✅ | ✅ | -
| [**`multiple-choice`**](https://github.com/huggingface/transformers/tree/master/examples/multiple-choice) | SWAG, RACE, ARC | ✅ | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ViktorAlm/notebooks/blob/master/MPC_GPU_Demo_for_TF_and_PT.ipynb)
| [**`question-answering`**](https://github.com/huggingface/transformers/tree/master/examples/question-answering) | SQuAD | ✅ | ✅ | - | -
| [**`question-answering`**](https://github.com/huggingface/transformers/tree/master/examples/question-answering) | SQuAD | ✅ | ✅ | | -
| [**`text-generation`**](https://github.com/huggingface/transformers/tree/master/examples/text-generation) | - | n/a | n/a | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)
| [**`distillation`**](https://github.com/huggingface/transformers/tree/master/examples/distillation) | All | - | - | - | -
| [**`summarization`**](https://github.com/huggingface/transformers/tree/master/examples/seq2seq) | CNN/Daily Mail | ✅ | - | - | -

View File

@ -7,33 +7,17 @@ Based on the script [`run_squad.py`](https://github.com/huggingface/transformers
#### Fine-tuning BERT on SQuAD1.0
This example code fine-tunes BERT on the SQuAD1.0 dataset. It runs in 24 min (with BERT-base) or 68 min (with BERT-large)
on a single tesla V100 16GB. The data for SQuAD can be downloaded with the following links and should be saved in a
$SQUAD_DIR directory.
* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
And for SQuAD2.0, you need to download:
- [train-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json)
- [dev-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json)
- [evaluate-v2.0.py](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/)
on a single tesla V100 16GB.
```bash
export SQUAD_DIR=/path/to/SQUAD
python run_squad.py \
--model_type bert \
python run_qa.py \
--model_name_or_path bert-base-uncased \
--dataset_name squad \
--do_train \
--do_eval \
--do_lower_case \
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--per_gpu_train_batch_size 12 \
--per_device_train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2.0 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir /tmp/debug_squad/
@ -53,20 +37,17 @@ Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word
```bash
python -m torch.distributed.launch --nproc_per_node=8 ./examples/question-answering/run_squad.py \
--model_type bert \
--model_name_or_path bert-large-uncased-whole-word-masking \
--dataset_name squad \
--do_train \
--do_eval \
--do_lower_case \
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir ./examples/models/wwm_uncased_finetuned_squad/ \
--per_gpu_eval_batch_size=3 \
--per_gpu_train_batch_size=3 \
--per_device_eval_batch_size=3 \
--per_device_train_batch_size=3 \
```
Training with the previously defined hyper-parameters yields the following results:
@ -79,29 +60,25 @@ exact_match = 86.91
This fine-tuned model is available as a checkpoint under the reference
[`bert-large-uncased-whole-word-masking-finetuned-squad`](https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad).
#### Fine-tuning XLNet on SQuAD
#### Fine-tuning XLNet with beam search on SQuAD
This example code fine-tunes XLNet on both SQuAD1.0 and SQuAD2.0 dataset. See above to download the data for SQuAD .
This example code fine-tunes XLNet on both SQuAD1.0 and SQuAD2.0 dataset.
##### Command for SQuAD1.0:
```bash
export SQUAD_DIR=/path/to/SQUAD
python run_squad.py \
--model_type xlnet \
python run_qa_beam_search.py \
--model_name_or_path xlnet-large-cased \
--dataset_name squad \
--do_train \
--do_eval \
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir ./wwm_cased_finetuned_squad/ \
--per_gpu_eval_batch_size=4 \
--per_gpu_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--per_device_train_batch_size=4 \
--save_steps 5000
```
@ -110,21 +87,19 @@ python run_squad.py \
```bash
export SQUAD_DIR=/path/to/SQUAD
python run_squad.py \
--model_type xlnet \
python run_qa_beam_search.py \
--model_name_or_path xlnet-large-cased \
--dataset_name squad_v2 \
--do_train \
--do_eval \
--version_2_with_negative \
--train_file $SQUAD_DIR/train-v2.0.json \
--predict_file $SQUAD_DIR/dev-v2.0.json \
--learning_rate 3e-5 \
--num_train_epochs 4 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir ./wwm_cased_finetuned_squad/ \
--per_gpu_eval_batch_size=2 \
--per_gpu_train_batch_size=2 \
--per_device_eval_batch_size=2 \
--per_device_train_batch_size=2 \
--save_steps 5000
```
@ -162,7 +137,7 @@ Larger batch size may improve the performance while costing more memory.
#### Fine-tuning BERT on SQuAD1.0 with relative position embeddings
The following examples show how to fine-tune BERT models with different relative position embeddings. The BERT model
`bert-base-uncased` was pre-trained with default absolute position embeddings. We provide the following pre-trained
`bert-base-uncased` was pretrained with default absolute position embeddings. We provide the following pretrained
models which were pre-trained on the same training data (BooksCorpus and English Wikipedia) as in the BERT model
training, but with different relative position embeddings.
@ -178,24 +153,19 @@ in Huang et al. [Improve Transformer Models with Better Relative Position Embedd
##### Base models fine-tuning
```bash
export SQUAD_DIR=/path/to/SQUAD
output_dir=relative_squad
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m torch.distributed.launch --nproc_per_node=8 ./examples/question-answering/run_squad.py \
--model_type bert \
--model_name_or_path zhiheng-huang/bert-base-uncased-embedding-relative-key-query \
--dataset_name squad \
--do_train \
--do_eval \
--do_lower_case \
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 512 \
--doc_stride 128 \
--output_dir ${output_dir} \
--per_gpu_eval_batch_size=60 \
--per_gpu_train_batch_size=6
--output_dir relative_squad \
--per_device_eval_batch_size=60 \
--per_device_train_batch_size=6
```
Training with the above command leads to the following results. It boosts the BERT default from f1 score of 88.52 to 90.54.
@ -211,22 +181,17 @@ gpu training leads to the f1 score of 90.71.
##### Large models fine-tuning
```bash
export SQUAD_DIR=/path/to/SQUAD
output_dir=relative_squad
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m torch.distributed.launch --nproc_per_node=8 ./examples/question-answering/run_squad.py \
--model_type bert \
--model_name_or_path zhiheng-huang/bert-large-uncased-whole-word-masking-embedding-relative-key-query \
--dataset_name squad \
--do_train \
--do_eval \
--do_lower_case \
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 512 \
--doc_stride 128 \
--output_dir ${output_dir} \
--output_dir relative_squad \
--per_gpu_eval_batch_size=6 \
--per_gpu_train_batch_size=2 \
--gradient_accumulation_steps 3
@ -251,5 +216,4 @@ python run_tf_squad.py \
--doc_stride 128
```
For the moment evaluation is not available in the Tensorflow Trainer only the training.

View File

@ -0,0 +1,469 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for question answering.
"""
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset, load_metric
import transformers
from trainer_qa import QuestionAnsweringTrainer
from transformers import (
AutoConfig,
AutoModelForQuestionAnswering,
AutoTokenizer,
DataCollatorWithPadding,
EvalPrediction,
HfArgumentParser,
PreTrainedTokenizerFast,
TrainingArguments,
default_data_collator,
set_seed,
)
from transformers.trainer_utils import is_main_process
from utils_qa import postprocess_qa_predictions
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to directory to store the pretrained models downloaded from huggingface.co"},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_seq_length: int = field(
default=384,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
pad_to_max_length: bool = field(
default=True,
metadata={
"help": "Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
"be faster on GPU but will be slower on TPU)."
},
)
version_2_with_negative: bool = field(
default=False, metadata={"help": "If true, some of the examples do not have an answer."}
)
null_score_diff_threshold: float = field(
default=0.0,
metadata={
"help": "The threshold used to select the null answer: if the best answer has a score that is less than "
"the score of the null answer minus this threshold, the null answer is selected for this example. "
"Only useful when `version_2_with_negative=True`."
},
)
doc_stride: int = field(
default=128,
metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
)
n_best_size: int = field(
default=20,
metadata={"help": "The total number of n-best predictions to generate when looking for an answer."},
)
max_answer_length: int = field(
default=30,
metadata={
"help": "The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another."
},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty."
"Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
transformers.utils.logging.set_verbosity_info()
logger.info("Training/evaluation parameters %s", training_args)
# Set seed before initializing model.
set_seed(training_args.seed)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
else:
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
datasets = load_dataset(extension, data_files=data_files, field="data")
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=True,
)
model = AutoModelForQuestionAnswering.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)
# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
"This example script only works for models that have a fast tokenizer. Checkout the big table of models "
"at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this "
"requirement"
)
# Preprocessing the datasets.
# Preprocessing is slighlty different for training and evaluation.
if training_args.do_train:
column_names = datasets["train"].column_names
else:
column_names = datasets["validation"].column_names
question_column_name = "question" if "question" in column_names else column_names[0]
context_column_name = "context" if "context" in column_names else column_names[1]
answer_column_name = "answers" if "answers" in column_names else column_names[2]
# Padding side determines if we do (question|context) or (context|question).
pad_on_right = tokenizer.padding_side == "right"
# Training preprocessing
def prepare_train_features(examples):
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
# in one example possible giving several features when a context is long, each of those features having a
# context that overlaps a bit the context of the previous feature.
tokenized_examples = tokenizer(
examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first",
max_length=data_args.max_seq_length,
stride=data_args.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length" if data_args.pad_to_max_length else False,
)
# Since one example might give us several features if it has a long context, we need a map from a feature to
# its corresponding example. This key gives us just that.
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
# The offset mappings will give us a map from token to character position in the original context. This will
# help us compute the start_positions and end_positions.
offset_mapping = tokenized_examples.pop("offset_mapping")
# Let's label those examples!
tokenized_examples["start_positions"] = []
tokenized_examples["end_positions"] = []
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples.sequence_ids(i)
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = sample_mapping[i]
answers = examples[answer_column_name][sample_index]
# If no answers are given, set the cls_index as answer.
if len(answers["answer_start"]) == 0:
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
else:
# Start/end character index of the answer in the text.
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])
# Start token index of the current span in the text.
token_start_index = 0
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
token_start_index += 1
# End token index of the current span in the text.
token_end_index = len(input_ids) - 1
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
token_end_index -= 1
# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
else:
# Otherwise move the token_start_index and token_end_index to the two ends of the answer.
# Note: we could go after the last offset if the answer is the last word (edge case).
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
token_start_index += 1
tokenized_examples["start_positions"].append(token_start_index - 1)
while offsets[token_end_index][1] >= end_char:
token_end_index -= 1
tokenized_examples["end_positions"].append(token_end_index + 1)
return tokenized_examples
if training_args.do_train:
train_dataset = datasets["train"].map(
prepare_train_features,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
# Validation preprocessing
def prepare_validation_features(examples):
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
# in one example possible giving several features when a context is long, each of those features having a
# context that overlaps a bit the context of the previous feature.
tokenized_examples = tokenizer(
examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first",
max_length=data_args.max_seq_length,
stride=data_args.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length" if data_args.pad_to_max_length else False,
)
# Since one example might give us several features if it has a long context, we need a map from a feature to
# its corresponding example. This key gives us just that.
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
# For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
# corresponding example_id and we will store the offset mappings.
tokenized_examples["example_id"] = []
for i in range(len(tokenized_examples["input_ids"])):
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples.sequence_ids(i)
context_index = 1 if pad_on_right else 0
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = sample_mapping[i]
tokenized_examples["example_id"].append(examples["id"][sample_index])
# Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
# position is part of the context or not.
tokenized_examples["offset_mapping"][i] = [
(o if sequence_ids[k] == context_index else None)
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
]
return tokenized_examples
if training_args.do_eval:
validation_dataset = datasets["validation"].map(
prepare_validation_features,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
# Data collator
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
# collator.
data_collator = default_data_collator if data_args.pad_to_max_length else DataCollatorWithPadding(tokenizer)
# Post-processing:
def post_processing_function(examples, features, predictions):
# Post-processing: we match the start logits and end logits to answers in the original context.
predictions = postprocess_qa_predictions(
examples=examples,
features=features,
predictions=predictions,
version_2_with_negative=data_args.version_2_with_negative,
n_best_size=data_args.n_best_size,
max_answer_length=data_args.max_answer_length,
null_score_diff_threshold=data_args.null_score_diff_threshold,
output_dir=training_args.output_dir,
is_world_process_zero=trainer.is_world_process_zero(),
)
# Format the result to the format the metric expects.
if data_args.version_2_with_negative:
formatted_predictions = [
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
]
else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in datasets["validation"]]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
# TODO: Once the fix lands in a Datasets release, remove the _local here and the squad_v2_local folder.
current_dir = os.path.sep.join(os.path.join(__file__).split(os.path.sep)[:-1])
metric = load_metric(os.path.join(current_dir, "squad_v2_local") if data_args.version_2_with_negative else "squad")
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
# Initialize our Trainer
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=validation_dataset if training_args.do_eval else None,
eval_examples=datasets["validation"] if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
post_process_function=post_processing_function,
compute_metrics=compute_metrics,
)
# Training
if training_args.do_train:
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
)
trainer.save_model() # Saves the tokenizer too for easy upload
# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
results = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
if trainer.is_world_process_zero():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in results.items():
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
return results
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,512 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning XLNet for question answering with beam search.
"""
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset, load_metric
import transformers
from trainer_qa import QuestionAnsweringTrainer
from transformers import (
DataCollatorWithPadding,
EvalPrediction,
HfArgumentParser,
TrainingArguments,
XLNetConfig,
XLNetForQuestionAnswering,
XLNetTokenizerFast,
default_data_collator,
set_seed,
)
from transformers.trainer_utils import is_main_process
from utils_qa import postprocess_qa_predictions_with_beam_search
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_seq_length: int = field(
default=384,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
pad_to_max_length: bool = field(
default=True,
metadata={
"help": "Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
"be faster on GPU but will be slower on TPU)."
},
)
version_2_with_negative: bool = field(
default=False, metadata={"help": "If true, some of the examples do not have an answer."}
)
null_score_diff_threshold: float = field(
default=0.0,
metadata={
"help": "The threshold used to select the null answer: if the best answer has a score that is less than "
"the score of the null answer minus this threshold, the null answer is selected for this example. "
"Only useful when `version_2_with_negative=True`."
},
)
doc_stride: int = field(
default=128,
metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
)
n_best_size: int = field(
default=20,
metadata={"help": "The total number of n-best predictions to generate when looking for an answer."},
)
max_answer_length: int = field(
default=30,
metadata={
"help": "The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another."
},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty."
"Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
transformers.utils.logging.set_verbosity_info()
logger.info("Training/evaluation parameters %s", training_args)
# Set seed before initializing model.
set_seed(training_args.seed)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
else:
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
datasets = load_dataset(extension, data_files=data_files, field="data")
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
config = XLNetConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
tokenizer = XLNetTokenizerFast.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
model = XLNetForQuestionAnswering.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)
# Preprocessing the datasets.
# Preprocessing is slighlty different for training and evaluation.
if training_args.do_train:
column_names = datasets["train"].column_names
else:
column_names = datasets["validation"].column_names
question_column_name = "question" if "question" in column_names else column_names[0]
context_column_name = "context" if "context" in column_names else column_names[1]
answer_column_name = "answers" if "answers" in column_names else column_names[2]
# Padding side determines if we do (question|context) or (context|question).
pad_on_right = tokenizer.padding_side == "right"
# Training preprocessing
def prepare_train_features(examples):
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
# in one example possible giving several features when a context is long, each of those features having a
# context that overlaps a bit the context of the previous feature.
tokenized_examples = tokenizer(
examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first",
max_length=data_args.max_seq_length,
stride=data_args.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
return_special_tokens_mask=True,
return_token_type_ids=True,
padding="max_length",
)
# Since one example might give us several features if it has a long context, we need a map from a feature to
# its corresponding example. This key gives us just that.
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
# The offset mappings will give us a map from token to character position in the original context. This will
# help us compute the start_positions and end_positions.
offset_mapping = tokenized_examples.pop("offset_mapping")
# The special tokens will help us build the p_mask (which indicates the tokens that can't be in answers).
special_tokens = tokenized_examples.pop("special_tokens_mask")
# Let's label those examples!
tokenized_examples["start_positions"] = []
tokenized_examples["end_positions"] = []
tokenized_examples["is_impossible"] = []
tokenized_examples["cls_index"] = []
tokenized_examples["p_mask"] = []
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
tokenized_examples["cls_index"].append(cls_index)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples["token_type_ids"][i]
for k, s in enumerate(special_tokens[i]):
if s:
sequence_ids[k] = 3
context_idx = 1 if pad_on_right else 0
# Build the p_mask: non special tokens and context gets 0.0, the others get 1.0.
# The cls token gets 1.0 too (for predictions of empty answers).
tokenized_examples["p_mask"].append(
[
0.0 if (not special_tokens[i][k] and s == context_idx) or k == cls_index else 1.0
for k, s in enumerate(sequence_ids)
]
)
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = sample_mapping[i]
answers = examples[answer_column_name][sample_index]
# If no answers are given, set the cls_index as answer.
if len(answers["answer_start"]) == 0:
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
tokenized_examples["is_impossible"].append(1.0)
else:
# Start/end character index of the answer in the text.
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])
# Start token index of the current span in the text.
token_start_index = 0
while sequence_ids[token_start_index] != context_idx:
token_start_index += 1
# End token index of the current span in the text.
token_end_index = len(input_ids) - 1
while sequence_ids[token_end_index] != context_idx:
token_end_index -= 1
# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
tokenized_examples["is_impossible"].append(1.0)
else:
# Otherwise move the token_start_index and token_end_index to the two ends of the answer.
# Note: we could go after the last offset if the answer is the last word (edge case).
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
token_start_index += 1
tokenized_examples["start_positions"].append(token_start_index - 1)
while offsets[token_end_index][1] >= end_char:
token_end_index -= 1
tokenized_examples["end_positions"].append(token_end_index + 1)
tokenized_examples["is_impossible"].append(0.0)
return tokenized_examples
if training_args.do_train:
train_dataset = datasets["train"].map(
prepare_train_features,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
# Validation preprocessing
def prepare_validation_features(examples):
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
# in one example possible giving several features when a context is long, each of those features having a
# context that overlaps a bit the context of the previous feature.
tokenized_examples = tokenizer(
examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first",
max_length=data_args.max_seq_length,
stride=data_args.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
return_special_tokens_mask=True,
return_token_type_ids=True,
padding="max_length",
)
# Since one example might give us several features if it has a long context, we need a map from a feature to
# its corresponding example. This key gives us just that.
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
# The special tokens will help us build the p_mask (which indicates the tokens that can't be in answers).
special_tokens = tokenized_examples.pop("special_tokens_mask")
# For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
# corresponding example_id and we will store the offset mappings.
tokenized_examples["example_id"] = []
# We still provide the index of the CLS token and the p_mask to the model, but not the is_impossible label.
tokenized_examples["cls_index"] = []
tokenized_examples["p_mask"] = []
for i, input_ids in enumerate(tokenized_examples["input_ids"]):
# Find the CLS token in the input ids.
cls_index = input_ids.index(tokenizer.cls_token_id)
tokenized_examples["cls_index"].append(cls_index)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples["token_type_ids"][i]
for k, s in enumerate(special_tokens[i]):
if s:
sequence_ids[k] = 3
context_idx = 1 if pad_on_right else 0
# Build the p_mask: non special tokens and context gets 0.0, the others 1.0.
tokenized_examples["p_mask"].append(
[
0.0 if (not special_tokens[i][k] and s == context_idx) or k == cls_index else 1.0
for k, s in enumerate(sequence_ids)
]
)
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = sample_mapping[i]
tokenized_examples["example_id"].append(examples["id"][sample_index])
# Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
# position is part of the context or not.
tokenized_examples["offset_mapping"][i] = [
(o if sequence_ids[k] == context_idx else None)
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
]
return tokenized_examples
if training_args.do_eval:
validation_dataset = datasets["validation"].map(
prepare_validation_features,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
# Data collator
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
# collator.
data_collator = default_data_collator if data_args.pad_to_max_length else DataCollatorWithPadding(tokenizer)
# Post-processing:
def post_processing_function(examples, features, predictions):
# Post-processing: we match the start logits and end logits to answers in the original context.
predictions, scores_diff_json = postprocess_qa_predictions_with_beam_search(
examples=examples,
features=features,
predictions=predictions,
version_2_with_negative=data_args.version_2_with_negative,
n_best_size=data_args.n_best_size,
max_answer_length=data_args.max_answer_length,
start_n_top=model.config.start_n_top,
end_n_top=model.config.end_n_top,
output_dir=training_args.output_dir,
is_world_process_zero=trainer.is_world_process_zero(),
)
# Format the result to the format the metric expects.
if data_args.version_2_with_negative:
formatted_predictions = [
{"id": k, "prediction_text": v, "no_answer_probability": scores_diff_json[k]}
for k, v in predictions.items()
]
else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in datasets["validation"]]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
# TODO: Once the fix lands in a Datasets release, remove the _local here and the squad_v2_local folder.
current_dir = os.path.sep.join(os.path.join(__file__).split(os.path.sep)[:-1])
metric = load_metric(os.path.join(current_dir, "squad_v2_local") if data_args.version_2_with_negative else "squad")
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
# Initialize our Trainer
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=validation_dataset if training_args.do_eval else None,
eval_examples=datasets["validation"] if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
post_process_function=post_processing_function,
compute_metrics=compute_metrics,
)
# Training
if training_args.do_train:
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
)
trainer.save_model() # Saves the tokenizer too for easy upload
# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
results = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
if trainer.is_world_process_zero():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in results.items():
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
return results
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,322 @@
"""Official evaluation script for SQuAD version 2.0.
In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
import argparse
import collections
import json
import os
import re
import string
import sys
import numpy as np
OPTS = None
def parse_args():
parser = argparse.ArgumentParser("Official evaluation script for SQuAD version 2.0.")
parser.add_argument("data_file", metavar="data.json", help="Input data JSON file.")
parser.add_argument("pred_file", metavar="pred.json", help="Model predictions.")
parser.add_argument(
"--out-file", "-o", metavar="eval.json", help="Write accuracy metrics to file (default is stdout)."
)
parser.add_argument(
"--na-prob-file", "-n", metavar="na_prob.json", help="Model estimates of probability of no answer."
)
parser.add_argument(
"--na-prob-thresh",
"-t",
type=float,
default=1.0,
help='Predict "" if no-answer probability exceeds this (default = 1.0).',
)
parser.add_argument(
"--out-image-dir", "-p", metavar="out_images", default=None, help="Save precision-recall curves to directory."
)
parser.add_argument("--verbose", "-v", action="store_true")
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def make_qid_to_has_ans(dataset):
qid_to_has_ans = {}
for article in dataset:
for p in article["paragraphs"]:
for qa in p["qas"]:
qid_to_has_ans[qa["id"]] = bool(qa["answers"]["text"])
return qid_to_has_ans
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()
def compute_exact(a_gold, a_pred):
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
def compute_f1(a_gold, a_pred):
gold_toks = get_tokens(a_gold)
pred_toks = get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def get_raw_scores(dataset, preds):
exact_scores = {}
f1_scores = {}
for article in dataset:
for p in article["paragraphs"]:
for qa in p["qas"]:
qid = qa["id"]
gold_answers = [t for t in qa["answers"]["text"] if normalize_answer(t)]
if not gold_answers:
# For unanswerable questions, only correct answer is empty string
gold_answers = [""]
if qid not in preds:
print("Missing prediction for %s" % qid)
continue
a_pred = preds[qid]
# Take max over all gold answers
exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
return exact_scores, f1_scores
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
new_scores = {}
for qid, s in scores.items():
pred_na = na_probs[qid] > na_prob_thresh
if pred_na:
new_scores[qid] = float(not qid_to_has_ans[qid])
else:
new_scores[qid] = s
return new_scores
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
if not qid_list:
total = len(exact_scores)
return collections.OrderedDict(
[
("exact", 100.0 * sum(exact_scores.values()) / total),
("f1", 100.0 * sum(f1_scores.values()) / total),
("total", total),
]
)
else:
total = len(qid_list)
return collections.OrderedDict(
[
("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
("total", total),
]
)
def merge_eval(main_eval, new_eval, prefix):
for k in new_eval:
main_eval["%s_%s" % (prefix, k)] = new_eval[k]
def plot_pr_curve(precisions, recalls, out_image, title):
plt.step(recalls, precisions, color="b", alpha=0.2, where="post")
plt.fill_between(recalls, precisions, step="post", alpha=0.2, color="b")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.xlim([0.0, 1.05])
plt.ylim([0.0, 1.05])
plt.title(title)
plt.savefig(out_image)
plt.clf()
def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None):
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
true_pos = 0.0
cur_p = 1.0
cur_r = 0.0
precisions = [1.0]
recalls = [0.0]
avg_prec = 0.0
for i, qid in enumerate(qid_list):
if qid_to_has_ans[qid]:
true_pos += scores[qid]
cur_p = true_pos / float(i + 1)
cur_r = true_pos / float(num_true_pos)
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]:
# i.e., if we can put a threshold after this point
avg_prec += cur_p * (cur_r - recalls[-1])
precisions.append(cur_p)
recalls.append(cur_r)
if out_image:
plot_pr_curve(precisions, recalls, out_image, title)
return {"ap": 100.0 * avg_prec}
def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, out_image_dir):
if out_image_dir and not os.path.exists(out_image_dir):
os.makedirs(out_image_dir)
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
if num_true_pos == 0:
return
pr_exact = make_precision_recall_eval(
exact_raw,
na_probs,
num_true_pos,
qid_to_has_ans,
out_image=os.path.join(out_image_dir, "pr_exact.png"),
title="Precision-Recall curve for Exact Match score",
)
pr_f1 = make_precision_recall_eval(
f1_raw,
na_probs,
num_true_pos,
qid_to_has_ans,
out_image=os.path.join(out_image_dir, "pr_f1.png"),
title="Precision-Recall curve for F1 score",
)
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
pr_oracle = make_precision_recall_eval(
oracle_scores,
na_probs,
num_true_pos,
qid_to_has_ans,
out_image=os.path.join(out_image_dir, "pr_oracle.png"),
title="Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)",
)
merge_eval(main_eval, pr_exact, "pr_exact")
merge_eval(main_eval, pr_f1, "pr_f1")
merge_eval(main_eval, pr_oracle, "pr_oracle")
def histogram_na_prob(na_probs, qid_list, image_dir, name):
if not qid_list:
return
x = [na_probs[k] for k in qid_list]
weights = np.ones_like(x) / float(len(x))
plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
plt.xlabel("Model probability of no-answer")
plt.ylabel("Proportion of dataset")
plt.title("Histogram of no-answer probability: %s" % name)
plt.savefig(os.path.join(image_dir, "na_prob_hist_%s.png" % name))
plt.clf()
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = 0.0
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
for i, qid in enumerate(qid_list):
if qid not in scores:
continue
if qid_to_has_ans[qid]:
diff = scores[qid]
else:
if preds[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = na_probs[qid]
return 100.0 * best_score / len(scores), best_thresh
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
main_eval["best_exact"] = best_exact
main_eval["best_exact_thresh"] = exact_thresh
main_eval["best_f1"] = best_f1
main_eval["best_f1_thresh"] = f1_thresh
def main():
with open(OPTS.data_file) as f:
dataset_json = json.load(f)
dataset = dataset_json["data"]
with open(OPTS.pred_file) as f:
preds = json.load(f)
if OPTS.na_prob_file:
with open(OPTS.na_prob_file) as f:
na_probs = json.load(f)
else:
na_probs = {k: 0.0 for k in preds}
qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = get_raw_scores(dataset, preds)
exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh)
f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh)
out_eval = make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
merge_eval(out_eval, has_ans_eval, "HasAns")
if no_ans_qids:
no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
merge_eval(out_eval, no_ans_eval, "NoAns")
if OPTS.na_prob_file:
find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
if OPTS.na_prob_file and OPTS.out_image_dir:
run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, OPTS.out_image_dir)
histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, "hasAns")
histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, "noAns")
if OPTS.out_file:
with open(OPTS.out_file, "w") as f:
json.dump(out_eval, f)
else:
print(json.dumps(out_eval, indent=2))
if __name__ == "__main__":
OPTS = parse_args()
if OPTS.out_image_dir:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
main()

View File

@ -0,0 +1,128 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" SQuAD v2 metric. """
import datasets
from .evaluate import (
apply_no_ans_threshold,
find_all_best_thresh,
get_raw_scores,
make_eval_dict,
make_qid_to_has_ans,
merge_eval,
)
_CITATION = """\
@inproceedings{Rajpurkar2016SQuAD10,
title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text},
author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang},
booktitle={EMNLP},
year={2016}
}
"""
_DESCRIPTION = """
This metric wrap the official scoring script for version 2 of the Stanford Question
Answering Dataset (SQuAD).
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by
crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span,
from the corresponding reading passage, or the question might be unanswerable.
SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable questions
written adversarially by crowdworkers to look similar to answerable ones.
To do well on SQuAD2.0, systems must not only answer questions when possible, but also
determine when no answer is supported by the paragraph and abstain from answering.
"""
_KWARGS_DESCRIPTION = """
Computes SQuAD v2 scores (F1 and EM).
Args:
predictions: List of triple for question-answers to score with the following elements:
- the question-answer 'id' field as given in the references (see below)
- the text of the answer
- the probability that the question has no answer
references: List of question-answers dictionaries with the following key-values:
- 'id': id of the question-answer pair (see above),
- 'answers': a list of Dict {'text': text of the answer as a string}
no_answer_threshold: float
Probability threshold to decide that a question has no answer.
Returns:
'exact': Exact match (the normalized answer exactly match the gold answer)
'f1': The F-score of predicted tokens versus the gold answer
'total': Number of score considered
'HasAns_exact': Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': The F-score of predicted tokens versus the gold answer
'HasAns_total': Number of score considered
'NoAns_exact': Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': The F-score of predicted tokens versus the gold answer
'NoAns_total': Number of score considered
'best_exact': Best exact match (with varying threshold)
'best_exact_thresh': No-answer probability threshold associated to the best exact match
'best_f1': Best F1 (with varying threshold)
'best_f1_thresh': No-answer probability threshold associated to the best F1
"""
class SquadV2(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": {
"id": datasets.Value("string"),
"prediction_text": datasets.Value("string"),
"no_answer_probability": datasets.Value("float32"),
},
"references": {
"id": datasets.Value("string"),
"answers": datasets.features.Sequence(
{"text": datasets.Value("string"), "answer_start": datasets.Value("int32")}
),
},
}
),
codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
)
def _compute(self, predictions, references, no_answer_threshold=1.0):
no_answer_probabilities = dict((p["id"], p["no_answer_probability"]) for p in predictions)
dataset = [{"paragraphs": [{"qas": references}]}]
predictions = dict((p["id"], p["prediction_text"]) for p in predictions)
qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = get_raw_scores(dataset, predictions)
exact_thresh = apply_no_ans_threshold(exact_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold)
f1_thresh = apply_no_ans_threshold(f1_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold)
out_eval = make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
merge_eval(out_eval, has_ans_eval, "HasAns")
if no_ans_qids:
no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
merge_eval(out_eval, no_ans_eval, "NoAns")
find_all_best_thresh(out_eval, predictions, exact_raw, f1_raw, no_answer_probabilities, qid_to_has_ans)
return out_eval

View File

@ -0,0 +1,104 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A subclass of `Trainer` specific to Question-Answering tasks
"""
from transformers import Trainer, is_datasets_available, is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput
if is_datasets_available():
import datasets
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
class QuestionAnsweringTrainer(Trainer):
def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs):
super().__init__(*args, **kwargs)
self.eval_examples = eval_examples
self.post_process_function = post_process_function
def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None):
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
eval_dataloader = self.get_eval_dataloader(eval_dataset)
eval_examples = self.eval_examples if eval_examples is None else eval_examples
# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
try:
output = self.prediction_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
)
finally:
self.compute_metrics = compute_metrics
# We might have removed columns from the dataset so we put them back.
if isinstance(eval_dataset, datasets.Dataset):
eval_dataset.set_format(type=eval_dataset.format["type"], columns=list(eval_dataset.features.keys()))
if self.post_process_function is not None and self.compute_metrics is not None:
eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions)
metrics = self.compute_metrics(eval_preds)
self.log(metrics)
else:
metrics = {}
if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
return metrics
def predict(self, test_dataset, test_examples, ignore_keys=None):
test_dataloader = self.get_test_dataloader(test_dataset)
# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
try:
output = self.prediction_loop(
test_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
)
finally:
self.compute_metrics = compute_metrics
if self.post_process_function is None or self.compute_metrics is None:
return output
# We might have removed columns from the dataset so we put them back.
if isinstance(test_dataset, datasets.Dataset):
test_dataset.set_format(type=test_dataset.format["type"], columns=list(test_dataset.features.keys()))
eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions)
metrics = self.compute_metrics(eval_preds)
return PredictionOutput(predictions=eval_preds.predictions, label_ids=eval_preds.label_ids, metrics=metrics)

View File

@ -0,0 +1,429 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Post-processing utilities for question answering.
"""
import collections
import json
import logging
import os
from typing import Optional, Tuple
import numpy as np
from tqdm.auto import tqdm
logger = logging.getLogger(__name__)
def postprocess_qa_predictions(
examples,
features,
predictions: Tuple[np.ndarray, np.ndarray],
version_2_with_negative: bool = False,
n_best_size: int = 20,
max_answer_length: int = 30,
null_score_diff_threshold: float = 0.0,
output_dir: Optional[str] = None,
prefix: Optional[str] = None,
is_world_process_zero: bool = True,
):
"""
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
original contexts. This is the base postprocessing functions for models that only return start and end logits.
Args:
examples: The non-preprocessed dataset (see the main script for more information).
features: The processed dataset (see the main script for more information).
predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
first dimension must match the number of elements of :obj:`features`.
version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the underlying dataset contains examples with no answers.
n_best_size (:obj:`int`, `optional`, defaults to 20):
The total number of n-best predictions to generate when looking for an answer.
max_answer_length (:obj:`int`, `optional`, defaults to 30):
The maximum length of an answer that can be generated. This is needed because the start and end predictions
are not conditioned on one another.
null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0):
The threshold used to select the null answer: if the best answer has a score that is less than the score of
the null answer minus this threshold, the null answer is selected for this example (note that the score of
the null answer for an example giving several features is the minimum of the scores for the null answer on
each feature: all features must be aligned on the fact they `want` to predict a null answer).
Only useful when :obj:`version_2_with_negative` is :obj:`True`.
output_dir (:obj:`str`, `optional`):
If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
:obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
answers, are saved in `output_dir`.
prefix (:obj:`str`, `optional`):
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether this process is the main process or not (used to determine if logging/saves should be done).
"""
assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
all_start_logits, all_end_logits = predictions
assert len(predictions[0]) == len(
features
), f"Got {len(predictions[0])} predicitions and {len(features)} features."
# Build a map example to its corresponding features.
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
features_per_example[example_id_to_index[feature["example_id"]]].append(i)
# The dictionaries we have to fill.
all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict()
if version_2_with_negative:
scores_diff_json = collections.OrderedDict()
# Logging.
logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
# Let's loop over all the examples!
for example_index, example in enumerate(tqdm(examples)):
# Those are the indices of the features associated to the current example.
feature_indices = features_per_example[example_index]
min_null_prediction = None
prelim_predictions = []
# Looping through all the features associated to the current example.
for feature_index in feature_indices:
# We grab the predictions of the model for this feature.
start_logits = all_start_logits[feature_index]
end_logits = all_end_logits[feature_index]
# This is what will allow us to map some the positions in our logits to span of texts in the original
# context.
offset_mapping = features[feature_index]["offset_mapping"]
# Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
# available in the current feature.
token_is_max_context = features[feature_index].get("token_is_max_context", None)
# Update minimum null prediction.
feature_null_score = start_logits[0] + end_logits[0]
if min_null_prediction is None or min_null_prediction["score"] < feature_null_score:
min_null_prediction = {
"offsets": (0, 0),
"score": feature_null_score,
"start_logit": start_logits[0],
"end_logit": end_logits[0],
}
# Go through all possibilities for the `n_best_size` greater start and end logits.
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
for start_index in start_indexes:
for end_index in end_indexes:
# Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
# to part of the input_ids that are not in the context.
if (
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or offset_mapping[end_index] is None
):
continue
# Don't consider answers with a length that is either < 0 or > max_answer_length.
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue
# Don't consider answer that don't have the maximum context available (if such information is
# provided).
if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
continue
prelim_predictions.append(
{
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
"score": start_logits[start_index] + end_logits[end_index],
"start_logit": start_logits[start_index],
"end_logit": end_logits[end_index],
}
)
if version_2_with_negative:
# Add the minimum null prediction
prelim_predictions.append(min_null_prediction)
null_score = min_null_prediction["score"]
# Only keep the best `n_best_size` predictions.
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
# Add back the minimum null prediction if it was removed because of its low score.
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
predictions.append(min_null_prediction)
# Use the offsets to gather the answer text in the original context.
context = example["context"]
for pred in predictions:
offsets = pred.pop("offsets")
pred["text"] = context[offsets[0] : offsets[1]]
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
# failure.
if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0})
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
# the LogSumExp trick).
scores = np.array([pred.pop("score") for pred in predictions])
exp_scores = np.exp(scores - np.max(scores))
probs = exp_scores / exp_scores.sum()
# Include the probabilities in our predictions.
for prob, pred in zip(probs, predictions):
pred["probability"] = prob
# Pick the best prediction. If the null answer is not possible, this is easy.
if not version_2_with_negative:
all_predictions[example["id"]] = predictions[0]["text"]
else:
# Otherwise we first need to find the best non-empty prediction.
i = 0
while predictions[i]["text"] == "":
i += 1
best_non_null_pred = predictions[i]
# Then we compare to the null prediction using the threshold.
score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable.
if score_diff > null_score_diff_threshold:
all_predictions[example["id"]] = ""
else:
all_predictions[example["id"]] = best_non_null_pred["text"]
# Make `predictions` JSON-serializable by casting np.float back to float.
all_nbest_json[example["id"]] = [
{k: (float(v) if isinstance(v, (np.float32, np.float64)) else v) for k, v in pred.items()}
for pred in predictions
]
# If we have an output_dir, let's save all those dicts.
if output_dir is not None:
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
prediction_file = os.path.join(
output_dir, "predictions.json" if prefix is None else f"predictions_{prefix}".json
)
nbest_file = os.path.join(
output_dir, "nbest_predictions.json" if prefix is None else f"nbest_predictions_{prefix}".json
)
if version_2_with_negative:
null_odds_file = os.path.join(
output_dir, "null_odds.json" if prefix is None else f"null_odds_{prefix}".json
)
logger.info(f"Saving predictions to {prediction_file}.")
with open(prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
logger.info(f"Saving nbest_preds to {nbest_file}.")
with open(nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative:
logger.info(f"Saving null_odds to {null_odds_file}.")
with open(null_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
return all_predictions
def postprocess_qa_predictions_with_beam_search(
examples,
features,
predictions: Tuple[np.ndarray, np.ndarray],
version_2_with_negative: bool = False,
n_best_size: int = 20,
max_answer_length: int = 30,
start_n_top: int = 5,
end_n_top: int = 5,
output_dir: Optional[str] = None,
prefix: Optional[str] = None,
is_world_process_zero: bool = True,
):
"""
Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the
original contexts. This is the postprocessing functions for models that return start and end logits, indices, as well as
cls token predictions.
Args:
examples: The non-preprocessed dataset (see the main script for more information).
features: The processed dataset (see the main script for more information).
predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
first dimension must match the number of elements of :obj:`features`.
version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the underlying dataset contains examples with no answers.
n_best_size (:obj:`int`, `optional`, defaults to 20):
The total number of n-best predictions to generate when looking for an answer.
max_answer_length (:obj:`int`, `optional`, defaults to 30):
The maximum length of an answer that can be generated. This is needed because the start and end predictions
are not conditioned on one another.
start_n_top (:obj:`int`, `optional`, defaults to 5):
The number of top start logits too keep when searching for the :obj:`n_best_size` predictions.
end_n_top (:obj:`int`, `optional`, defaults to 5):
The number of top end logits too keep when searching for the :obj:`n_best_size` predictions.
output_dir (:obj:`str`, `optional`):
If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
:obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
answers, are saved in `output_dir`.
prefix (:obj:`str`, `optional`):
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether this process is the main process or not (used to determine if logging/saves should be done).
"""
assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
assert len(predictions[0]) == len(
features
), f"Got {len(predictions[0])} predicitions and {len(features)} features."
# Build a map example to its corresponding features.
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
features_per_example[example_id_to_index[feature["example_id"]]].append(i)
# The dictionaries we have to fill.
all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict()
scores_diff_json = collections.OrderedDict() if version_2_with_negative else None
# Logging.
logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
# Let's loop over all the examples!
for example_index, example in enumerate(tqdm(examples)):
# Those are the indices of the features associated to the current example.
feature_indices = features_per_example[example_index]
min_null_score = None
prelim_predictions = []
# Looping through all the features associated to the current example.
for feature_index in feature_indices:
# We grab the predictions of the model for this feature.
start_log_prob = start_top_log_probs[feature_index]
start_indexes = start_top_index[feature_index]
end_log_prob = end_top_log_probs[feature_index]
end_indexes = end_top_index[feature_index]
feature_null_score = cls_logits[feature_index]
# This is what will allow us to map some the positions in our logits to span of texts in the original
# context.
offset_mapping = features[feature_index]["offset_mapping"]
# Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
# available in the current feature.
token_is_max_context = features[feature_index].get("token_is_max_context", None)
# Update minimum null prediction
if min_null_score is None or feature_null_score < min_null_score:
min_null_score = feature_null_score
# Go through all possibilities for the `n_start_top`/`n_end_top` greater start and end logits.
for i in range(start_n_top):
for j in range(end_n_top):
start_index = start_indexes[i]
j_index = i * end_n_top + j
end_index = end_indexes[j_index]
# Don't consider out-of-scope answers (last part of the test should be unnecessary because of the
# p_mask but let's not take any risk)
if (
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or offset_mapping[end_index] is None
):
continue
# Don't consider answers with a length negative or > max_answer_length.
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue
# Don't consider answer that don't have the maximum context available (if such information is
# provided).
if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
continue
prelim_predictions.append(
{
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
"score": start_log_prob[i] + end_log_prob[j_index],
"start_log_prob": start_log_prob[i],
"end_log_prob": end_log_prob[j_index],
}
)
# Only keep the best `n_best_size` predictions.
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
# Use the offsets to gather the answer text in the original context.
context = example["context"]
for pred in predictions:
offsets = pred.pop("offsets")
pred["text"] = context[offsets[0] : offsets[1]]
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
# failure.
if len(predictions) == 0:
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
# the LogSumExp trick).
scores = np.array([pred.pop("score") for pred in predictions])
exp_scores = np.exp(scores - np.max(scores))
probs = exp_scores / exp_scores.sum()
# Include the probabilities in our predictions.
for prob, pred in zip(probs, predictions):
pred["probability"] = prob
# Pick the best prediction and set the probability for the null answer.
all_predictions[example["id"]] = predictions[0]["text"]
if version_2_with_negative:
scores_diff_json[example["id"]] = float(min_null_score)
# Make `predictions` JSON-serializable by casting np.float back to float.
all_nbest_json[example["id"]] = [
{k: (float(v) if isinstance(v, (np.float32, np.float64)) else v) for k, v in pred.items()}
for pred in predictions
]
# If we have an output_dir, let's save all those dicts.
if output_dir is not None:
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
prediction_file = os.path.join(
output_dir, "predictions.json" if prefix is None else f"predictions_{prefix}".json
)
nbest_file = os.path.join(
output_dir, "nbest_predictions.json" if prefix is None else f"nbest_predictions_{prefix}".json
)
if version_2_with_negative:
null_odds_file = os.path.join(
output_dir, "null_odds.json" if prefix is None else f"null_odds_{prefix}".json
)
print(f"Saving predictions to {prediction_file}.")
with open(prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
print(f"Saving nbest_preds to {nbest_file}.")
with open(nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative:
print(f"Saving null_odds to {null_odds_file}.")
with open(null_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
return all_predictions, scores_diff_json

View File

@ -46,7 +46,7 @@ if SRC_DIRS is not None:
import run_mlm
import run_ner
import run_pl_glue
import run_squad
import run_qa as run_squad
logging.basicConfig(level=logging.DEBUG)
@ -213,8 +213,8 @@ class ExamplesTests(TestCasePlus):
--do_eval
--warmup_steps=2
--learning_rate=2e-4
--per_gpu_train_batch_size=2
--per_gpu_eval_batch_size=2
--per_device_train_batch_size=2
--per_device_eval_batch_size=2
--num_train_epochs=2
""".split()
@ -235,26 +235,25 @@ class ExamplesTests(TestCasePlus):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_squad.py
--model_type=distilbert
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
--data_dir=./tests/fixtures/tests_samples/SQUAD
--model_name_or_path bert-base-uncased
--version_2_with_negative
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--max_steps=10
--warmup_steps=2
--do_train
--do_eval
--version_2_with_negative
--learning_rate=2e-4
--per_gpu_train_batch_size=2
--per_gpu_eval_batch_size=1
--seed=42
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
""".split()
with patch.object(sys, "argv", testargs):
result = run_squad.main()
self.assertGreaterEqual(result["f1"], 25)
self.assertGreaterEqual(result["exact"], 21)
self.assertGreaterEqual(result["f1"], 30)
self.assertGreaterEqual(result["exact"], 30)
@require_torch_non_multi_gpu_but_fix_me
def test_generation(self):

View File

@ -1,140 +0,0 @@
{
"version": "v2.0",
"data": [{
"title": "Normans",
"paragraphs": [{
"qas": [{
"question": "In what country is Normandy located?",
"id": "56ddde6b9a695914005b9628",
"answers": [{
"text": "France",
"answer_start": 159
}],
"is_impossible": false
}, {
"question": "When were the Normans in Normandy?",
"id": "56ddde6b9a695914005b9629",
"answers": [{
"text": "10th and 11th centuries",
"answer_start": 94
}],
"is_impossible": false
}, {
"question": "From which countries did the Norse originate?",
"id": "56ddde6b9a695914005b962a",
"answers": [{
"text": "Denmark, Iceland and Norway",
"answer_start": 256
}],
"is_impossible": false
}, {
"plausible_answers": [{
"text": "Rollo",
"answer_start": 308
}],
"question": "Who did King Charles III swear fealty to?",
"id": "5ad39d53604f3c001a3fe8d3",
"answers": [],
"is_impossible": true
}, {
"plausible_answers": [{
"text": "10th century",
"answer_start": 671
}],
"question": "When did the Frankish identity emerge?",
"id": "5ad39d53604f3c001a3fe8d4",
"answers": [],
"is_impossible": true
}],
"context": "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries."
}, {
"qas": [{
"question": "Who was the duke in the battle of Hastings?",
"id": "56dddf4066d3e219004dad5f",
"answers": [{
"text": "William the Conqueror",
"answer_start": 1022
}],
"is_impossible": false
}, {
"plausible_answers": [{
"text": "Antioch",
"answer_start": 1295
}],
"question": "What principality did William the conquerer found?",
"id": "5ad3a266604f3c001a3fea2b",
"answers": [],
"is_impossible": true
}],
"context": "The Norman dynasty had a major political, cultural and military impact on medieval Europe and even the Near East. The Normans were famed for their martial spirit and eventually for their Christian piety, becoming exponents of the Catholic orthodoxy into which they assimilated. They adopted the Gallo-Romance language of the Frankish land they settled, their dialect becoming known as Norman, Normaund or Norman French, an important literary language. The Duchy of Normandy, which they formed by treaty with the French crown, was a great fief of medieval France, and under Richard I of Normandy was forged into a cohesive and formidable principality in feudal tenure. The Normans are noted both for their culture, such as their unique Romanesque architecture and musical traditions, and for their significant military accomplishments and innovations. Norman adventurers founded the Kingdom of Sicily under Roger II after conquering southern Italy on the Saracens and Byzantines, and an expedition on behalf of their duke, William the Conqueror, led to the Norman conquest of England at the Battle of Hastings in 1066. Norman cultural and military influence spread from these new European centres to the Crusader states of the Near East, where their prince Bohemond I founded the Principality of Antioch in the Levant, to Scotland and Wales in Great Britain, to Ireland, and to the coasts of north Africa and the Canary Islands."
}]
}, {
"title": "Computational_complexity_theory",
"paragraphs": [{
"qas": [{
"question": "What branch of theoretical computer science deals with broadly classifying computational problems by difficulty and class of relationship?",
"id": "56e16182e3433e1400422e28",
"answers": [{
"text": "Computational complexity theory",
"answer_start": 0
}],
"is_impossible": false
}, {
"plausible_answers": [{
"text": "algorithm",
"answer_start": 472
}],
"question": "What is a manual application of mathematical steps?",
"id": "5ad5316b5b96ef001a10ab76",
"answers": [],
"is_impossible": true
}],
"context": "Computational complexity theory is a branch of the theory of computation in theoretical computer science that focuses on classifying computational problems according to their inherent difficulty, and relating those classes to each other. A computational problem is understood to be a task that is in principle amenable to being solved by a computer, which is equivalent to stating that the problem may be solved by mechanical application of mathematical steps, such as an algorithm."
}, {
"qas": [{
"question": "What measure of a computational problem broadly defines the inherent difficulty of the solution?",
"id": "56e16839cd28a01900c67887",
"answers": [{
"text": "if its solution requires significant resources",
"answer_start": 46
}],
"is_impossible": false
}, {
"question": "What method is used to intuitively assess or quantify the amount of resources required to solve a computational problem?",
"id": "56e16839cd28a01900c67888",
"answers": [{
"text": "mathematical models of computation",
"answer_start": 176
}],
"is_impossible": false
}, {
"question": "What are two basic primary resources used to guage complexity?",
"id": "56e16839cd28a01900c67889",
"answers": [{
"text": "time and storage",
"answer_start": 305
}],
"is_impossible": false
}, {
"plausible_answers": [{
"text": "the number of gates in a circuit",
"answer_start": 436
}],
"question": "What unit is measured to determine circuit simplicity?",
"id": "5ad532575b96ef001a10ab7f",
"answers": [],
"is_impossible": true
}, {
"plausible_answers": [{
"text": "the number of processors",
"answer_start": 502
}],
"question": "What number is used in perpendicular computing?",
"id": "5ad532575b96ef001a10ab80",
"answers": [],
"is_impossible": true
}],
"context": "A problem is regarded as inherently difficult if its solution requires significant resources, whatever the algorithm used. The theory formalizes this intuition, by introducing mathematical models of computation to study these problems and quantifying the amount of resources needed to solve them, such as time and storage. Other complexity measures are also used, such as the amount of communication (used in communication complexity), the number of gates in a circuit (used in circuit complexity) and the number of processors (used in parallel computing). One of the roles of computational complexity theory is to determine the practical limits on what computers can and cannot do."
}]
}]
}

View File

@ -0,0 +1,201 @@
{
"version": 2.0,
"data": [
{
"id": "56ddde6b9a695914005b9628",
"question": "In what country is Normandy located?",
"context": "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.",
"answers": {
"answer_start": [
159,
159,
159,
159
],
"text": [
"France",
"France",
"France",
"France"
]
}
},
{
"id": "56ddde6b9a695914005b9629",
"question": "When were the Normans in Normandy?",
"context": "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.",
"answers": {
"answer_start": [
94,
87,
94,
94
],
"text": [
"10th and 11th centuries",
"in the 10th and 11th centuries",
"10th and 11th centuries",
"10th and 11th centuries"
]
}
},
{
"id": "56ddde6b9a695914005b962a",
"question": "From which countries did the Norse originate?",
"context": "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.",
"answers": {
"answer_start": [
256,
256,
256,
256
],
"text": [
"Denmark, Iceland and Norway",
"Denmark, Iceland and Norway",
"Denmark, Iceland and Norway",
"Denmark, Iceland and Norway"
]
}
},
{
"id": "5ad39d53604f3c001a3fe8d3",
"question": "Who did King Charles III swear fealty to?",
"context": "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.",
"answers": {
"answer_start": [],
"text": []
}
},
{
"id": "5ad39d53604f3c001a3fe8d4",
"question": "When did the Frankish identity emerge?",
"context": "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.",
"answers": {
"answer_start": [],
"text": []
}
},
{
"id": "56dddf4066d3e219004dad5f",
"question": "Who was the duke in the battle of Hastings?",
"context": "The Norman dynasty had a major political, cultural and military impact on medieval Europe and even the Near East. The Normans were famed for their martial spirit and eventually for their Christian piety, becoming exponents of the Catholic orthodoxy into which they assimilated. They adopted the Gallo-Romance language of the Frankish land they settled, their dialect becoming known as Norman, Normaund or Norman French, an important literary language. The Duchy of Normandy, which they formed by treaty with the French crown, was a great fief of medieval France, and under Richard I of Normandy was forged into a cohesive and formidable principality in feudal tenure. The Normans are noted both for their culture, such as their unique Romanesque architecture and musical traditions, and for their significant military accomplishments and innovations. Norman adventurers founded the Kingdom of Sicily under Roger II after conquering southern Italy on the Saracens and Byzantines, and an expedition on behalf of their duke, William the Conqueror, led to the Norman conquest of England at the Battle of Hastings in 1066. Norman cultural and military influence spread from these new European centres to the Crusader states of the Near East, where their prince Bohemond I founded the Principality of Antioch in the Levant, to Scotland and Wales in Great Britain, to Ireland, and to the coasts of north Africa and the Canary Islands.",
"answers": {
"answer_start": [
1022,
1022,
1022
],
"text": [
"William the Conqueror",
"William the Conqueror",
"William the Conqueror"
]
}
},
{
"id": "5ad3a266604f3c001a3fea2b",
"question": "What principality did William the conquerer found?",
"context": "The Norman dynasty had a major political, cultural and military impact on medieval Europe and even the Near East. The Normans were famed for their martial spirit and eventually for their Christian piety, becoming exponents of the Catholic orthodoxy into which they assimilated. They adopted the Gallo-Romance language of the Frankish land they settled, their dialect becoming known as Norman, Normaund or Norman French, an important literary language. The Duchy of Normandy, which they formed by treaty with the French crown, was a great fief of medieval France, and under Richard I of Normandy was forged into a cohesive and formidable principality in feudal tenure. The Normans are noted both for their culture, such as their unique Romanesque architecture and musical traditions, and for their significant military accomplishments and innovations. Norman adventurers founded the Kingdom of Sicily under Roger II after conquering southern Italy on the Saracens and Byzantines, and an expedition on behalf of their duke, William the Conqueror, led to the Norman conquest of England at the Battle of Hastings in 1066. Norman cultural and military influence spread from these new European centres to the Crusader states of the Near East, where their prince Bohemond I founded the Principality of Antioch in the Levant, to Scotland and Wales in Great Britain, to Ireland, and to the coasts of north Africa and the Canary Islands.",
"answers": {
"answer_start": [],
"text": []
}
},
{
"id": "56e16182e3433e1400422e28",
"question": "What branch of theoretical computer science deals with broadly classifying computational problems by difficulty and class of relationship?",
"context": "Computational complexity theory is a branch of the theory of computation in theoretical computer science that focuses on classifying computational problems according to their inherent difficulty, and relating those classes to each other. A computational problem is understood to be a task that is in principle amenable to being solved by a computer, which is equivalent to stating that the problem may be solved by mechanical application of mathematical steps, such as an algorithm.",
"answers": {
"answer_start": [
0,
0,
0
],
"text": [
"Computational complexity theory",
"Computational complexity theory",
"Computational complexity theory"
]
}
},
{
"id": "5ad5316b5b96ef001a10ab76",
"question": "What is a manual application of mathematical steps?",
"context": "Computational complexity theory is a branch of the theory of computation in theoretical computer science that focuses on classifying computational problems according to their inherent difficulty, and relating those classes to each other. A computational problem is understood to be a task that is in principle amenable to being solved by a computer, which is equivalent to stating that the problem may be solved by mechanical application of mathematical steps, such as an algorithm.",
"answers": {
"answer_start": [],
"text": []
}
},
{
"id": "56e16839cd28a01900c67887",
"question": "What measure of a computational problem broadly defines the inherent difficulty of the solution?",
"context": "A problem is regarded as inherently difficult if its solution requires significant resources, whatever the algorithm used. The theory formalizes this intuition, by introducing mathematical models of computation to study these problems and quantifying the amount of resources needed to solve them, such as time and storage. Other complexity measures are also used, such as the amount of communication (used in communication complexity), the number of gates in a circuit (used in circuit complexity) and the number of processors (used in parallel computing). One of the roles of computational complexity theory is to determine the practical limits on what computers can and cannot do.",
"answers": {
"answer_start": [
46,
49,
46
],
"text": [
"if its solution requires significant resources",
"its solution requires significant resources",
"if its solution requires significant resources"
]
}
},
{
"id": "56e16839cd28a01900c67888",
"question": "What method is used to intuitively assess or quantify the amount of resources required to solve a computational problem?",
"context": "A problem is regarded as inherently difficult if its solution requires significant resources, whatever the algorithm used. The theory formalizes this intuition, by introducing mathematical models of computation to study these problems and quantifying the amount of resources needed to solve them, such as time and storage. Other complexity measures are also used, such as the amount of communication (used in communication complexity), the number of gates in a circuit (used in circuit complexity) and the number of processors (used in parallel computing). One of the roles of computational complexity theory is to determine the practical limits on what computers can and cannot do.",
"answers": {
"answer_start": [
176,
176,
176
],
"text": [
"mathematical models of computation",
"mathematical models of computation",
"mathematical models of computation"
]
}
},
{
"id": "56e16839cd28a01900c67889",
"question": "What are two basic primary resources used to guage complexity?",
"context": "A problem is regarded as inherently difficult if its solution requires significant resources, whatever the algorithm used. The theory formalizes this intuition, by introducing mathematical models of computation to study these problems and quantifying the amount of resources needed to solve them, such as time and storage. Other complexity measures are also used, such as the amount of communication (used in communication complexity), the number of gates in a circuit (used in circuit complexity) and the number of processors (used in parallel computing). One of the roles of computational complexity theory is to determine the practical limits on what computers can and cannot do.",
"answers": {
"answer_start": [
305,
305,
305
],
"text": [
"time and storage",
"time and storage",
"time and storage"
]
}
},
{
"id": "5ad532575b96ef001a10ab7f",
"question": "What unit is measured to determine circuit simplicity?",
"context": "A problem is regarded as inherently difficult if its solution requires significant resources, whatever the algorithm used. The theory formalizes this intuition, by introducing mathematical models of computation to study these problems and quantifying the amount of resources needed to solve them, such as time and storage. Other complexity measures are also used, such as the amount of communication (used in communication complexity), the number of gates in a circuit (used in circuit complexity) and the number of processors (used in parallel computing). One of the roles of computational complexity theory is to determine the practical limits on what computers can and cannot do.",
"answers": {
"answer_start": [],
"text": []
}
},
{
"id": "5ad532575b96ef001a10ab80",
"question": "What number is used in perpendicular computing?",
"context": "A problem is regarded as inherently difficult if its solution requires significant resources, whatever the algorithm used. The theory formalizes this intuition, by introducing mathematical models of computation to study these problems and quantifying the amount of resources needed to solve them, such as time and storage. Other complexity measures are also used, such as the amount of communication (used in communication complexity), the number of gates in a circuit (used in circuit complexity) and the number of processors (used in parallel computing). One of the roles of computational complexity theory is to determine the practical limits on what computers can and cannot do.",
"answers": {
"answer_start": [],
"text": []
}
}
]
}

View File

@ -1,140 +0,0 @@
{
"version": "v2.0",
"data": [{
"title": "Normans",
"paragraphs": [{
"qas": [{
"question": "In what country is Normandy located?",
"id": "56ddde6b9a695914005b9628",
"answers": [{
"text": "France",
"answer_start": 159
}],
"is_impossible": false
}, {
"question": "When were the Normans in Normandy?",
"id": "56ddde6b9a695914005b9629",
"answers": [{
"text": "10th and 11th centuries",
"answer_start": 94
}],
"is_impossible": false
}, {
"question": "From which countries did the Norse originate?",
"id": "56ddde6b9a695914005b962a",
"answers": [{
"text": "Denmark, Iceland and Norway",
"answer_start": 256
}],
"is_impossible": false
}, {
"plausible_answers": [{
"text": "Rollo",
"answer_start": 308
}],
"question": "Who did King Charles III swear fealty to?",
"id": "5ad39d53604f3c001a3fe8d3",
"answers": [],
"is_impossible": true
}, {
"plausible_answers": [{
"text": "10th century",
"answer_start": 671
}],
"question": "When did the Frankish identity emerge?",
"id": "5ad39d53604f3c001a3fe8d4",
"answers": [],
"is_impossible": true
}],
"context": "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries."
}, {
"qas": [{
"question": "Who was the duke in the battle of Hastings?",
"id": "56dddf4066d3e219004dad5f",
"answers": [{
"text": "William the Conqueror",
"answer_start": 1022
}],
"is_impossible": false
}, {
"plausible_answers": [{
"text": "Antioch",
"answer_start": 1295
}],
"question": "What principality did William the conquerer found?",
"id": "5ad3a266604f3c001a3fea2b",
"answers": [],
"is_impossible": true
}],
"context": "The Norman dynasty had a major political, cultural and military impact on medieval Europe and even the Near East. The Normans were famed for their martial spirit and eventually for their Christian piety, becoming exponents of the Catholic orthodoxy into which they assimilated. They adopted the Gallo-Romance language of the Frankish land they settled, their dialect becoming known as Norman, Normaund or Norman French, an important literary language. The Duchy of Normandy, which they formed by treaty with the French crown, was a great fief of medieval France, and under Richard I of Normandy was forged into a cohesive and formidable principality in feudal tenure. The Normans are noted both for their culture, such as their unique Romanesque architecture and musical traditions, and for their significant military accomplishments and innovations. Norman adventurers founded the Kingdom of Sicily under Roger II after conquering southern Italy on the Saracens and Byzantines, and an expedition on behalf of their duke, William the Conqueror, led to the Norman conquest of England at the Battle of Hastings in 1066. Norman cultural and military influence spread from these new European centres to the Crusader states of the Near East, where their prince Bohemond I founded the Principality of Antioch in the Levant, to Scotland and Wales in Great Britain, to Ireland, and to the coasts of north Africa and the Canary Islands."
}]
}, {
"title": "Computational_complexity_theory",
"paragraphs": [{
"qas": [{
"question": "What branch of theoretical computer science deals with broadly classifying computational problems by difficulty and class of relationship?",
"id": "56e16182e3433e1400422e28",
"answers": [{
"text": "Computational complexity theory",
"answer_start": 0
}],
"is_impossible": false
}, {
"plausible_answers": [{
"text": "algorithm",
"answer_start": 472
}],
"question": "What is a manual application of mathematical steps?",
"id": "5ad5316b5b96ef001a10ab76",
"answers": [],
"is_impossible": true
}],
"context": "Computational complexity theory is a branch of the theory of computation in theoretical computer science that focuses on classifying computational problems according to their inherent difficulty, and relating those classes to each other. A computational problem is understood to be a task that is in principle amenable to being solved by a computer, which is equivalent to stating that the problem may be solved by mechanical application of mathematical steps, such as an algorithm."
}, {
"qas": [{
"question": "What measure of a computational problem broadly defines the inherent difficulty of the solution?",
"id": "56e16839cd28a01900c67887",
"answers": [{
"text": "if its solution requires significant resources",
"answer_start": 46
}],
"is_impossible": false
}, {
"question": "What method is used to intuitively assess or quantify the amount of resources required to solve a computational problem?",
"id": "56e16839cd28a01900c67888",
"answers": [{
"text": "mathematical models of computation",
"answer_start": 176
}],
"is_impossible": false
}, {
"question": "What are two basic primary resources used to guage complexity?",
"id": "56e16839cd28a01900c67889",
"answers": [{
"text": "time and storage",
"answer_start": 305
}],
"is_impossible": false
}, {
"plausible_answers": [{
"text": "the number of gates in a circuit",
"answer_start": 436
}],
"question": "What unit is measured to determine circuit simplicity?",
"id": "5ad532575b96ef001a10ab7f",
"answers": [],
"is_impossible": true
}, {
"plausible_answers": [{
"text": "the number of processors",
"answer_start": 502
}],
"question": "What number is used in perpendicular computing?",
"id": "5ad532575b96ef001a10ab80",
"answers": [],
"is_impossible": true
}],
"context": "A problem is regarded as inherently difficult if its solution requires significant resources, whatever the algorithm used. The theory formalizes this intuition, by introducing mathematical models of computation to study these problems and quantifying the amount of resources needed to solve them, such as time and storage. Other complexity measures are also used, such as the amount of communication (used in communication complexity), the number of gates in a circuit (used in circuit complexity) and the number of processors (used in parallel computing). One of the roles of computational complexity theory is to determine the practical limits on what computers can and cannot do."
}]
}]
}