update readme

This commit is contained in:
thomwolf 2018-11-04 21:26:03 +01:00
parent efb44a8310
commit 6cc651778a
1 changed files with 92 additions and 36 deletions

128
README.md
View File

@ -1,23 +1,22 @@
# PyTorch implementation of Google AI's BERT model # PyTorch implementation of Google AI's BERT model with a script to load Google's pre-trained models
## Introduction ## Introduction
This is an op-for-op PyTorch reimplementation of the [TensorFlow code](https://github.com/google-research/bert) released by Google AI with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805). This repository contains an op-for-op PyTorch reimplementation of [Google's TensorFlow code for the BERT model](https://github.com/google-research/bert) that was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
This implementation can load any pre-trained TensorFlow BERT checkpoint in a PyTorch model (see below). This implementation can load any pre-trained TensorFlow BERT checkpoint (in particular [Google's pre-trained models](https://github.com/google-research/bert)) and a conversion script is provided (see below).
There are a few differences with the TensorFlow model: ## Loading a TensorFlow checkpoint (e.g. Google's pre-trained models)
- this PyTorch implementation support multi-GPU and distributed training (see below), You can convert any TensorFlow checkpoint for BERT (in particular the pre-trained weights released by GoogleAI) in a PyTorch save file by using [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py).
- the current stable version of PyTorch (0.4.1) doesn't support TPU training and as a consequence, the pre-training script are not included in this repo (see below). TPU support is supposed to be available in PyTorch v1.0. We will update the repository with TPU-adapted pre-training scripts at that time. In the meantime, you can use the TensorFlow version to train a model on TPU and import a TensorFlow checkpoint as described below.
## Loading a TensorFlow checkpoint (in particular Google's pre-trained models) in the Pytorch model This script takes as input a TensorFlow checkpoint (`bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and create a PyTorch model for this configuration, load the weights from the TensorFlow checpoint in the PyTorch model and save the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`).
You can convert any TensorFlow checkpoint, and in particular the pre-trained weights released by GoogleAI, by using `convert_tf_checkpoint_to_pytorch.py`. To run this specific script you will need to have TensorFlow and PyTorch installed (`pip install tensorflow`).
This script takes as input a TensorFlow checkpoint (`bert_model.ckpt`) load it in the PyTorch model and save the model in a standard PyTorch model save file that can be imported using the usual `torch.load()` command (see the `run_classifier.py` script for an example). You can find Google's pre-trained models in [Google's TensorFlow repository for BERT](https://github.com/google-research/bert).
TensorFlow pre-trained models can be found in the [original TensorFlow code](https://github.com/google-research/bert). Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model: Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model:
```shell ```shell
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
@ -28,17 +27,90 @@ python convert_tf_checkpoint_to_pytorch.py \
--pytorch_dump_path=$BERT_BASE_DIR/pytorch_model.bin --pytorch_dump_path=$BERT_BASE_DIR/pytorch_model.bin
``` ```
## Multi-GPU and Distributed Training ## PyTorch models for BERT
Multi-GPU is automatically activated in the scripts when multiple GPUs are detected. This repository contains three PyTorch models that you can find in [`modeling.py`](modeling.py):
Distributed training is activated by suppying a `--local_rank` arguments to the `run_classifier.py` or the `run_squad.py` scripts. - `BertModel` - the basic model
- `BertForSequenceClassification` - the model with a sequence classification head
- `BertForQuestionAnswering` - the model with a token classification head
For more information on how to use distributed training with PyTorch, you can read [this simple introduction](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) we wrote earlier this month. ### 1. `BertModel`
`BertModel` is the basic BERT model with a layer of token, position and sequence embeddings followed by a series of identical self-attention blocks (12 for BERT-base, 24 for BERT-large).
This model outputs a tuple of:
- `all_encoder_layers`: a list of the full sequences of hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), and
- `pooled_output`: the output of a classifier pretrained on top of the hidden state associated to the first character of the input to classifier the Next-Sentence task (see BERT's paper).
An example on how to use this class is given in the `extract_features.py` script which can be used to extract the hidden states of the model for a given input.
### 2. `BertForSequenceClassification`
`BertForSequenceClassification` is a fine-tuning model that includes `BertModel` and a sequence (or pair of sequence) classifier on top of the `BertModel`.
The sequence classifier is a linear layer that takes as input the last hidden state of the first character in the input sequence (see Figures 3a and 3b in the BERT paper).
An example on how to use this class is given in the `run_classifier.py` script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
### 3. `BertForQuestionAnswering`
`BertForSequenceClassification` is a fine-tuning model that includes `BertModel` with a two-class classifiers on top of the full sequence of last hidden states.
The token classifier takes as input the full sequence of the last hidden state and compute two scores for each tokens that can for example respectively be the score that a given token is a `start_span` or `end_span` token (see Figures 3c and 3d in the BERT paper).
An example on how to use this class is given in the `run_squad.py` script which can be used to fine-tune a token classifier using BERT, for example for the SQuAS task.
## Installation, requirements, test
This code was tested on Python 3.5+. The requirements are:
- PyTorch (>= 0.4.0)
- tqdm
To install the dependencies:
````bash
pip install -r ./requirements.txt
````
A series of tests is included in the [`test` folder](./test) and can be run using `pytest` (install pytest if needed: `pip install pytest`). You can run the tests with the command:
```bash
pytest -sv ./tests/
```
## Training on large batches: gradient accumulation, multi-GPU and distributed training
BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch_size of 32 is recommended).
To help fine-tuning, we have included three techniques that you can leverage in the fine-tuning scripts `run_classifier.py` and `run_squad.py`: gradient-accumulation, multi-gpu and distributed training. For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that we published earlier this month.
Here are the details:
- **Gradient Accumulation**: Gradient accumulation can be used by supplying a integer greater than one to the `--gradient_accumulation_steps` argument. The batch at each step will be divided by this integer and gradient will be accumulated over `gradient_accumulation_steps` steps.
- **Multi-GPU**: Multi-GPU is automatically activated when several GPUs are detected and the batch are splitted over the GPUs.
- **Distributed training**: Distributed training can be activated by suppying an integer greater or equal to zero to the `--local_rank` argument. To use Distributed training, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see the above blog post for more details):
```bash
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=$THIS_MACHINE_INDEX --master_addr="192.168.1.1" --master_port=1234 run_classifier.py (--arg1 --arg2 --arg3 and all other arguments of the run_classifier script)
```
Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your machine (0, 1, 2...) and the machine with rank 0 has an IP adress `192.168.1.1` and an open port `1234`.
## TPU support and pretraining scripts
TPU are not supported by the current stable release of PyTorch (0.4.1). However, the next version of PyTorch (v1.0) should support training on TPU and is expected to be released soon (see the recent [official announcement](https://cloud.google.com/blog/products/ai-machine-learning/introducing-pytorch-across-google-cloud)).
We will add TPU support when this next release is published.
The original TensorFlow code furthe comprises two scripts for pre-training BERT: [create_pretraining_data.py](https://github.com/google-research/bert/blob/master/create_pretraining_data.py) and [run_pretraining.py](https://github.com/google-research/bert/blob/master/run_pretraining.py).
Since, pre-training BERT is a particularly expensive operation that basically requires one or several TPUs to be completed in a reasonable amout of time (see details [here](https://github.com/google-research/bert#pre-training-with-bert)) we have decided to wait for the inclusion of TPU support in PyTorch to convert these pre-training scripts.
## Fine-tuning with BERT: running the examples ## Fine-tuning with BERT: running the examples
We showcase the same examples as in the original implementation: fine-tuning on the MRPC classification corpus and the question answering dataset SQUAD. We showcase the same examples as [the original implementation](https://github.com/google-research/bert/): fine-tuning a sequence-level classifier on the MRPC classification corpus and a token-level classifier on the question answering dataset SQuAD.
Before running theses examples you should download the Before running theses examples you should download the
[GLUE data](https://gluebenchmark.com/tasks) by running [GLUE data](https://gluebenchmark.com/tasks) by running
@ -68,7 +140,7 @@ python run_classifier.py \
--output_dir /tmp/mrpc_output/ --output_dir /tmp/mrpc_output/
``` ```
The next example fine-tunes `BERT-Base` on the SQuAD question answering task. The next example fine-tunes `BERT-Base` on the SQuAD question answering task. This example runs in about 4 hours on a multi-GPU K-80.
The data for SQuAD can be downloaded with the following links and should be saved in a `$SQUAD_DIR` directory. The data for SQuAD can be downloaded with the following links and should be saved in a `$SQUAD_DIR` directory.
@ -95,26 +167,10 @@ python run_squad.py \
--output_dir=../debug_squad/ --output_dir=../debug_squad/
``` ```
## Comparing TensorFlow and PyTorch models ## Comparing the PyTorch model and the TensorFlow model predictions
We also include [a simple Jupyter Notebook](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/Comparing%20TF%20and%20PT%20models.ipynb) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model (within the rounding errors and the differing backend implementations of the operations, in our case we found a standard deviation of about 4e-7 on the last hidden state of the 12th layer). Please follow the instructions in the Notebook to run it. We also include [a simple Jupyter Notebook](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/Comparing%20TF%20and%20PT%20models.ipynb) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model.
## Note on pre-training This NoteBook extract the full sequence hidden state layers of each model and compute the sandard deviation between them.
The original TensorFlow code comprise two scripts that can be used for pre-training BERT: [create_pretraining_data.py](https://github.com/google-research/bert/blob/master/create_pretraining_data.py) and [run_pretraining.py](https://github.com/google-research/bert/blob/master/run_pretraining.py). In our case we found a standard deviation of about 4e-7 on the last hidden state of the 12th layer. Please follow the instructions in the Notebook to run it.
As the authors notice, pre-training BERT is particularly expensive and requires TPU to run in a reasonable amout of time (see [here](https://github.com/google-research/bert#pre-training-with-bert)).
We have decided to wait for the up-coming release of PyTorch v1.0 which is expected support training on TPU for porting these scripts (see the recent [official announcement](https://cloud.google.com/blog/products/ai-machine-learning/introducing-pytorch-across-google-cloud)).
## Requirements
The main dependencies of this code are:
- PyTorch (>= 0.4.0)
- tqdm
To install the dependencies:
````bash
pip install -r ./requirements.txt
````