Merge pull request #73 from huggingface/third-release

Third release
This commit is contained in:
Thomas Wolf 2018-11-30 23:10:30 +01:00 committed by GitHub
commit 66d50ca6ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 311 additions and 67 deletions

View File

@ -14,7 +14,7 @@ This implementation is provided with [Google's pre-trained models](https://githu
| [Doc](#doc) | Detailed documentation |
| [Examples](#examples) | Detailed examples on how to fine-tune Bert |
| [Notebooks](#notebooks) | Introduction on the provided Jupyter Notebooks |
| [TPU](#tup) | Notes on TPU support and pretraining scripts |
| [TPU](#tpu) | Notes on TPU support and pretraining scripts |
| [Command-line interface](#Command-line-interface) | Convert a TensorFlow checkpoint in a PyTorch dump |
## Installation
@ -46,13 +46,14 @@ python -m pytest -sv tests/
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme:
- Six PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
- [`BertModel`](./pytorch_pretrained_bert/modeling.py#L535) - raw BERT Transformer model (**fully pre-trained**),
- [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L689) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L750) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L618) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L812) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L877) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
- Seven PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
- [`BertModel`](./pytorch_pretrained_bert/modeling.py#L537) - raw BERT Transformer model (**fully pre-trained**),
- [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L691) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L752) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L620) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L814) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
- [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L946) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
- Three tokenizers (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file):
- `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.),
@ -153,7 +154,7 @@ Here is a detailed documentation of the classes in the package and how to use th
| Sub-section | Description |
|-|-|
| [Loading Google AI's pre-trained weigths](#Loading-Google-AIs-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI's pre-trained weight or a PyTorch saved instance |
| [PyTorch models](#PyTorch-models) | API of the six PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` |
| [PyTorch models](#PyTorch-models) | API of the seven PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` |
| [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class|
| [Optimizer: `BertAdam`](#Optimizer-BertAdam) | API of the `BertAdam` class |
@ -167,7 +168,7 @@ model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
where
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the six PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering`, and
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the seven PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForTokenClassification` or `BertForQuestionAnswering`, and
- `PRE_TRAINED_MODEL_NAME_OR_PATH` is either:
- the shortcut name of a Google AI's pre-trained model selected in the list:
@ -175,17 +176,23 @@ where
- `bert-base-uncased`: 12-layer, 768-hidden, 12-heads, 110M parameters
- `bert-large-uncased`: 24-layer, 1024-hidden, 16-heads, 340M parameters
- `bert-base-cased`: 12-layer, 768-hidden, 12-heads , 110M parameters
- `bert-base-multilingual`: 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
- `bert-large-cased`: 24-layer, 1024-hidden, 16-heads, 340M parameters
- `bert-base-multilingual-uncased`: (Orig, not recommended) 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
- `bert-base-multilingual-cased`: **(New, recommended)** 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
- `bert-base-chinese`: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters
- a path or url to a pretrained model archive containing:
- `bert_config.json` a configuration file for the model, and
- `pytorch_model.bin` a PyTorch dump of a pre-trained instance `BertForPreTraining` (saved with the usual `torch.save()`)
- `bert_config.json` a configuration file for the model, and
- `pytorch_model.bin` a PyTorch dump of a pre-trained instance `BertForPreTraining` (saved with the usual `torch.save()`)
If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`).
- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. This option is useful in particular when you are using distributed training: to avoid concurrent access to the same weights you can set for example `cache_dir='./pretrained_model_{}'.format(args.local_rank)` (see the section on distributed training for more information)
`Uncased` means that the text has been lowercased before WordPiece tokenization, e.g., `John Smith` becomes `john smith`. The Uncased model also strips out any accent markers. `Cased` means that the true case and accent markers are preserved. Typically, the Uncased model is better unless you know that case information is important for your task (e.g., Named Entity Recognition or Part-of-Speech tagging). For information about the Multilingual and Chinese model, see the [Multilingual README](https://github.com/google-research/bert/blob/master/multilingual.md) or the original TensorFlow repository.
**When using an `uncased model`, make sure to pass `--do_lower_case` to the training scripts. (Or pass `do_lower_case=True` directly to FullTokenizer if you're using your own script.)**
Example:
```python
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
@ -271,7 +278,13 @@ The sequence-level classifier is a linear layer that takes as input the last hid
An example on how to use this class is given in the `run_classifier.py` script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
#### 6. `BertForQuestionAnswering`
#### 6. `BertForTokenClassification`
`BertForTokenClassification` is a fine-tuning model that includes `BertModel` and a token-level classifier on top of the `BertModel`.
The token-level classifier is a linear layer that takes as input the last hidden state of the sequence.
#### 7. `BertForQuestionAnswering`
`BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states.

View File

@ -199,6 +199,7 @@ def main():
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
## Other parameters
parser.add_argument("--do_lower_case", default=False, action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
@ -227,7 +228,7 @@ def main():
layer_indexes = [int(x) for x in args.layers.split(",")]
tokenizer = BertTokenizer.from_pretrained(args.bert_model)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
examples = read_examples(args.input_file)

View File

@ -376,6 +376,10 @@ def main():
default=False,
action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--do_lower_case",
default=False,
action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--train_batch_size",
default=32,
type=int,
@ -473,7 +477,7 @@ def main():
processor = processors[task_name]()
label_list = processor.get_labels()
tokenizer = BertTokenizer.from_pretrained(args.bert_model)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
train_examples = None
num_train_steps = None
@ -542,7 +546,7 @@ def main():
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
loss = model(input_ids, segment_ids, input_mask, label_ids)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.fp16 and args.loss_scale != 1.0:

View File

@ -1,6 +1,7 @@
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForQuestionAnswering)
BertForSequenceClassification, BertForTokenClassification,
BertForQuestionAnswering)
from .optimization import BertAdam
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE

View File

@ -42,7 +42,9 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-base-multilingual': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual.tar.gz",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
}
CONFIG_NAME = 'bert_config.json'
@ -476,7 +478,7 @@ class PreTrainedBertModel(nn.Module):
"associated to this path or url.".format(
pretrained_model_name,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
pretrained_model_name))
archive_file))
return None
if resolved_archive_file == archive_file:
logger.info("loading archive file {}".format(archive_file))
@ -557,7 +559,7 @@ class BertModel(PreTrainedBertModel):
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block,
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
@ -567,10 +569,10 @@ class BertModel(PreTrainedBertModel):
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.BertModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
@ -648,18 +650,18 @@ class BertForPreTraining(PreTrainedBertModel):
sentence classification loss.
if `masked_lm_labels` or `next_sentence_label` is `None`:
Outputs a tuple comprising
- the masked language modeling logits, and
- the next sentence classification logits.
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
- the next sentence classification logits of shape [batch_size, 2].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForPreTraining(config)
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
@ -712,17 +714,17 @@ class BertForMaskedLM(PreTrainedBertModel):
if `masked_lm_labels` is `None`:
Outputs the masked language modeling loss.
if `masked_lm_labels` is `None`:
Outputs the masked language modeling logits.
Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForMaskedLM(config)
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
@ -774,7 +776,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `next_sentence_label` is `None`:
Outputs the next sentence classification logits.
Outputs the next sentence classification logits of shape [batch_size, 2].
Example usage:
```python
@ -783,8 +785,8 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForNextSentencePrediction(config)
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
@ -836,17 +838,17 @@ class BertForSequenceClassification(PreTrainedBertModel):
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits.
Outputs the classification logits of shape [batch_size, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_labels = 2
@ -870,7 +872,73 @@ class BertForSequenceClassification(PreTrainedBertModel):
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss, logits
return loss
else:
return logits
class BertForTokenClassification(PreTrainedBertModel):
"""BERT model for token-level classification.
This module is composed of the BERT model with a linear layer on top of
the full hidden state of the last layer.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_labels].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_labels = 2
model = BertForTokenClassification(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2):
super(BertForTokenClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
else:
return logits
@ -914,17 +982,17 @@ class BertForQuestionAnswering(PreTrainedBertModel):
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
if `start_positions` or `end_positions` is `None`:
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens.
position tokens of shape [batch_size, sequence_length].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForQuestionAnswering(config)
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)

View File

@ -34,9 +34,12 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-base-multilingual': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-vocab.txt",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
}
VOCAB_NAME = 'vocab.txt'
def load_vocab(vocab_file):
@ -98,7 +101,7 @@ class BertTokenizer(object):
return tokens
@classmethod
def from_pretrained(cls, pretrained_model_name, do_lower_case=True):
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
@ -107,16 +110,11 @@ class BertTokenizer(object):
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
else:
vocab_file = pretrained_model_name
if os.path.isdir(vocab_file):
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file)
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, do_lower_case)
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except FileNotFoundError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
@ -124,8 +122,15 @@ class BertTokenizer(object):
"associated to this path or url.".format(
pretrained_model_name,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name))
tokenizer = None
vocab_file))
return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
return tokenizer

View File

@ -2,7 +2,7 @@ from setuptools import find_packages, setup
setup(
name="pytorch_pretrained_bert",
version="0.2.0",
version="0.3.0",
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors",
author_email="thomas@huggingface.co",
description="PyTorch version of Google AI BERT model with script to load Google pre-trained models",

View File

@ -22,7 +22,10 @@ import random
import torch
from pytorch_pretrained_bert import BertConfig, BertModel
from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification)
class BertModelTest(unittest.TestCase):
@ -35,6 +38,7 @@ class BertModelTest(unittest.TestCase):
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
@ -45,7 +49,9 @@ class BertModelTest(unittest.TestCase):
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
scope=None):
self.parent = parent
self.batch_size = batch_size
@ -53,6 +59,7 @@ class BertModelTest(unittest.TestCase):
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
@ -63,10 +70,12 @@ class BertModelTest(unittest.TestCase):
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.scope = scope
def create_model(self):
def prepare_config_and_inputs(self):
input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
@ -77,6 +86,12 @@ class BertModelTest(unittest.TestCase):
if self.use_token_type_ids:
token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
if self.use_labels:
sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels)
config = BertConfig(
vocab_size_or_config_json_file=self.vocab_size,
hidden_size=self.hidden_size,
@ -90,10 +105,16 @@ class BertModelTest(unittest.TestCase):
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
model = BertModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
outputs = {
"sequence_output": all_encoder_layers[-1],
"pooled_output": pooled_output,
@ -101,13 +122,119 @@ class BertModelTest(unittest.TestCase):
}
return outputs
def check_output(self, result):
def check_bert_model_output(self, result):
self.parent.assertListEqual(
[size for layer in result["all_encoder_layers"] for size in layer.size()],
[self.batch_size, self.seq_length, self.hidden_size] * self.num_hidden_layers)
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
model = BertForMaskedLM(config=config)
loss = model(input_ids, token_type_ids, input_mask, token_labels)
prediction_scores = model(input_ids, token_type_ids, input_mask)
outputs = {
"loss": loss,
"prediction_scores": prediction_scores,
}
return outputs
def check_bert_for_masked_lm_output(self, result):
self.parent.assertListEqual(
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
model = BertForNextSentencePrediction(config=config)
loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
seq_relationship_score = model(input_ids, token_type_ids, input_mask)
outputs = {
"loss": loss,
"seq_relationship_score": seq_relationship_score,
}
return outputs
def check_bert_for_next_sequence_prediction_output(self, result):
self.parent.assertListEqual(
list(result["seq_relationship_score"].size()),
[self.batch_size, 2])
def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
model = BertForPreTraining(config=config)
loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask)
outputs = {
"loss": loss,
"prediction_scores": prediction_scores,
"seq_relationship_score": seq_relationship_score,
}
return outputs
def check_bert_for_pretraining_output(self, result):
self.parent.assertListEqual(
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(result["seq_relationship_score"].size()),
[self.batch_size, 2])
def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
model = BertForQuestionAnswering(config=config)
loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
outputs = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
return outputs
def check_bert_for_question_answering_output(self, result):
self.parent.assertListEqual(
list(result["start_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
model = BertForSequenceClassification(config=config, num_labels=self.num_labels)
loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
logits = model(input_ids, token_type_ids, input_mask)
outputs = {
"loss": loss,
"logits": logits,
}
return outputs
def check_bert_for_sequence_classification_output(self, result):
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_labels])
def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
model = BertForTokenClassification(config=config, num_labels=self.num_labels)
loss = model(input_ids, token_type_ids, input_mask, token_labels)
logits = model(input_ids, token_type_ids, input_mask)
outputs = {
"loss": loss,
"logits": logits,
}
return outputs
def check_bert_for_token_classification_output(self, result):
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.seq_length, self.num_labels])
def test_default(self):
self.run_tester(BertModelTest.BertModelTester(self))
@ -118,8 +245,33 @@ class BertModelTest(unittest.TestCase):
self.assertEqual(obj["hidden_size"], 37)
def run_tester(self, tester):
output_result = tester.create_model()
tester.check_output(output_result)
config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_bert_model(*config_and_inputs)
tester.check_bert_model_output(output_result)
output_result = tester.create_bert_for_masked_lm(*config_and_inputs)
tester.check_bert_for_masked_lm_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_next_sequence_prediction(*config_and_inputs)
tester.check_bert_for_next_sequence_prediction_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_pretraining(*config_and_inputs)
tester.check_bert_for_pretraining_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_question_answering(*config_and_inputs)
tester.check_bert_for_question_answering_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_sequence_classification(*config_and_inputs)
tester.check_bert_for_sequence_classification_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_token_classification(*config_and_inputs)
tester.check_bert_for_token_classification_output(output_result)
tester.check_loss_output(output_result)
@classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):