579 lines
20 KiB
Python
579 lines
20 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
|
|
|
|
|
|
import csv
|
|
import glob
|
|
import json
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import List, Optional
|
|
|
|
import tqdm
|
|
from filelock import FileLock
|
|
|
|
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class InputExample:
|
|
"""
|
|
A single training/test example for multiple choice
|
|
|
|
Args:
|
|
example_id: Unique id for the example.
|
|
question: string. The untokenized text of the second sequence (question).
|
|
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
|
|
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
|
|
label: (Optional) string. The label of the example. This should be
|
|
specified for train and dev examples, but not for test examples.
|
|
"""
|
|
|
|
example_id: str
|
|
question: str
|
|
contexts: List[str]
|
|
endings: List[str]
|
|
label: Optional[str]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class InputFeatures:
|
|
"""
|
|
A single set of features of data.
|
|
Property names are the same names as the corresponding inputs to a model.
|
|
"""
|
|
|
|
example_id: str
|
|
input_ids: List[List[int]]
|
|
attention_mask: Optional[List[List[int]]]
|
|
token_type_ids: Optional[List[List[int]]]
|
|
label: Optional[int]
|
|
|
|
|
|
class Split(Enum):
|
|
train = "train"
|
|
dev = "dev"
|
|
test = "test"
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
class MultipleChoiceDataset(Dataset):
|
|
"""
|
|
This will be superseded by a framework-agnostic approach
|
|
soon.
|
|
"""
|
|
|
|
features: List[InputFeatures]
|
|
|
|
def __init__(
|
|
self,
|
|
data_dir: str,
|
|
tokenizer: PreTrainedTokenizer,
|
|
task: str,
|
|
max_seq_length: Optional[int] = None,
|
|
overwrite_cache=False,
|
|
mode: Split = Split.train,
|
|
):
|
|
processor = processors[task]()
|
|
|
|
cached_features_file = os.path.join(
|
|
data_dir,
|
|
"cached_{}_{}_{}_{}".format(
|
|
mode.value,
|
|
tokenizer.__class__.__name__,
|
|
str(max_seq_length),
|
|
task,
|
|
),
|
|
)
|
|
|
|
# Make sure only the first process in distributed training processes the dataset,
|
|
# and the others will use the cache.
|
|
lock_path = cached_features_file + ".lock"
|
|
with FileLock(lock_path):
|
|
if os.path.exists(cached_features_file) and not overwrite_cache:
|
|
logger.info(f"Loading features from cached file {cached_features_file}")
|
|
self.features = torch.load(cached_features_file)
|
|
else:
|
|
logger.info(f"Creating features from dataset file at {data_dir}")
|
|
label_list = processor.get_labels()
|
|
if mode == Split.dev:
|
|
examples = processor.get_dev_examples(data_dir)
|
|
elif mode == Split.test:
|
|
examples = processor.get_test_examples(data_dir)
|
|
else:
|
|
examples = processor.get_train_examples(data_dir)
|
|
logger.info("Training examples: %s", len(examples))
|
|
self.features = convert_examples_to_features(
|
|
examples,
|
|
label_list,
|
|
max_seq_length,
|
|
tokenizer,
|
|
)
|
|
logger.info("Saving features into cached file %s", cached_features_file)
|
|
torch.save(self.features, cached_features_file)
|
|
|
|
def __len__(self):
|
|
return len(self.features)
|
|
|
|
def __getitem__(self, i) -> InputFeatures:
|
|
return self.features[i]
|
|
|
|
|
|
if is_tf_available():
|
|
import tensorflow as tf
|
|
|
|
class TFMultipleChoiceDataset:
|
|
"""
|
|
This will be superseded by a framework-agnostic approach
|
|
soon.
|
|
"""
|
|
|
|
features: List[InputFeatures]
|
|
|
|
def __init__(
|
|
self,
|
|
data_dir: str,
|
|
tokenizer: PreTrainedTokenizer,
|
|
task: str,
|
|
max_seq_length: Optional[int] = 128,
|
|
overwrite_cache=False,
|
|
mode: Split = Split.train,
|
|
):
|
|
processor = processors[task]()
|
|
|
|
logger.info(f"Creating features from dataset file at {data_dir}")
|
|
label_list = processor.get_labels()
|
|
if mode == Split.dev:
|
|
examples = processor.get_dev_examples(data_dir)
|
|
elif mode == Split.test:
|
|
examples = processor.get_test_examples(data_dir)
|
|
else:
|
|
examples = processor.get_train_examples(data_dir)
|
|
logger.info("Training examples: %s", len(examples))
|
|
|
|
self.features = convert_examples_to_features(
|
|
examples,
|
|
label_list,
|
|
max_seq_length,
|
|
tokenizer,
|
|
)
|
|
|
|
def gen():
|
|
for ex_index, ex in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
|
|
if ex_index % 10000 == 0:
|
|
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
|
|
|
yield (
|
|
{
|
|
"example_id": 0,
|
|
"input_ids": ex.input_ids,
|
|
"attention_mask": ex.attention_mask,
|
|
"token_type_ids": ex.token_type_ids,
|
|
},
|
|
ex.label,
|
|
)
|
|
|
|
self.dataset = tf.data.Dataset.from_generator(
|
|
gen,
|
|
(
|
|
{
|
|
"example_id": tf.int32,
|
|
"input_ids": tf.int32,
|
|
"attention_mask": tf.int32,
|
|
"token_type_ids": tf.int32,
|
|
},
|
|
tf.int64,
|
|
),
|
|
(
|
|
{
|
|
"example_id": tf.TensorShape([]),
|
|
"input_ids": tf.TensorShape([None, None]),
|
|
"attention_mask": tf.TensorShape([None, None]),
|
|
"token_type_ids": tf.TensorShape([None, None]),
|
|
},
|
|
tf.TensorShape([]),
|
|
),
|
|
)
|
|
|
|
def get_dataset(self):
|
|
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
|
|
|
|
return self.dataset
|
|
|
|
def __len__(self):
|
|
return len(self.features)
|
|
|
|
def __getitem__(self, i) -> InputFeatures:
|
|
return self.features[i]
|
|
|
|
|
|
class DataProcessor:
|
|
"""Base class for data converters for multiple choice data sets."""
|
|
|
|
def get_train_examples(self, data_dir):
|
|
"""Gets a collection of `InputExample`s for the train set."""
|
|
raise NotImplementedError()
|
|
|
|
def get_dev_examples(self, data_dir):
|
|
"""Gets a collection of `InputExample`s for the dev set."""
|
|
raise NotImplementedError()
|
|
|
|
def get_test_examples(self, data_dir):
|
|
"""Gets a collection of `InputExample`s for the test set."""
|
|
raise NotImplementedError()
|
|
|
|
def get_labels(self):
|
|
"""Gets the list of labels for this data set."""
|
|
raise NotImplementedError()
|
|
|
|
|
|
class RaceProcessor(DataProcessor):
|
|
"""Processor for the RACE data set."""
|
|
|
|
def get_train_examples(self, data_dir):
|
|
"""See base class."""
|
|
logger.info("LOOKING AT {} train".format(data_dir))
|
|
high = os.path.join(data_dir, "train/high")
|
|
middle = os.path.join(data_dir, "train/middle")
|
|
high = self._read_txt(high)
|
|
middle = self._read_txt(middle)
|
|
return self._create_examples(high + middle, "train")
|
|
|
|
def get_dev_examples(self, data_dir):
|
|
"""See base class."""
|
|
logger.info("LOOKING AT {} dev".format(data_dir))
|
|
high = os.path.join(data_dir, "dev/high")
|
|
middle = os.path.join(data_dir, "dev/middle")
|
|
high = self._read_txt(high)
|
|
middle = self._read_txt(middle)
|
|
return self._create_examples(high + middle, "dev")
|
|
|
|
def get_test_examples(self, data_dir):
|
|
"""See base class."""
|
|
logger.info("LOOKING AT {} test".format(data_dir))
|
|
high = os.path.join(data_dir, "test/high")
|
|
middle = os.path.join(data_dir, "test/middle")
|
|
high = self._read_txt(high)
|
|
middle = self._read_txt(middle)
|
|
return self._create_examples(high + middle, "test")
|
|
|
|
def get_labels(self):
|
|
"""See base class."""
|
|
return ["0", "1", "2", "3"]
|
|
|
|
def _read_txt(self, input_dir):
|
|
lines = []
|
|
files = glob.glob(input_dir + "/*txt")
|
|
for file in tqdm.tqdm(files, desc="read files"):
|
|
with open(file, "r", encoding="utf-8") as fin:
|
|
data_raw = json.load(fin)
|
|
data_raw["race_id"] = file
|
|
lines.append(data_raw)
|
|
return lines
|
|
|
|
def _create_examples(self, lines, set_type):
|
|
"""Creates examples for the training and dev sets."""
|
|
examples = []
|
|
for _, data_raw in enumerate(lines):
|
|
race_id = "%s-%s" % (set_type, data_raw["race_id"])
|
|
article = data_raw["article"]
|
|
for i in range(len(data_raw["answers"])):
|
|
truth = str(ord(data_raw["answers"][i]) - ord("A"))
|
|
question = data_raw["questions"][i]
|
|
options = data_raw["options"][i]
|
|
|
|
examples.append(
|
|
InputExample(
|
|
example_id=race_id,
|
|
question=question,
|
|
contexts=[article, article, article, article], # this is not efficient but convenient
|
|
endings=[options[0], options[1], options[2], options[3]],
|
|
label=truth,
|
|
)
|
|
)
|
|
return examples
|
|
|
|
|
|
class SynonymProcessor(DataProcessor):
|
|
"""Processor for the Synonym data set."""
|
|
|
|
def get_train_examples(self, data_dir):
|
|
"""See base class."""
|
|
logger.info("LOOKING AT {} train".format(data_dir))
|
|
return self._create_examples(self._read_csv(os.path.join(data_dir, "mctrain.csv")), "train")
|
|
|
|
def get_dev_examples(self, data_dir):
|
|
"""See base class."""
|
|
logger.info("LOOKING AT {} dev".format(data_dir))
|
|
return self._create_examples(self._read_csv(os.path.join(data_dir, "mchp.csv")), "dev")
|
|
|
|
def get_test_examples(self, data_dir):
|
|
"""See base class."""
|
|
logger.info("LOOKING AT {} dev".format(data_dir))
|
|
|
|
return self._create_examples(self._read_csv(os.path.join(data_dir, "mctest.csv")), "test")
|
|
|
|
def get_labels(self):
|
|
"""See base class."""
|
|
return ["0", "1", "2", "3", "4"]
|
|
|
|
def _read_csv(self, input_file):
|
|
with open(input_file, "r", encoding="utf-8") as f:
|
|
return list(csv.reader(f))
|
|
|
|
def _create_examples(self, lines: List[List[str]], type: str):
|
|
"""Creates examples for the training and dev sets."""
|
|
|
|
examples = [
|
|
InputExample(
|
|
example_id=line[0],
|
|
question="", # in the swag dataset, the
|
|
# common beginning of each
|
|
# choice is stored in "sent2".
|
|
contexts=[line[1], line[1], line[1], line[1], line[1]],
|
|
endings=[line[2], line[3], line[4], line[5], line[6]],
|
|
label=line[7],
|
|
)
|
|
for line in lines # we skip the line with the column names
|
|
]
|
|
|
|
return examples
|
|
|
|
|
|
class SwagProcessor(DataProcessor):
|
|
"""Processor for the SWAG data set."""
|
|
|
|
def get_train_examples(self, data_dir):
|
|
"""See base class."""
|
|
logger.info("LOOKING AT {} train".format(data_dir))
|
|
return self._create_examples(self._read_csv(os.path.join(data_dir, "train.csv")), "train")
|
|
|
|
def get_dev_examples(self, data_dir):
|
|
"""See base class."""
|
|
logger.info("LOOKING AT {} dev".format(data_dir))
|
|
return self._create_examples(self._read_csv(os.path.join(data_dir, "val.csv")), "dev")
|
|
|
|
def get_test_examples(self, data_dir):
|
|
"""See base class."""
|
|
logger.info("LOOKING AT {} dev".format(data_dir))
|
|
raise ValueError(
|
|
"For swag testing, the input file does not contain a label column. It can not be tested in current code "
|
|
"setting!"
|
|
)
|
|
return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test")
|
|
|
|
def get_labels(self):
|
|
"""See base class."""
|
|
return ["0", "1", "2", "3"]
|
|
|
|
def _read_csv(self, input_file):
|
|
with open(input_file, "r", encoding="utf-8") as f:
|
|
return list(csv.reader(f))
|
|
|
|
def _create_examples(self, lines: List[List[str]], type: str):
|
|
"""Creates examples for the training and dev sets."""
|
|
if type == "train" and lines[0][-1] != "label":
|
|
raise ValueError("For training, the input file must contain a label column.")
|
|
|
|
examples = [
|
|
InputExample(
|
|
example_id=line[2],
|
|
question=line[5], # in the swag dataset, the
|
|
# common beginning of each
|
|
# choice is stored in "sent2".
|
|
contexts=[line[4], line[4], line[4], line[4]],
|
|
endings=[line[7], line[8], line[9], line[10]],
|
|
label=line[11],
|
|
)
|
|
for line in lines[1:] # we skip the line with the column names
|
|
]
|
|
|
|
return examples
|
|
|
|
|
|
class ArcProcessor(DataProcessor):
|
|
"""Processor for the ARC data set (request from allennlp)."""
|
|
|
|
def get_train_examples(self, data_dir):
|
|
"""See base class."""
|
|
logger.info("LOOKING AT {} train".format(data_dir))
|
|
return self._create_examples(self._read_json(os.path.join(data_dir, "train.jsonl")), "train")
|
|
|
|
def get_dev_examples(self, data_dir):
|
|
"""See base class."""
|
|
logger.info("LOOKING AT {} dev".format(data_dir))
|
|
return self._create_examples(self._read_json(os.path.join(data_dir, "dev.jsonl")), "dev")
|
|
|
|
def get_test_examples(self, data_dir):
|
|
logger.info("LOOKING AT {} test".format(data_dir))
|
|
return self._create_examples(self._read_json(os.path.join(data_dir, "test.jsonl")), "test")
|
|
|
|
def get_labels(self):
|
|
"""See base class."""
|
|
return ["0", "1", "2", "3"]
|
|
|
|
def _read_json(self, input_file):
|
|
with open(input_file, "r", encoding="utf-8") as fin:
|
|
lines = fin.readlines()
|
|
return lines
|
|
|
|
def _create_examples(self, lines, type):
|
|
"""Creates examples for the training and dev sets."""
|
|
|
|
# There are two types of labels. They should be normalized
|
|
def normalize(truth):
|
|
if truth in "ABCD":
|
|
return ord(truth) - ord("A")
|
|
elif truth in "1234":
|
|
return int(truth) - 1
|
|
else:
|
|
logger.info("truth ERROR! %s", str(truth))
|
|
return None
|
|
|
|
examples = []
|
|
three_choice = 0
|
|
four_choice = 0
|
|
five_choice = 0
|
|
other_choices = 0
|
|
# we deleted example which has more than or less than four choices
|
|
for line in tqdm.tqdm(lines, desc="read arc data"):
|
|
data_raw = json.loads(line.strip("\n"))
|
|
if len(data_raw["question"]["choices"]) == 3:
|
|
three_choice += 1
|
|
continue
|
|
elif len(data_raw["question"]["choices"]) == 5:
|
|
five_choice += 1
|
|
continue
|
|
elif len(data_raw["question"]["choices"]) != 4:
|
|
other_choices += 1
|
|
continue
|
|
four_choice += 1
|
|
truth = str(normalize(data_raw["answerKey"]))
|
|
assert truth != "None"
|
|
question_choices = data_raw["question"]
|
|
question = question_choices["stem"]
|
|
id = data_raw["id"]
|
|
options = question_choices["choices"]
|
|
if len(options) == 4:
|
|
examples.append(
|
|
InputExample(
|
|
example_id=id,
|
|
question=question,
|
|
contexts=[
|
|
options[0]["para"].replace("_", ""),
|
|
options[1]["para"].replace("_", ""),
|
|
options[2]["para"].replace("_", ""),
|
|
options[3]["para"].replace("_", ""),
|
|
],
|
|
endings=[options[0]["text"], options[1]["text"], options[2]["text"], options[3]["text"]],
|
|
label=truth,
|
|
)
|
|
)
|
|
|
|
if type == "train":
|
|
assert len(examples) > 1
|
|
assert examples[0].label is not None
|
|
logger.info("len examples: %s}", str(len(examples)))
|
|
logger.info("Three choices: %s", str(three_choice))
|
|
logger.info("Five choices: %s", str(five_choice))
|
|
logger.info("Other choices: %s", str(other_choices))
|
|
logger.info("four choices: %s", str(four_choice))
|
|
|
|
return examples
|
|
|
|
|
|
def convert_examples_to_features(
|
|
examples: List[InputExample],
|
|
label_list: List[str],
|
|
max_length: int,
|
|
tokenizer: PreTrainedTokenizer,
|
|
) -> List[InputFeatures]:
|
|
"""
|
|
Loads a data file into a list of `InputFeatures`
|
|
"""
|
|
|
|
label_map = {label: i for i, label in enumerate(label_list)}
|
|
|
|
features = []
|
|
for ex_index, example in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
|
|
if ex_index % 10000 == 0:
|
|
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
|
choices_inputs = []
|
|
for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
|
|
text_a = context
|
|
if example.question.find("_") != -1:
|
|
# this is for cloze question
|
|
text_b = example.question.replace("_", ending)
|
|
else:
|
|
text_b = example.question + " " + ending
|
|
|
|
inputs = tokenizer(
|
|
text_a,
|
|
text_b,
|
|
add_special_tokens=True,
|
|
max_length=max_length,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_overflowing_tokens=True,
|
|
)
|
|
if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0:
|
|
logger.info(
|
|
"Attention! you are cropping tokens (swag task is ok). "
|
|
"If you are training ARC and RACE and you are poping question + options, "
|
|
"you need to try to use a bigger max seq length!"
|
|
)
|
|
|
|
choices_inputs.append(inputs)
|
|
|
|
label = label_map[example.label]
|
|
|
|
input_ids = [x["input_ids"] for x in choices_inputs]
|
|
attention_mask = (
|
|
[x["attention_mask"] for x in choices_inputs] if "attention_mask" in choices_inputs[0] else None
|
|
)
|
|
token_type_ids = (
|
|
[x["token_type_ids"] for x in choices_inputs] if "token_type_ids" in choices_inputs[0] else None
|
|
)
|
|
|
|
features.append(
|
|
InputFeatures(
|
|
example_id=example.example_id,
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
label=label,
|
|
)
|
|
)
|
|
|
|
for f in features[:2]:
|
|
logger.info("*** Example ***")
|
|
logger.info("feature: %s" % f)
|
|
|
|
return features
|
|
|
|
|
|
processors = {"race": RaceProcessor, "swag": SwagProcessor, "arc": ArcProcessor, "syn": SynonymProcessor}
|
|
MULTIPLE_CHOICE_TASKS_NUM_LABELS = {"race", 4, "swag", 4, "arc", 4, "syn", 5}
|