Merge pull request #96 from rodgzilla/multiple-choice-code
BertForMultipleChoice and Swag dataset example.
This commit is contained in:
commit
ffe9075f48
48
README.md
48
README.md
|
@ -52,8 +52,9 @@ This package comprises the following classes that can be imported in Python and
|
||||||
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L752) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
|
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L752) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
|
||||||
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L620) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
|
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L620) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
|
||||||
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L814) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
|
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L814) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
|
||||||
- [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
|
- [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
|
||||||
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L946) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
|
- [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L949) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
|
||||||
|
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1015) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
|
||||||
|
|
||||||
- Three tokenizers (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file):
|
- Three tokenizers (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file):
|
||||||
- `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.),
|
- `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.),
|
||||||
|
@ -68,10 +69,11 @@ This package comprises the following classes that can be imported in Python and
|
||||||
|
|
||||||
The repository further comprises:
|
The repository further comprises:
|
||||||
|
|
||||||
- Three examples on how to use Bert (in the [`examples` folder](./examples)):
|
- Four examples on how to use Bert (in the [`examples` folder](./examples)):
|
||||||
- [`extract_features.py`](./examples/extract_features.py) - Show how to extract hidden states from an instance of `BertModel`,
|
- [`extract_features.py`](./examples/extract_features.py) - Show how to extract hidden states from an instance of `BertModel`,
|
||||||
- [`run_classifier.py`](./examples/run_classifier.py) - Show how to fine-tune an instance of `BertForSequenceClassification` on GLUE's MRPC task,
|
- [`run_classifier.py`](./examples/run_classifier.py) - Show how to fine-tune an instance of `BertForSequenceClassification` on GLUE's MRPC task,
|
||||||
- [`run_squad.py`](./examples/run_squad.py) - Show how to fine-tune an instance of `BertForQuestionAnswering` on SQuAD v1.0 task.
|
- [`run_squad.py`](./examples/run_squad.py) - Show how to fine-tune an instance of `BertForQuestionAnswering` on SQuAD v1.0 task.
|
||||||
|
- [`run_swag.py`](./examples/run_swag.py) - Show how to fine-tune an instance of `BertForMultipleChoice` on Swag task.
|
||||||
|
|
||||||
These examples are detailed in the [Examples](#examples) section of this readme.
|
These examples are detailed in the [Examples](#examples) section of this readme.
|
||||||
|
|
||||||
|
@ -278,13 +280,23 @@ The sequence-level classifier is a linear layer that takes as input the last hid
|
||||||
|
|
||||||
An example on how to use this class is given in the [`run_classifier.py`](./examples/run_classifier.py) script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
|
An example on how to use this class is given in the [`run_classifier.py`](./examples/run_classifier.py) script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
|
||||||
|
|
||||||
#### 6. `BertForTokenClassification`
|
#### 6. `BertForMultipleChoice`
|
||||||
|
|
||||||
|
`BertForMultipleChoice` is a fine-tuning model that includes `BertModel` and a linear layer on top of the `BertModel`.
|
||||||
|
|
||||||
|
The linear layer outputs a single value for each choice of a multiple choice problem, then all the outputs corresponding to an instance are passed through a softmax to get the model choice.
|
||||||
|
|
||||||
|
This implementation is largely inspired by the work of OpenAI in [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) and the answer of Jacob Devlin in the following [issue](https://github.com/google-research/bert/issues/38).
|
||||||
|
|
||||||
|
An example on how to use this class is given in the [`run_swag.py`](./examples/run_swag.py) script which can be used to fine-tune a multiple choice classifier using BERT, for example for the Swag task.
|
||||||
|
|
||||||
|
#### 7. `BertForTokenClassification`
|
||||||
|
|
||||||
`BertForTokenClassification` is a fine-tuning model that includes `BertModel` and a token-level classifier on top of the `BertModel`.
|
`BertForTokenClassification` is a fine-tuning model that includes `BertModel` and a token-level classifier on top of the `BertModel`.
|
||||||
|
|
||||||
The token-level classifier is a linear layer that takes as input the last hidden state of the sequence.
|
The token-level classifier is a linear layer that takes as input the last hidden state of the sequence.
|
||||||
|
|
||||||
#### 7. `BertForQuestionAnswering`
|
#### 8. `BertForQuestionAnswering`
|
||||||
|
|
||||||
`BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states.
|
`BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states.
|
||||||
|
|
||||||
|
@ -420,6 +432,32 @@ Training with the previous hyper-parameters gave us the following results:
|
||||||
{"f1": 88.52381567990474, "exact_match": 81.22043519394512}
|
{"f1": 88.52381567990474, "exact_match": 81.22043519394512}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The data for Swag can be downloaded by cloning the following [repository](https://github.com/rowanz/swagaf)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export SWAG_DIR=/path/to/SWAG
|
||||||
|
|
||||||
|
python run_swag.py \
|
||||||
|
--bert_model bert-base-uncased \
|
||||||
|
--do_train \
|
||||||
|
--do_eval \
|
||||||
|
--data_dir $SWAG_DIR/data
|
||||||
|
--train_batch_size 16 \
|
||||||
|
--learning_rate 2e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_seq_length 80 \
|
||||||
|
--output_dir /tmp/swag_output/
|
||||||
|
--gradient_accumulation_steps 4
|
||||||
|
```
|
||||||
|
|
||||||
|
Training with the previous hyper-parameters gave us the following results:
|
||||||
|
```
|
||||||
|
eval_accuracy = 0.8062081375587323
|
||||||
|
eval_loss = 0.5966546792367169
|
||||||
|
global_step = 13788
|
||||||
|
loss = 0.06423990014260186
|
||||||
|
```
|
||||||
|
|
||||||
## Fine-tuning BERT-large on GPUs
|
## Fine-tuning BERT-large on GPUs
|
||||||
|
|
||||||
The options we list above allow to fine-tune BERT-large rather easily on GPU(s) instead of the TPU used by the original implementation.
|
The options we list above allow to fine-tune BERT-large rather easily on GPU(s) instead of the TPU used by the original implementation.
|
||||||
|
|
|
@ -0,0 +1,544 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""BERT finetuning runner."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
import csv
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
|
from pytorch_pretrained_bert.modeling import BertForMultipleChoice
|
||||||
|
from pytorch_pretrained_bert.optimization import BertAdam
|
||||||
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
|
|
||||||
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
|
level = logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SwagExample(object):
|
||||||
|
"""A single training/test example for the SWAG dataset."""
|
||||||
|
def __init__(self,
|
||||||
|
swag_id,
|
||||||
|
context_sentence,
|
||||||
|
start_ending,
|
||||||
|
ending_0,
|
||||||
|
ending_1,
|
||||||
|
ending_2,
|
||||||
|
ending_3,
|
||||||
|
label = None):
|
||||||
|
self.swag_id = swag_id
|
||||||
|
self.context_sentence = context_sentence
|
||||||
|
self.start_ending = start_ending
|
||||||
|
self.endings = [
|
||||||
|
ending_0,
|
||||||
|
ending_1,
|
||||||
|
ending_2,
|
||||||
|
ending_3,
|
||||||
|
]
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
l = [
|
||||||
|
f"swag_id: {self.swag_id}",
|
||||||
|
f"context_sentence: {self.context_sentence}",
|
||||||
|
f"start_ending: {self.start_ending}",
|
||||||
|
f"ending_0: {self.endings[0]}",
|
||||||
|
f"ending_1: {self.endings[1]}",
|
||||||
|
f"ending_2: {self.endings[2]}",
|
||||||
|
f"ending_3: {self.endings[3]}",
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.label is not None:
|
||||||
|
l.append(f"label: {self.label}")
|
||||||
|
|
||||||
|
return ", ".join(l)
|
||||||
|
|
||||||
|
|
||||||
|
class InputFeatures(object):
|
||||||
|
def __init__(self,
|
||||||
|
example_id,
|
||||||
|
choices_features,
|
||||||
|
label
|
||||||
|
|
||||||
|
):
|
||||||
|
self.example_id = example_id
|
||||||
|
self.choices_features = [
|
||||||
|
{
|
||||||
|
'input_ids': input_ids,
|
||||||
|
'input_mask': input_mask,
|
||||||
|
'segment_ids': segment_ids
|
||||||
|
}
|
||||||
|
for _, input_ids, input_mask, segment_ids in choices_features
|
||||||
|
]
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
|
||||||
|
def read_swag_examples(input_file, is_training):
|
||||||
|
with open(input_file, 'r') as f:
|
||||||
|
reader = csv.reader(f)
|
||||||
|
lines = list(reader)
|
||||||
|
|
||||||
|
if is_training and lines[0][-1] != 'label':
|
||||||
|
raise ValueError(
|
||||||
|
"For training, the input file must contain a label column."
|
||||||
|
)
|
||||||
|
|
||||||
|
examples = [
|
||||||
|
SwagExample(
|
||||||
|
swag_id = line[2],
|
||||||
|
context_sentence = line[4],
|
||||||
|
start_ending = line[5], # in the swag dataset, the
|
||||||
|
# common beginning of each
|
||||||
|
# choice is stored in "sent2".
|
||||||
|
ending_0 = line[7],
|
||||||
|
ending_1 = line[8],
|
||||||
|
ending_2 = line[9],
|
||||||
|
ending_3 = line[10],
|
||||||
|
label = int(line[11]) if is_training else None
|
||||||
|
) for line in lines[1:] # we skip the line with the column names
|
||||||
|
]
|
||||||
|
|
||||||
|
return examples
|
||||||
|
|
||||||
|
def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||||
|
is_training):
|
||||||
|
"""Loads a data file into a list of `InputBatch`s."""
|
||||||
|
|
||||||
|
# Swag is a multiple choice task. To perform this task using Bert,
|
||||||
|
# we will use the formatting proposed in "Improving Language
|
||||||
|
# Understanding by Generative Pre-Training" and suggested by
|
||||||
|
# @jacobdevlin-google in this issue
|
||||||
|
# https://github.com/google-research/bert/issues/38.
|
||||||
|
#
|
||||||
|
# Each choice will correspond to a sample on which we run the
|
||||||
|
# inference. For a given Swag example, we will create the 4
|
||||||
|
# following inputs:
|
||||||
|
# - [CLS] context [SEP] choice_1 [SEP]
|
||||||
|
# - [CLS] context [SEP] choice_2 [SEP]
|
||||||
|
# - [CLS] context [SEP] choice_3 [SEP]
|
||||||
|
# - [CLS] context [SEP] choice_4 [SEP]
|
||||||
|
# The model will output a single value for each input. To get the
|
||||||
|
# final decision of the model, we will run a softmax over these 4
|
||||||
|
# outputs.
|
||||||
|
features = []
|
||||||
|
for example_index, example in enumerate(examples):
|
||||||
|
context_tokens = tokenizer.tokenize(example.context_sentence)
|
||||||
|
start_ending_tokens = tokenizer.tokenize(example.start_ending)
|
||||||
|
|
||||||
|
choices_features = []
|
||||||
|
for ending_index, ending in enumerate(example.endings):
|
||||||
|
# We create a copy of the context tokens in order to be
|
||||||
|
# able to shrink it according to ending_tokens
|
||||||
|
context_tokens_choice = context_tokens[:]
|
||||||
|
ending_tokens = start_ending_tokens + tokenizer.tokenize(ending)
|
||||||
|
# Modifies `context_tokens_choice` and `ending_tokens` in
|
||||||
|
# place so that the total length is less than the
|
||||||
|
# specified length. Account for [CLS], [SEP], [SEP] with
|
||||||
|
# "- 3"
|
||||||
|
_truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3)
|
||||||
|
|
||||||
|
tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"]
|
||||||
|
segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1)
|
||||||
|
|
||||||
|
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
input_mask = [1] * len(input_ids)
|
||||||
|
|
||||||
|
# Zero-pad up to the sequence length.
|
||||||
|
padding = [0] * (max_seq_length - len(input_ids))
|
||||||
|
input_ids += padding
|
||||||
|
input_mask += padding
|
||||||
|
segment_ids += padding
|
||||||
|
|
||||||
|
assert len(input_ids) == max_seq_length
|
||||||
|
assert len(input_mask) == max_seq_length
|
||||||
|
assert len(segment_ids) == max_seq_length
|
||||||
|
|
||||||
|
choices_features.append((tokens, input_ids, input_mask, segment_ids))
|
||||||
|
|
||||||
|
label = example.label
|
||||||
|
if example_index < 5:
|
||||||
|
logger.info("*** Example ***")
|
||||||
|
logger.info(f"swag_id: {example.swag_id}")
|
||||||
|
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
|
||||||
|
logger.info(f"choice: {choice_idx}")
|
||||||
|
logger.info(f"tokens: {' '.join(tokens)}")
|
||||||
|
logger.info(f"input_ids: {' '.join(map(str, input_ids))}")
|
||||||
|
logger.info(f"input_mask: {' '.join(map(str, input_mask))}")
|
||||||
|
logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}")
|
||||||
|
if is_training:
|
||||||
|
logger.info(f"label: {label}")
|
||||||
|
|
||||||
|
features.append(
|
||||||
|
InputFeatures(
|
||||||
|
example_id = example.swag_id,
|
||||||
|
choices_features = choices_features,
|
||||||
|
label = label
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||||
|
"""Truncates a sequence pair in place to the maximum length."""
|
||||||
|
|
||||||
|
# This is a simple heuristic which will always truncate the longer sequence
|
||||||
|
# one token at a time. This makes more sense than truncating an equal percent
|
||||||
|
# of tokens from each, since if one sequence is very short then each token
|
||||||
|
# that's truncated likely contains more information than a longer sequence.
|
||||||
|
while True:
|
||||||
|
total_length = len(tokens_a) + len(tokens_b)
|
||||||
|
if total_length <= max_length:
|
||||||
|
break
|
||||||
|
if len(tokens_a) > len(tokens_b):
|
||||||
|
tokens_a.pop()
|
||||||
|
else:
|
||||||
|
tokens_b.pop()
|
||||||
|
|
||||||
|
def accuracy(out, labels):
|
||||||
|
outputs = np.argmax(out, axis=1)
|
||||||
|
return np.sum(outputs == labels)
|
||||||
|
|
||||||
|
def select_field(features, field):
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
choice[field]
|
||||||
|
for choice in feature.choices_features
|
||||||
|
]
|
||||||
|
for feature in features
|
||||||
|
]
|
||||||
|
|
||||||
|
def copy_optimizer_params_to_model(named_params_model, named_params_optimizer):
|
||||||
|
""" Utility function for optimize_on_cpu and 16-bits training.
|
||||||
|
Copy the parameters optimized on CPU/RAM back to the model on GPU
|
||||||
|
"""
|
||||||
|
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
|
||||||
|
if name_opti != name_model:
|
||||||
|
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
|
||||||
|
raise ValueError
|
||||||
|
param_model.data.copy_(param_opti.data)
|
||||||
|
|
||||||
|
def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_nan=False):
|
||||||
|
""" Utility function for optimize_on_cpu and 16-bits training.
|
||||||
|
Copy the gradient of the GPU parameters to the CPU/RAMM copy of the model
|
||||||
|
"""
|
||||||
|
is_nan = False
|
||||||
|
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
|
||||||
|
if name_opti != name_model:
|
||||||
|
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
|
||||||
|
raise ValueError
|
||||||
|
if param_model.grad is not None:
|
||||||
|
if test_nan and torch.isnan(param_model.grad).sum() > 0:
|
||||||
|
is_nan = True
|
||||||
|
if param_opti.grad is None:
|
||||||
|
param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size()))
|
||||||
|
param_opti.grad.data.copy_(param_model.grad.data)
|
||||||
|
else:
|
||||||
|
param_opti.grad = None
|
||||||
|
return is_nan
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
## Required parameters
|
||||||
|
parser.add_argument("--data_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The input data dir. Should contain the .csv files (or other data files) for the task.")
|
||||||
|
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||||
|
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||||
|
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
||||||
|
parser.add_argument("--output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The output directory where the model checkpoints will be written.")
|
||||||
|
|
||||||
|
## Other parameters
|
||||||
|
parser.add_argument("--max_seq_length",
|
||||||
|
default=128,
|
||||||
|
type=int,
|
||||||
|
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||||
|
"Sequences longer than this will be truncated, and sequences shorter \n"
|
||||||
|
"than this will be padded.")
|
||||||
|
parser.add_argument("--do_train",
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help="Whether to run training.")
|
||||||
|
parser.add_argument("--do_eval",
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help="Whether to run eval on the dev set.")
|
||||||
|
parser.add_argument("--do_lower_case",
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help="Set this flag if you are using an uncased model.")
|
||||||
|
parser.add_argument("--train_batch_size",
|
||||||
|
default=32,
|
||||||
|
type=int,
|
||||||
|
help="Total batch size for training.")
|
||||||
|
parser.add_argument("--eval_batch_size",
|
||||||
|
default=8,
|
||||||
|
type=int,
|
||||||
|
help="Total batch size for eval.")
|
||||||
|
parser.add_argument("--learning_rate",
|
||||||
|
default=5e-5,
|
||||||
|
type=float,
|
||||||
|
help="The initial learning rate for Adam.")
|
||||||
|
parser.add_argument("--num_train_epochs",
|
||||||
|
default=3.0,
|
||||||
|
type=float,
|
||||||
|
help="Total number of training epochs to perform.")
|
||||||
|
parser.add_argument("--warmup_proportion",
|
||||||
|
default=0.1,
|
||||||
|
type=float,
|
||||||
|
help="Proportion of training to perform linear learning rate warmup for. "
|
||||||
|
"E.g., 0.1 = 10%% of training.")
|
||||||
|
parser.add_argument("--no_cuda",
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help="Whether not to use CUDA when available")
|
||||||
|
parser.add_argument("--local_rank",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="local_rank for distributed training on gpus")
|
||||||
|
parser.add_argument('--seed',
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="random seed for initialization")
|
||||||
|
parser.add_argument('--gradient_accumulation_steps',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||||
|
parser.add_argument('--optimize_on_cpu',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help="Whether to perform optimization and keep the optimizer averages on CPU")
|
||||||
|
parser.add_argument('--fp16',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||||
|
parser.add_argument('--loss_scale',
|
||||||
|
type=float, default=128,
|
||||||
|
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.local_rank == -1 or args.no_cuda:
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||||
|
n_gpu = torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
device = torch.device("cuda", args.local_rank)
|
||||||
|
n_gpu = 1
|
||||||
|
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
|
torch.distributed.init_process_group(backend='nccl')
|
||||||
|
if args.fp16:
|
||||||
|
logger.info("16-bits training currently not supported in distributed training")
|
||||||
|
args.fp16 = False # (see https://github.com/pytorch/pytorch/pull/13496)
|
||||||
|
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
|
||||||
|
|
||||||
|
if args.gradient_accumulation_steps < 1:
|
||||||
|
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||||
|
args.gradient_accumulation_steps))
|
||||||
|
|
||||||
|
args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
random.seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
if n_gpu > 0:
|
||||||
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
if not args.do_train and not args.do_eval:
|
||||||
|
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||||
|
|
||||||
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||||
|
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
|
train_examples = None
|
||||||
|
num_train_steps = None
|
||||||
|
if args.do_train:
|
||||||
|
train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True)
|
||||||
|
num_train_steps = int(
|
||||||
|
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
||||||
|
|
||||||
|
# Prepare model
|
||||||
|
model = BertForMultipleChoice.from_pretrained(args.bert_model,
|
||||||
|
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
|
||||||
|
num_choices = 4
|
||||||
|
)
|
||||||
|
if args.fp16:
|
||||||
|
model.half()
|
||||||
|
model.to(device)
|
||||||
|
if args.local_rank != -1:
|
||||||
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||||
|
output_device=args.local_rank)
|
||||||
|
elif n_gpu > 1:
|
||||||
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
|
# Prepare optimizer
|
||||||
|
if args.fp16:
|
||||||
|
param_optimizer = [(n, param.clone().detach().to('cpu').float().requires_grad_()) \
|
||||||
|
for n, param in model.named_parameters()]
|
||||||
|
elif args.optimize_on_cpu:
|
||||||
|
param_optimizer = [(n, param.clone().detach().to('cpu').requires_grad_()) \
|
||||||
|
for n, param in model.named_parameters()]
|
||||||
|
else:
|
||||||
|
param_optimizer = list(model.named_parameters())
|
||||||
|
no_decay = ['bias', 'gamma', 'beta']
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
|
||||||
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
|
||||||
|
]
|
||||||
|
t_total = num_train_steps
|
||||||
|
if args.local_rank != -1:
|
||||||
|
t_total = t_total // torch.distributed.get_world_size()
|
||||||
|
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||||
|
lr=args.learning_rate,
|
||||||
|
warmup=args.warmup_proportion,
|
||||||
|
t_total=t_total)
|
||||||
|
|
||||||
|
global_step = 0
|
||||||
|
if args.do_train:
|
||||||
|
train_features = convert_examples_to_features(
|
||||||
|
train_examples, tokenizer, args.max_seq_length, True)
|
||||||
|
logger.info("***** Running training *****")
|
||||||
|
logger.info(" Num examples = %d", len(train_examples))
|
||||||
|
logger.info(" Batch size = %d", args.train_batch_size)
|
||||||
|
logger.info(" Num steps = %d", num_train_steps)
|
||||||
|
all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
|
||||||
|
all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
|
||||||
|
all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
|
||||||
|
all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
|
||||||
|
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
|
||||||
|
if args.local_rank == -1:
|
||||||
|
train_sampler = RandomSampler(train_data)
|
||||||
|
else:
|
||||||
|
train_sampler = DistributedSampler(train_data)
|
||||||
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||||
|
tr_loss = 0
|
||||||
|
nb_tr_examples, nb_tr_steps = 0, 0
|
||||||
|
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||||
|
batch = tuple(t.to(device) for t in batch)
|
||||||
|
input_ids, input_mask, segment_ids, label_ids = batch
|
||||||
|
loss = model(input_ids, segment_ids, input_mask, label_ids)
|
||||||
|
if n_gpu > 1:
|
||||||
|
loss = loss.mean() # mean() to average on multi-gpu.
|
||||||
|
if args.fp16 and args.loss_scale != 1.0:
|
||||||
|
# rescale loss for fp16 training
|
||||||
|
# see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
|
||||||
|
loss = loss * args.loss_scale
|
||||||
|
if args.gradient_accumulation_steps > 1:
|
||||||
|
loss = loss / args.gradient_accumulation_steps
|
||||||
|
loss.backward()
|
||||||
|
tr_loss += loss.item()
|
||||||
|
nb_tr_examples += input_ids.size(0)
|
||||||
|
nb_tr_steps += 1
|
||||||
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
|
if args.fp16 or args.optimize_on_cpu:
|
||||||
|
if args.fp16 and args.loss_scale != 1.0:
|
||||||
|
# scale down gradients for fp16 training
|
||||||
|
for param in model.parameters():
|
||||||
|
if param.grad is not None:
|
||||||
|
param.grad.data = param.grad.data / args.loss_scale
|
||||||
|
is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True)
|
||||||
|
if is_nan:
|
||||||
|
logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling")
|
||||||
|
args.loss_scale = args.loss_scale / 2
|
||||||
|
model.zero_grad()
|
||||||
|
continue
|
||||||
|
optimizer.step()
|
||||||
|
copy_optimizer_params_to_model(model.named_parameters(), param_optimizer)
|
||||||
|
else:
|
||||||
|
optimizer.step()
|
||||||
|
model.zero_grad()
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
|
eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True)
|
||||||
|
eval_features = convert_examples_to_features(
|
||||||
|
eval_examples, tokenizer, args.max_seq_length, True)
|
||||||
|
logger.info("***** Running evaluation *****")
|
||||||
|
logger.info(" Num examples = %d", len(eval_examples))
|
||||||
|
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||||
|
all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'), dtype=torch.long)
|
||||||
|
all_input_mask = torch.tensor(select_field(eval_features, 'input_mask'), dtype=torch.long)
|
||||||
|
all_segment_ids = torch.tensor(select_field(eval_features, 'segment_ids'), dtype=torch.long)
|
||||||
|
all_label = torch.tensor([f.label for f in eval_features], dtype=torch.long)
|
||||||
|
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
|
||||||
|
# Run prediction for full data
|
||||||
|
eval_sampler = SequentialSampler(eval_data)
|
||||||
|
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
eval_loss, eval_accuracy = 0, 0
|
||||||
|
nb_eval_steps, nb_eval_examples = 0, 0
|
||||||
|
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
|
||||||
|
input_ids = input_ids.to(device)
|
||||||
|
input_mask = input_mask.to(device)
|
||||||
|
segment_ids = segment_ids.to(device)
|
||||||
|
label_ids = label_ids.to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
|
||||||
|
logits = model(input_ids, segment_ids, input_mask)
|
||||||
|
|
||||||
|
logits = logits.detach().cpu().numpy()
|
||||||
|
label_ids = label_ids.to('cpu').numpy()
|
||||||
|
tmp_eval_accuracy = accuracy(logits, label_ids)
|
||||||
|
|
||||||
|
eval_loss += tmp_eval_loss.mean().item()
|
||||||
|
eval_accuracy += tmp_eval_accuracy
|
||||||
|
|
||||||
|
nb_eval_examples += input_ids.size(0)
|
||||||
|
nb_eval_steps += 1
|
||||||
|
|
||||||
|
eval_loss = eval_loss / nb_eval_steps
|
||||||
|
eval_accuracy = eval_accuracy / nb_eval_examples
|
||||||
|
|
||||||
|
result = {'eval_loss': eval_loss,
|
||||||
|
'eval_accuracy': eval_accuracy,
|
||||||
|
'global_step': global_step,
|
||||||
|
'loss': tr_loss/nb_tr_steps}
|
||||||
|
|
||||||
|
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||||
|
with open(output_eval_file, "w") as writer:
|
||||||
|
logger.info("***** Eval results *****")
|
||||||
|
for key in sorted(result.keys()):
|
||||||
|
logger.info(" %s = %s", key, str(result[key]))
|
||||||
|
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1,7 +1,7 @@
|
||||||
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||||
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
||||||
BertForMaskedLM, BertForNextSentencePrediction,
|
BertForMaskedLM, BertForNextSentencePrediction,
|
||||||
BertForSequenceClassification, BertForTokenClassification,
|
BertForSequenceClassification, BertForMultipleChoice,
|
||||||
BertForQuestionAnswering)
|
BertForTokenClassification, BertForQuestionAnswering)
|
||||||
from .optimization import BertAdam
|
from .optimization import BertAdam
|
||||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
|
|
|
@ -877,6 +877,75 @@ class BertForSequenceClassification(PreTrainedBertModel):
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class BertForMultipleChoice(PreTrainedBertModel):
|
||||||
|
"""BERT model for multiple choice tasks.
|
||||||
|
This module is composed of the BERT model with a linear layer on top of
|
||||||
|
the pooled output.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
`config`: a BertConfig class instance with the configuration to build a new model.
|
||||||
|
`num_choices`: the number of classes for the classifier. Default = 2.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
||||||
|
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||||
|
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||||
|
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
||||||
|
with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
|
||||||
|
and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
||||||
|
`attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
|
||||||
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||||
|
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||||
|
a batch has varying length sentences.
|
||||||
|
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
|
||||||
|
with indices selected in [0, ..., num_choices].
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
if `labels` is not `None`:
|
||||||
|
Outputs the CrossEntropy classification loss of the output with the labels.
|
||||||
|
if `labels` is `None`:
|
||||||
|
Outputs the classification logits of shape [batch_size, num_labels].
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```python
|
||||||
|
# Already been converted into WordPiece token ids
|
||||||
|
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
|
||||||
|
input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
|
||||||
|
token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
|
||||||
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||||
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||||
|
|
||||||
|
num_choices = 2
|
||||||
|
|
||||||
|
model = BertForMultipleChoice(config, num_choices)
|
||||||
|
logits = model(input_ids, token_type_ids, input_mask)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
def __init__(self, config, num_choices=2):
|
||||||
|
super(BertForMultipleChoice, self).__init__(config)
|
||||||
|
self.num_choices = num_choices
|
||||||
|
self.bert = BertModel(config)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||||
|
self.apply(self.init_bert_weights)
|
||||||
|
|
||||||
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
||||||
|
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||||
|
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
||||||
|
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
|
||||||
|
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
|
||||||
|
pooled_output = self.dropout(pooled_output)
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
reshaped_logits = logits.view(-1, self.num_choices)
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(reshaped_logits, labels)
|
||||||
|
return loss
|
||||||
|
else:
|
||||||
|
return reshaped_logits
|
||||||
|
|
||||||
|
|
||||||
class BertForTokenClassification(PreTrainedBertModel):
|
class BertForTokenClassification(PreTrainedBertModel):
|
||||||
"""BERT model for token-level classification.
|
"""BERT model for token-level classification.
|
||||||
This module is composed of the BERT model with a linear layer on top of
|
This module is composed of the BERT model with a linear layer on top of
|
||||||
|
|
Loading…
Reference in New Issue