fix reloading model for evaluation in examples

This commit is contained in:
thomwolf 2018-12-13 14:48:12 +01:00
parent 0f544625f4
commit 087798b7fa
4 changed files with 65 additions and 17 deletions

View File

@ -46,13 +46,13 @@ python -m pytest -sv tests/
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme:
- Seven PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
- Eight PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
- [`BertModel`](./pytorch_pretrained_bert/modeling.py#L537) - raw BERT Transformer model (**fully pre-trained**),
- [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L691) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L752) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L620) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L814) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
- [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
- [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
- [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L949) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1015) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
@ -156,7 +156,7 @@ Here is a detailed documentation of the classes in the package and how to use th
| Sub-section | Description |
|-|-|
| [Loading Google AI's pre-trained weigths](#Loading-Google-AIs-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI's pre-trained weight or a PyTorch saved instance |
| [PyTorch models](#PyTorch-models) | API of the seven PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` |
| [PyTorch models](#PyTorch-models) | API of the eight PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForMultipleChoice` or `BertForQuestionAnswering` |
| [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class|
| [Optimizer: `BertAdam`](#Optimizer-BertAdam) | API of the `BertAdam` class |
@ -170,7 +170,7 @@ model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
where
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the seven PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForTokenClassification` or `BertForQuestionAnswering`, and
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the eight PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForTokenClassification`, `BertForMultipleChoice` or `BertForQuestionAnswering`, and
- `PRE_TRAINED_MODEL_NAME_OR_PATH` is either:
- the shortcut name of a Google AI's pre-trained model selected in the list:
@ -353,14 +353,13 @@ The optimizer accepts the following arguments:
BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32).
To help with fine-tuning these models, we have included five techniques that you can activate in the fine-tuning scripts [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py): gradient-accumulation, multi-gpu training, distributed training, optimize on CPU and 16-bits training . For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
To help with fine-tuning these models, we have included several techniques that you can activate in the fine-tuning scripts [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py): gradient-accumulation, multi-gpu training, distributed training and 16-bits training . For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
Here is how to use these techniques in our scripts:
- **Gradient Accumulation**: Gradient accumulation can be used by supplying a integer greater than 1 to the `--gradient_accumulation_steps` argument. The batch at each step will be divided by this integer and gradient will be accumulated over `gradient_accumulation_steps` steps.
- **Multi-GPU**: Multi-GPU is automatically activated when several GPUs are detected and the batches are splitted over the GPUs.
- **Distributed training**: Distributed training can be activated by supplying an integer greater or equal to 0 to the `--local_rank` argument (see below).
- **Optimize on CPU**: The Adam optimizer stores 2 moving average of the weights of the model. If you keep them on GPU 1 (typical behavior), your first GPU will have to store 3-times the size of the model. This is not optimal for large models like `BERT-large` and means your batch size is a lot lower than it could be. This option will perform the optimization and store the averages on the CPU/RAM to free more room on the GPU(s). As the most computational intensive operation is usually the backward pass, this doesn't have a significant impact on the training time. Activate this option with `--optimize_on_cpu` on the [`run_squad.py`](./examples/run_squad.py) script.
- **16-bits training**: 16-bits training, also called mixed-precision training, can reduce the memory requirement of your model on the GPU by using half-precision training, basically allowing to double the batch size. If you have a recent GPU (starting from NVIDIA Volta architecture) you should see no decrease in speed. A good introduction to Mixed precision training can be found [here](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) and a full documentation is [here](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html). In our scripts, this option can be activated by setting the `--fp16` flag and you can play with loss scaling using the `--loss_scaling` flag (see the previously linked documentation for details on loss scaling). If the loss scaling is too high (`Nan` in the gradients) it will be automatically scaled down until the value is acceptable. The default loss scaling is 128 which behaved nicely in our tests.
Note: To use *Distributed Training*, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see [the above mentioned blog post]((https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255)) for more details):
@ -371,16 +370,21 @@ Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your mach
### Fine-tuning with BERT: running the examples
We showcase the same examples as [the original implementation](https://github.com/google-research/bert/): fine-tuning a sequence-level classifier on the MRPC classification corpus and a token-level classifier on the question answering dataset SQuAD.
We showcase several fine-tuning examples based on (and extended from) [the original implementation](https://github.com/google-research/bert/):
Before running these examples you should download the
- a sequence-level classifier on the MRPC classification corpus,
- a token-level classifier on the question answering dataset SQuAD, and
- a sequence-level multiple-choice classifier on the SWAG classification corpus.
#### MRPC
This example code fine-tunes BERT on the Microsoft Research Paraphrase
Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80 and in 27 seconds (!) on single tesla V100 16GB with apex installed.
Before running this example you should download the
[GLUE data](https://gluebenchmark.com/tasks) by running
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
and unpack it to some directory `$GLUE_DIR`. Please also download the `BERT-Base`
checkpoint, unzip it to some directory `$BERT_BASE_DIR`, and convert it to its PyTorch version as explained in the previous section.
This example code fine-tunes `BERT-Base` on the Microsoft Research Paraphrase
Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80.
and unpack it to some directory `$GLUE_DIR`.
```shell
export GLUE_DIR=/path/to/glue
@ -401,7 +405,29 @@ python run_classifier.py \
Our test ran on a few seeds with [the original implementation hyper-parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation results between 84% and 88%.
The second example fine-tunes `BERT-Base` on the SQuAD question answering task.
**Fast run with apex and 16 bit precision: fine-tuning on MRPC in 27 seconds!**
First install apex as indicated [here](https://github.com/NVIDIA/apex).
Then run
```shell
export GLUE_DIR=/path/to/glue
python run_classifier.py \
--task_name MRPC \
--do_train \
--do_eval \
--do_lower_case \
--data_dir $GLUE_DIR/MRPC/ \
--bert_model bert-base-uncased \
--max_seq_length 128 \
--train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3.0 \
--output_dir /tmp/mrpc_output/
```
#### SQuAD
This example code fine-tunes BERT on the SQuAD dataset. It runs in 24 min (with BERT-base) or 68 min (with BERT-large) on single tesla V100 16GB.
The data for SQuAD can be downloaded with the following links and should be saved in a `$SQUAD_DIR` directory.
@ -451,7 +477,7 @@ python run_swag.py \
--gradient_accumulation_steps 4
```
Training with the previous hyper-parameters gave us the following results:
Training with the previous hyper-parameters on a single GPU gave us the following results:
```
eval_accuracy = 0.8062081375587323
eval_loss = 0.5966546792367169

View File

@ -558,6 +558,9 @@ def main():
# Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file)
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict)
if args.fp16:
model.half()
model.to(device)
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = processor.get_dev_examples(args.data_dir)

View File

@ -923,6 +923,9 @@ def main():
# Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file)
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
if args.fp16:
model.half()
model.to(device)
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_squad_examples(

View File

@ -366,8 +366,7 @@ def main():
# Prepare model
model = BertForMultipleChoice.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
num_choices = 4
)
num_choices=4)
if args.fp16:
model.half()
model.to(device)
@ -452,6 +451,9 @@ def main():
loss = loss * args.loss_scale
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
if args.fp16:
optimizer.backward(loss)
@ -466,6 +468,20 @@ def main():
optimizer.zero_grad()
global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file)
model = BertForMultipleChoice.from_pretrained(args.bert_model,
state_dict=model_state_dict,
num_choices=4)
if args.fp16:
model.half()
model.to(device)
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True)
eval_features = convert_examples_to_features(