Merge pull request #96 from rodgzilla/multiple-choice-code

BertForMultipleChoice and Swag dataset example.
This commit is contained in:
Thomas Wolf 2018-12-13 12:05:11 +01:00 committed by GitHub
commit ffe9075f48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 658 additions and 7 deletions

View File

@ -52,8 +52,9 @@ This package comprises the following classes that can be imported in Python and
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L752) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**), - [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L752) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L620) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**), - [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L620) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L814) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**), - [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L814) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
- [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**), - [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L946) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**). - [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L949) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1015) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
- Three tokenizers (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file): - Three tokenizers (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file):
- `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.), - `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.),
@ -68,10 +69,11 @@ This package comprises the following classes that can be imported in Python and
The repository further comprises: The repository further comprises:
- Three examples on how to use Bert (in the [`examples` folder](./examples)): - Four examples on how to use Bert (in the [`examples` folder](./examples)):
- [`extract_features.py`](./examples/extract_features.py) - Show how to extract hidden states from an instance of `BertModel`, - [`extract_features.py`](./examples/extract_features.py) - Show how to extract hidden states from an instance of `BertModel`,
- [`run_classifier.py`](./examples/run_classifier.py) - Show how to fine-tune an instance of `BertForSequenceClassification` on GLUE's MRPC task, - [`run_classifier.py`](./examples/run_classifier.py) - Show how to fine-tune an instance of `BertForSequenceClassification` on GLUE's MRPC task,
- [`run_squad.py`](./examples/run_squad.py) - Show how to fine-tune an instance of `BertForQuestionAnswering` on SQuAD v1.0 task. - [`run_squad.py`](./examples/run_squad.py) - Show how to fine-tune an instance of `BertForQuestionAnswering` on SQuAD v1.0 task.
- [`run_swag.py`](./examples/run_swag.py) - Show how to fine-tune an instance of `BertForMultipleChoice` on Swag task.
These examples are detailed in the [Examples](#examples) section of this readme. These examples are detailed in the [Examples](#examples) section of this readme.
@ -278,13 +280,23 @@ The sequence-level classifier is a linear layer that takes as input the last hid
An example on how to use this class is given in the [`run_classifier.py`](./examples/run_classifier.py) script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task. An example on how to use this class is given in the [`run_classifier.py`](./examples/run_classifier.py) script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
#### 6. `BertForTokenClassification` #### 6. `BertForMultipleChoice`
`BertForMultipleChoice` is a fine-tuning model that includes `BertModel` and a linear layer on top of the `BertModel`.
The linear layer outputs a single value for each choice of a multiple choice problem, then all the outputs corresponding to an instance are passed through a softmax to get the model choice.
This implementation is largely inspired by the work of OpenAI in [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) and the answer of Jacob Devlin in the following [issue](https://github.com/google-research/bert/issues/38).
An example on how to use this class is given in the [`run_swag.py`](./examples/run_swag.py) script which can be used to fine-tune a multiple choice classifier using BERT, for example for the Swag task.
#### 7. `BertForTokenClassification`
`BertForTokenClassification` is a fine-tuning model that includes `BertModel` and a token-level classifier on top of the `BertModel`. `BertForTokenClassification` is a fine-tuning model that includes `BertModel` and a token-level classifier on top of the `BertModel`.
The token-level classifier is a linear layer that takes as input the last hidden state of the sequence. The token-level classifier is a linear layer that takes as input the last hidden state of the sequence.
#### 7. `BertForQuestionAnswering` #### 8. `BertForQuestionAnswering`
`BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states. `BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states.
@ -420,6 +432,32 @@ Training with the previous hyper-parameters gave us the following results:
{"f1": 88.52381567990474, "exact_match": 81.22043519394512} {"f1": 88.52381567990474, "exact_match": 81.22043519394512}
``` ```
The data for Swag can be downloaded by cloning the following [repository](https://github.com/rowanz/swagaf)
```shell
export SWAG_DIR=/path/to/SWAG
python run_swag.py \
--bert_model bert-base-uncased \
--do_train \
--do_eval \
--data_dir $SWAG_DIR/data
--train_batch_size 16 \
--learning_rate 2e-5 \
--num_train_epochs 3.0 \
--max_seq_length 80 \
--output_dir /tmp/swag_output/
--gradient_accumulation_steps 4
```
Training with the previous hyper-parameters gave us the following results:
```
eval_accuracy = 0.8062081375587323
eval_loss = 0.5966546792367169
global_step = 13788
loss = 0.06423990014260186
```
## Fine-tuning BERT-large on GPUs ## Fine-tuning BERT-large on GPUs
The options we list above allow to fine-tune BERT-large rather easily on GPU(s) instead of the TPU used by the original implementation. The options we list above allow to fine-tune BERT-large rather easily on GPU(s) instead of the TPU used by the original implementation.

544
examples/run_swag.py Normal file
View File

@ -0,0 +1,544 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT finetuning runner."""
import logging
import os
import argparse
import random
from tqdm import tqdm, trange
import csv
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForMultipleChoice
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
class SwagExample(object):
"""A single training/test example for the SWAG dataset."""
def __init__(self,
swag_id,
context_sentence,
start_ending,
ending_0,
ending_1,
ending_2,
ending_3,
label = None):
self.swag_id = swag_id
self.context_sentence = context_sentence
self.start_ending = start_ending
self.endings = [
ending_0,
ending_1,
ending_2,
ending_3,
]
self.label = label
def __str__(self):
return self.__repr__()
def __repr__(self):
l = [
f"swag_id: {self.swag_id}",
f"context_sentence: {self.context_sentence}",
f"start_ending: {self.start_ending}",
f"ending_0: {self.endings[0]}",
f"ending_1: {self.endings[1]}",
f"ending_2: {self.endings[2]}",
f"ending_3: {self.endings[3]}",
]
if self.label is not None:
l.append(f"label: {self.label}")
return ", ".join(l)
class InputFeatures(object):
def __init__(self,
example_id,
choices_features,
label
):
self.example_id = example_id
self.choices_features = [
{
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids
}
for _, input_ids, input_mask, segment_ids in choices_features
]
self.label = label
def read_swag_examples(input_file, is_training):
with open(input_file, 'r') as f:
reader = csv.reader(f)
lines = list(reader)
if is_training and lines[0][-1] != 'label':
raise ValueError(
"For training, the input file must contain a label column."
)
examples = [
SwagExample(
swag_id = line[2],
context_sentence = line[4],
start_ending = line[5], # in the swag dataset, the
# common beginning of each
# choice is stored in "sent2".
ending_0 = line[7],
ending_1 = line[8],
ending_2 = line[9],
ending_3 = line[10],
label = int(line[11]) if is_training else None
) for line in lines[1:] # we skip the line with the column names
]
return examples
def convert_examples_to_features(examples, tokenizer, max_seq_length,
is_training):
"""Loads a data file into a list of `InputBatch`s."""
# Swag is a multiple choice task. To perform this task using Bert,
# we will use the formatting proposed in "Improving Language
# Understanding by Generative Pre-Training" and suggested by
# @jacobdevlin-google in this issue
# https://github.com/google-research/bert/issues/38.
#
# Each choice will correspond to a sample on which we run the
# inference. For a given Swag example, we will create the 4
# following inputs:
# - [CLS] context [SEP] choice_1 [SEP]
# - [CLS] context [SEP] choice_2 [SEP]
# - [CLS] context [SEP] choice_3 [SEP]
# - [CLS] context [SEP] choice_4 [SEP]
# The model will output a single value for each input. To get the
# final decision of the model, we will run a softmax over these 4
# outputs.
features = []
for example_index, example in enumerate(examples):
context_tokens = tokenizer.tokenize(example.context_sentence)
start_ending_tokens = tokenizer.tokenize(example.start_ending)
choices_features = []
for ending_index, ending in enumerate(example.endings):
# We create a copy of the context tokens in order to be
# able to shrink it according to ending_tokens
context_tokens_choice = context_tokens[:]
ending_tokens = start_ending_tokens + tokenizer.tokenize(ending)
# Modifies `context_tokens_choice` and `ending_tokens` in
# place so that the total length is less than the
# specified length. Account for [CLS], [SEP], [SEP] with
# "- 3"
_truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3)
tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"]
segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
input_mask += padding
segment_ids += padding
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
choices_features.append((tokens, input_ids, input_mask, segment_ids))
label = example.label
if example_index < 5:
logger.info("*** Example ***")
logger.info(f"swag_id: {example.swag_id}")
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
logger.info(f"choice: {choice_idx}")
logger.info(f"tokens: {' '.join(tokens)}")
logger.info(f"input_ids: {' '.join(map(str, input_ids))}")
logger.info(f"input_mask: {' '.join(map(str, input_mask))}")
logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}")
if is_training:
logger.info(f"label: {label}")
features.append(
InputFeatures(
example_id = example.swag_id,
choices_features = choices_features,
label = label
)
)
return features
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def accuracy(out, labels):
outputs = np.argmax(out, axis=1)
return np.sum(outputs == labels)
def select_field(features, field):
return [
[
choice[field]
for choice in feature.choices_features
]
for feature in features
]
def copy_optimizer_params_to_model(named_params_model, named_params_optimizer):
""" Utility function for optimize_on_cpu and 16-bits training.
Copy the parameters optimized on CPU/RAM back to the model on GPU
"""
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
if name_opti != name_model:
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
raise ValueError
param_model.data.copy_(param_opti.data)
def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_nan=False):
""" Utility function for optimize_on_cpu and 16-bits training.
Copy the gradient of the GPU parameters to the CPU/RAMM copy of the model
"""
is_nan = False
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
if name_opti != name_model:
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
raise ValueError
if param_model.grad is not None:
if test_nan and torch.isnan(param_model.grad).sum() > 0:
is_nan = True
if param_opti.grad is None:
param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size()))
param_opti.grad.data.copy_(param_model.grad.data)
else:
param_opti.grad = None
return is_nan
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the .csv files (or other data files) for the task.")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
parser.add_argument("--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model checkpoints will be written.")
## Other parameters
parser.add_argument("--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--do_train",
default=False,
action='store_true',
help="Whether to run training.")
parser.add_argument("--do_eval",
default=False,
action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--do_lower_case",
default=False,
action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--train_batch_size",
default=32,
type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size",
default=8,
type=int,
help="Total batch size for eval.")
parser.add_argument("--learning_rate",
default=5e-5,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs",
default=3.0,
type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion",
default=0.1,
type=float,
help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--no_cuda",
default=False,
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--optimize_on_cpu',
default=False,
action='store_true',
help="Whether to perform optimization and keep the optimizer averages on CPU")
parser.add_argument('--fp16',
default=False,
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',
type=float, default=128,
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
args = parser.parse_args()
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
else:
device = torch.device("cuda", args.local_rank)
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
if args.fp16:
logger.info("16-bits training currently not supported in distributed training")
args.fp16 = False # (see https://github.com/pytorch/pytorch/pull/13496)
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
os.makedirs(args.output_dir, exist_ok=True)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
train_examples = None
num_train_steps = None
if args.do_train:
train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True)
num_train_steps = int(
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
# Prepare model
model = BertForMultipleChoice.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
num_choices = 4
)
if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
# Prepare optimizer
if args.fp16:
param_optimizer = [(n, param.clone().detach().to('cpu').float().requires_grad_()) \
for n, param in model.named_parameters()]
elif args.optimize_on_cpu:
param_optimizer = [(n, param.clone().detach().to('cpu').requires_grad_()) \
for n, param in model.named_parameters()]
else:
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
]
t_total = num_train_steps
if args.local_rank != -1:
t_total = t_total // torch.distributed.get_world_size()
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=t_total)
global_step = 0
if args.do_train:
train_features = convert_examples_to_features(
train_examples, tokenizer, args.max_seq_length, True)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps)
all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
if args.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train()
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
loss = model(input_ids, segment_ids, input_mask, label_ids)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.fp16 and args.loss_scale != 1.0:
# rescale loss for fp16 training
# see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
loss = loss * args.loss_scale
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16 or args.optimize_on_cpu:
if args.fp16 and args.loss_scale != 1.0:
# scale down gradients for fp16 training
for param in model.parameters():
if param.grad is not None:
param.grad.data = param.grad.data / args.loss_scale
is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True)
if is_nan:
logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling")
args.loss_scale = args.loss_scale / 2
model.zero_grad()
continue
optimizer.step()
copy_optimizer_params_to_model(model.named_parameters(), param_optimizer)
else:
optimizer.step()
model.zero_grad()
global_step += 1
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True)
eval_features = convert_examples_to_features(
eval_examples, tokenizer, args.max_seq_length, True)
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'), dtype=torch.long)
all_input_mask = torch.tensor(select_field(eval_features, 'input_mask'), dtype=torch.long)
all_segment_ids = torch.tensor(select_field(eval_features, 'segment_ids'), dtype=torch.long)
all_label = torch.tensor([f.label for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval()
eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
with torch.no_grad():
tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
logits = model(input_ids, segment_ids, input_mask)
logits = logits.detach().cpu().numpy()
label_ids = label_ids.to('cpu').numpy()
tmp_eval_accuracy = accuracy(logits, label_ids)
eval_loss += tmp_eval_loss.mean().item()
eval_accuracy += tmp_eval_accuracy
nb_eval_examples += input_ids.size(0)
nb_eval_steps += 1
eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples
result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy,
'global_step': global_step,
'loss': tr_loss/nb_tr_steps}
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if __name__ == "__main__":
main()

View File

@ -1,7 +1,7 @@
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .modeling import (BertConfig, BertModel, BertForPreTraining, from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction, BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForTokenClassification, BertForSequenceClassification, BertForMultipleChoice,
BertForQuestionAnswering) BertForTokenClassification, BertForQuestionAnswering)
from .optimization import BertAdam from .optimization import BertAdam
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE

View File

@ -877,6 +877,75 @@ class BertForSequenceClassification(PreTrainedBertModel):
return logits return logits
class BertForMultipleChoice(PreTrainedBertModel):
"""BERT model for multiple choice tasks.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_choices = 2
model = BertForMultipleChoice(config, num_choices)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_choices=2):
super(BertForMultipleChoice, self).__init__(config)
self.num_choices = num_choices
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
return loss
else:
return reshaped_logits
class BertForTokenClassification(PreTrainedBertModel): class BertForTokenClassification(PreTrainedBertModel):
"""BERT model for token-level classification. """BERT model for token-level classification.
This module is composed of the BERT model with a linear layer on top of This module is composed of the BERT model with a linear layer on top of