Merge branch 'master' into improved_testing

This commit is contained in:
Thomas Wolf 2019-08-30 13:40:35 +02:00 committed by GitHub
commit 50e615f43d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 3103 additions and 56 deletions

View File

@ -26,9 +26,27 @@ jobs:
- run: sudo pip install pytest codecov pytest-cov
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
- run: codecov
deploy_doc:
working_directory: ~/pytorch-transformers
docker:
- image: circleci/python:3.5
steps:
- add_ssh_keys:
fingerprints:
- "5b:7a:95:18:07:8c:aa:76:4c:60:35:88:ad:60:56:71"
- checkout
- run: sudo pip install -r docs/requirements.txt
- run: sudo pip install -r requirements.txt
- run: cd docs && make clean && make html && scp -r -oStrictHostKeyChecking=no _build/html/* $doc:$dir
workflow_filters: &workflow_filters
filters:
branches:
only:
- master
workflows:
version: 2
build_and_test:
jobs:
- build_py3
- build_py2
version: 2
build_and_test:
jobs:
- build_py3
- build_py2
- deploy_doc: *workflow_filters

1
.gitignore vendored
View File

@ -129,4 +129,5 @@ proc_data
runs
examples/runs
# data
data

View File

@ -12,7 +12,9 @@ The library currently contains PyTorch implementations, pre-trained model weight
4. **[Transformer-XL](https://github.com/kimiyoung/transformer-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
5. **[XLNet](https://github.com/zihangdai/xlnet/)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
6. **[XLM](https://github.com/facebookresearch/XLM/)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
7. **[RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)** (from Facebook), a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du et al.
7. **[RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
8. **[DistilBERT](https://github.com/huggingface/pytorch-transformers/tree/master/examples/distillation)** (from HuggingFace), released together with the blogpost [Smaller, faster, cheaper, lighter: Introducing DistilBERT, a distilled version of BERT](https://medium.com/huggingface/distilbert-8cf3380435b5
) by Victor Sanh, Lysandre Debut and Thomas Wolf.
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/pytorch-transformers/examples.html).
@ -393,8 +395,8 @@ for batch in train_data:
loss = model(batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
scheduler.step()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
```

View File

@ -12,8 +12,8 @@ Examples
- How to use gradient-accumulation, multi-gpu training, distributed training, optimize on CPU and 16-bits training to train Bert models
* - `Fine-tuning with BERT: running the examples <#fine-tuning-bert-examples>`_
- Running the examples in `examples <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples>`_\ : ``extract_classif.py``\ , ``run_bert_classifier.py``\ , ``run_bert_squad.py`` and ``run_lm_finetuning.py``
* - `Fine-tuning with OpenAI GPT, Transformer-XL and GPT-2 <#fine-tuning>`_
- Running the examples in `examples <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples>`_\ : ``run_openai_gpt.py``\ , ``run_transfo_xl.py`` and ``run_gpt2.py``
* - `Fine-tuning with OpenAI GPT, Transformer-XL, GPT-2 as well as BERT and RoBERTa <#fine-tuning>`_
- Running the examples in `examples <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples>`_\ : ``run_openai_gpt.py``\ , ``run_transfo_xl.py``, ``run_gpt2.py`` and ``run_lm_finetuning.py``
* - `Fine-tuning BERT-large on GPUs <#fine-tuning-bert-large>`_
- How to fine tune ``BERT large``
@ -384,7 +384,7 @@ Training with the previous hyper-parameters on a single GPU gave us the followin
LM Fine-tuning
~~~~~~~~~~~~~~
The data should be a text file in the same format as `sample_text.txt <./samples/sample_text.txt>`_ (one sentence per line, docs separated by empty line).
The data should be a text file in the same format as `sample_text.txt <./pytorch_transformers/tests/fixtures/sample_text.txt/sample_text.txt>`_ (one sentence per line, docs separated by empty line).
You can download an `exemplary training corpus <https://ext-bert-sample.obs.eu-de.otc.t-systems.com/small_wiki_sentence_corpus.txt>`_ generated from wikipedia articles and split into ~500k sentences with spaCy.
Training one epoch on this corpus takes about 1:20h on 4 x NVIDIA Tesla P100 with ``train_batch_size=200`` and ``max_seq_length=128``\ :
@ -395,12 +395,13 @@ Thank to the work of @Rocketknight1 and @tholor there are now **several scripts*
OpenAI GPT, Transformer-XL and GPT-2: running the examples
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
We provide three examples of scripts for OpenAI GPT, Transformer-XL and OpenAI GPT-2 based on (and extended from) the respective original implementations:
We provide three examples of scripts for OpenAI GPT, Transformer-XL, OpenAI GPT-2, BERT and RoBERTa based on (and extended from) the respective original implementations:
* fine-tuning OpenAI GPT on the ROCStories dataset
* evaluating Transformer-XL on Wikitext 103
* unconditional and conditional generation from a pre-trained OpenAI GPT-2 model
* fine-tuning GPT/GPT-2 on a causal language modeling task and BERT/RoBERTa on a masked language modeling task
Fine-tuning OpenAI GPT on the RocStories dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -454,7 +455,47 @@ Unconditional generation:
python run_gpt2.py --unconditional
The same option as in the original scripts are provided, please refere to the code of the example and the original repository of OpenAI.
The same option as in the original scripts are provided, please refer to the code of the example and the original repository of OpenAI.
Causal LM fine-tuning on GPT/GPT-2, Masked LM fine-tuning on BERT/RoBERTa
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Before running the following examples you should download the `WikiText-2 dataset <https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/>`__ and unpack it to some directory `$WIKITEXT_2_DATASET`
The following results were obtained using the `raw` WikiText-2 (no tokens were replaced before the tokenization).
This example fine-tunes GPT-2 on the WikiText-2 dataset. The loss function is a causal language modeling loss (perplexity).
.. code-block:: bash
export WIKITEXT_2_DATASET=/path/to/wikitext_dataset
python run_lm_finetuning.py
--output_dir=output
--model_type=gpt2
--model_name_or_path=gpt2
--do_train
--train_data_file=$WIKITEXT_2_DATASET/wiki.train.raw
--do_eval
--eval_data_file=$WIKITEXT_2_DATASET/wiki.test.raw
This takes about half an hour to train on a single K80 GPU and about one minute for the evaluation to run.
It reaches a score of about 20 perplexity once fine-tuned on the dataset.
This example fine-tunes RoBERTa on the WikiText-2 dataset. The loss function is a masked language modeling loss (masked perplexity).
The `--mlm` flag is necessary to fine-tune BERT/RoBERTa on masked language modeling.
.. code-block:: bash
export WIKITEXT_2_DATASET=/path/to/wikitext_dataset
python run_lm_finetuning.py
--output_dir=output
--model_type=roberta
--model_name_or_path=roberta-base
--do_train
--train_data_file=$WIKITEXT_2_DATASET/wiki.train.raw
--do_eval
--eval_data_file=$WIKITEXT_2_DATASET/wiki.test.raw
--mlm
.. _fine-tuning-BERT-large:

View File

@ -48,3 +48,4 @@ The library currently contains PyTorch implementations, pre-trained model weight
model_doc/xlm
model_doc/xlnet
model_doc/roberta
model_doc/distilbert

View File

@ -0,0 +1,43 @@
DistilBERT
----------------------------------------------------
``DistilBertConfig``
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.DistilBertConfig
:members:
``DistilBertTokenizer``
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.DistilBertTokenizer
:members:
``DistilBertModel``
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.DistilBertModel
:members:
``DistilBertForMaskedLM``
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.DistilBertForMaskedLM
:members:
``DistilBertForSequenceClassification``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.DistilBertForSequenceClassification
:members:
``DistilBertForQuestionAnswering``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.DistilBertForQuestionAnswering
:members:

View File

@ -111,5 +111,13 @@ Here is the full list of the currently provided pretrained models together with
| | | | ``roberta-large`` fine-tuned on `MNLI <http://www.nyu.edu/projects/bowman/multinli/>`__. |
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/roberta>`__) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| DistilBERT | ``distilbert-base-uncased`` | | 6-layer, 768-hidden, 12-heads, 66M parameters |
| | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint |
| | | (see `details <https://medium.com/huggingface/distilbert-8cf3380435b5>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``distilbert-base-uncased-distilled-squad`` | | 6-layer, 768-hidden, 12-heads, 66M parameters |
| | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint, with an additional linear layer. |
| | | (see `details <https://medium.com/huggingface/distilbert-8cf3380435b5>`__) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
.. <https://huggingface.co/pytorch-transformers/examples.html>`__

View File

@ -0,0 +1,100 @@
# DistilBERT
This folder contains the original code used to train DistilBERT as well as examples showcasing how to use DistilBERT.
## What is DistilBERT
DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving over 95% of Bert's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production.
For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5
).
## How to use DistilBERT
PyTorch-Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT):
- `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters.
- `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knwoledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.2 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score).
Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models.
```python
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
```
## How to train DistilBERT
In the following, we will explain how you can train your own compressed model.
### A. Preparing the data
The weights we release are trained using a concatenation of Toronto Book Corpus and English Wikipedia (same training data as the English version of BERT).
To avoid processing the data several time, we do it once and for all before the training. From now on, will suppose that you have a text file `dump.txt` which contains one sequence per line (a sequence being composed of one of several coherent sentences).
First, we will binarize the data, i.e. tokenize the data and convert each token in an index in our model's vocabulary.
```bash
python scripts/binarized_data.py \
--file_path data/dump.txt \
--bert_tokenizer bert-base-uncased \
--dump_file data/binarized_text
```
Our implementation of masked language modeling loss follows [XLM](https://github.com/facebookresearch/XLM)'s one and smoothes the probability of masking with a factor that put more emphasis on rare words. Thus we count the occurences of each tokens in the data:
```bash
python scripts/token_counts.py \
--data_file data/binarized_text.bert-base-uncased.pickle \
--token_counts_dump data/token_counts.bert-base-uncased.pickle
```
### B. Training
Training with distillation is really simple once you have pre-processed the data:
```bash
python train.py \
--dump_path serialization_dir/my_first_training \
--data_file data/binarized_text.bert-base-uncased.pickle \
--token_counts data/token_counts.bert-base-uncased.pickle \
--force # overwrites the `dump_path` if it already exists.
```
By default, this will launch a training on a single GPU (even if more are available on the cluster). Other parameters are available in the command line, please look in `train.py` or run `python train.py --help` to list them.
We highly encourage you to distributed training for training DistilBert as the training corpus is quite large. Here's an example that runs a distributed training on a single node having 4 GPUs:
```bash
export NODE_RANK=0
export N_NODES=1
export N_GPU_NODE=4
export WORLD_SIZE=4
export MASTER_PORT=<AN_OPEN_PORT>
export MASTER_ADDR=<I.P.>
pkill -f 'python -u train.py'
python -m torch.distributed.launch \
--nproc_per_node=$N_GPU_NODE \
--nnodes=$N_NODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
train.py \
--force \
--n_gpu $WORLD_SIZE \
--data_file data/dump_concat_wiki_toronto_bk.bert-base-uncased.pickle \
--token_counts data/token_counts_concat_wiki_toronto_bk.bert-base-uncased.pickle \
--dump_path serialization_dir/with_transform/last_word
```
**Tips** Starting distillated training with good initialization of the model weights is crucial to reach decent performance. In our experiments, we initialized our model from a few layers of the teacher (Bert) itself! Please refer to `scripts/extract_for_distil.py` to create a valid initialization checkpoint and use `--from_pretrained_weights` and `--from_pretrained_config` arguments to use this initialization for the distilled training!
Happy distillation!

View File

@ -0,0 +1,201 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Dataloaders to train DistilBERT
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
"""
from typing import List
import math
from itertools import chain
from collections import Counter
import numpy as np
import torch
from utils import logger
class Dataset:
def __init__(self,
params,
data):
self.params = params
self.tokens_per_batch = params.tokens_per_batch
self.batch_size = params.batch_size
self.shuffle = params.shuffle
self.group_by_size = params.group_by_size
self.token_ids = np.array(data)
self.lengths = np.uint16([len(t) for t in data])
self.check()
self.remove_long_sequences()
self.remove_empty_sequences()
self.check()
self.print_statistics()
def __len__(self):
return len(self.lengths)
def check(self):
"""
Some sanity checks
"""
assert len(self.token_ids) == len(self.lengths)
def remove_long_sequences(self):
"""
Sequences that are too long are splitted by chunk of max_position_embeddings.
"""
indices = self.lengths >= self.params.max_position_embeddings
logger.info(f'Splitting {sum(indices)} too long sequences.')
def divide_chunks(l, n):
return [l[i:i + n] for i in range(0, len(l), n)]
new_tok_ids = []
new_lengths = []
cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token']
max_len = self.params.max_position_embeddings
for seq_, len_ in zip(self.token_ids, self.lengths):
if len_ <= max_len:
new_tok_ids.append(seq_)
new_lengths.append(len_)
else:
sub_seqs = []
for sub_s in divide_chunks(seq_, max_len-2):
if sub_s[0] != cls_id:
sub_s = np.insert(sub_s, 0, cls_id)
if sub_s[-1] != sep_id:
sub_s = np.insert(sub_s, len(sub_s), cls_id)
assert len(sub_s) <= max_len
sub_seqs.append(sub_s)
new_tok_ids.extend(sub_seqs)
new_lengths.extend([len(l) for l in sub_seqs])
self.token_ids = np.array(new_tok_ids)
self.lengths = np.array(new_lengths)
def remove_empty_sequences(self):
"""
Too short sequences are simply removed. This could be tunedd.
"""
init_size = len(self)
indices = self.lengths > 5
self.token_ids = self.token_ids[indices]
self.lengths = self.lengths[indices]
new_size = len(self)
logger.info(f'Remove {init_size - new_size} too short (<=5 tokens) sequences.')
def print_statistics(self):
"""
Print some statistics on the corpus. Only the master process.
"""
if not self.params.is_master:
return
logger.info(f'{len(self)} sequences')
# data_len = sum(self.lengths)
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
# unk_idx = self.params.special_tok_ids['unk_token']
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
def select_data(self, a: int, b: int):
"""
Select a subportion of the data.
"""
n_sequences = len(self)
assert 0 <= a < b <= n_sequences, ValueError(f'`0 <= a < b <= n_sequences` is not met with a={a} and b={b}')
logger.info(f'Selecting sequences from {a} to {b} (excluded).')
self.token_ids = self.token_ids[a:b]
self.lengths = self.lengths[a:b]
self.check()
def split(self):
"""
Distributed training: split the data accross the processes.
"""
assert self.params.n_gpu > 1
logger.info('Splitting the data accross the processuses.')
n_seq = len(self)
n_seq_per_procesus = n_seq // self.params.world_size
a = n_seq_per_procesus * self.params.global_rank
b = a + n_seq_per_procesus
self.select_data(a=a, b=b)
def batch_sequences(self,
token_ids: List[List[int]],
lengths: List[int]):
"""
Do the padding and transform into torch.tensor.
"""
assert len(token_ids) == len(lengths)
# Max for paddings
max_seq_len_ = max(lengths)
# Pad token ids
pad_idx = self.params.special_tok_ids['pad_token']
tk_ = [list(t.astype(int)) + [pad_idx]*(max_seq_len_-len(t)) for t in token_ids]
assert len(tk_) == len(token_ids)
assert all(len(t) == max_seq_len_ for t in tk_)
tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
lg_t = torch.tensor(lengths.astype(int)) # (bs)
return tk_t, lg_t
def get_batches_iterator(self,
batches):
"""
Return an iterator over batches.
"""
for sequences_ids in batches:
token_ids, lengths = self.batch_sequences(self.token_ids[sequences_ids],
self.lengths[sequences_ids])
yield (token_ids, lengths)
def get_iterator(self,
seed: int = None):
"""
Return a data iterator.
"""
rng = np.random.RandomState(seed)
n_sequences = len(self)
indices = np.arange(n_sequences)
if self.group_by_size:
indices = indices[np.argsort(self.lengths[indices], kind='mergesort')]
if self.tokens_per_batch == -1:
batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size))
else:
assert self.tokens_per_batch > 0
batch_ids = np.cumsum(self.lengths[indices]) // self.tokens_per_batch
_, bounds = np.unique(batch_ids, return_index=True)
batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)]
if bounds[-1] < len(indices):
batches.append(indices[bounds[-1]:])
if self.shuffle:
rng.shuffle(batches)
assert n_sequences == sum([len(x) for x in batches])
assert self.lengths[indices].sum() == sum([self.lengths[x].sum() for x in batches])
return self.get_batches_iterator(batches=batches)

View File

@ -0,0 +1,448 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" The distiller to distil DistilBERT
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
"""
import os
import math
from tensorboardX import SummaryWriter
from tqdm import trange, tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_transformers import AdamW, WarmupLinearSchedule
from utils import logger
from dataset import Dataset
class Distiller:
def __init__(self,
params: dict,
dataloader: Dataset,
token_probs: torch.tensor,
student: nn.Module,
teacher: nn.Module):
logger.info('Initializing Distiller')
self.params = params
self.dump_path = params.dump_path
self.multi_gpu = params.multi_gpu
self.fp16 = params.fp16
self.student = student
self.teacher = teacher
self.dataloader = dataloader
if self.params.n_gpu > 1:
self.dataloader.split()
self.get_iterator(seed=params.seed)
self.temperature = params.temperature
assert self.temperature > 0.
self.alpha_ce = params.alpha_ce
self.alpha_mlm = params.alpha_mlm
self.alpha_mse = params.alpha_mse
assert self.alpha_ce >= 0.
assert self.alpha_mlm >= 0.
assert self.alpha_mse >= 0.
assert self.alpha_ce + self.alpha_mlm + self.alpha_mse > 0.
self.mlm_mask_prop = params.mlm_mask_prop
assert 0.0 <= self.mlm_mask_prop <= 1.0
assert params.word_mask + params.word_keep + params.word_rand == 1.0
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs
self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs
if self.fp16:
self.pred_probs = self.pred_probs.half()
self.token_probs = self.token_probs.half()
self.epoch = 0
self.n_iter = 0
self.n_total_iter = 0
self.n_sequences_epoch = 0
self.total_loss_epoch = 0
self.last_loss = 0
self.last_loss_ce = 0
self.last_loss_mlm = 0
self.last_loss_mse = 0
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
self.mse_loss_fct = nn.MSELoss(reduction='sum')
logger.info('--- Initializing model optimizer')
assert params.gradient_accumulation_steps >= 1
self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1
num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': params.weight_decay},
{'params': [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0}
]
logger.info("------ Number of trainable parameters (student): %i" % sum([p.numel() for p in self.student.parameters() if p.requires_grad]))
logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
self.optimizer = AdamW(optimizer_grouped_parameters,
lr=params.learning_rate,
eps=params.adam_epsilon,
betas=(0.9, 0.98))
self.scheduler = WarmupLinearSchedule(self.optimizer,
warmup_steps=warmup_steps,
t_total=num_train_optimization_steps)
if self.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
self.student, self.optimizer = amp.initialize(self.student,
self.optimizer,
opt_level=self.params.fp16_opt_level)
self.teacher = self.teacher.half()
if self.multi_gpu:
if self.fp16:
from apex.parallel import DistributedDataParallel
logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
self.student = DistributedDataParallel(self.student)
else:
from torch.nn.parallel import DistributedDataParallel
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
self.student = DistributedDataParallel(self.student,
device_ids=[params.local_rank],
output_device=params.local_rank)
self.is_master = params.is_master
if self.is_master:
logger.info('--- Initializing Tensorboard')
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train'))
self.tensorboard.add_text(tag='config', text_string=str(self.params), global_step=0)
def get_iterator(self,
seed: int = None):
"""
Initialize the data iterator.
Each process has its own data iterator (iterating on his own random portion of the dataset).
Input:
------
seed: `int` - The random seed.
"""
logger.info('--- Initializing Data Iterator')
self.data_iterator = self.dataloader.get_iterator(seed=seed)
def get_batch(self):
"""
Call the data iterator to output a new batch.
If the data iterator went through the whole dataset, create a new iterator.
"""
assert hasattr(self, 'data_iterator')
try:
x = next(self.data_iterator)
except StopIteration:
logger.warning('--- Went through the whole dataset. Creating new data iterator.')
self.data_iterator = self.dataloader.get_iterator()
x = next(self.data_iterator)
return x
def prepare_batch(self,
batch):
"""
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
Input:
------
batch: `Tuple`
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
Output:
-------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict.
"""
token_ids, lengths = batch
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
assert token_ids.size(0) == lengths.size(0)
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])
bs, max_seq_len = token_ids.size()
mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
x_prob = self.token_probs[token_ids.flatten()]
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.uint8, device=token_ids.device)
pred_mask[tgt_ids] = 1
pred_mask = pred_mask.view(bs, max_seq_len)
pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0
# mask a number of words == 0 [8] (faster with fp16)
if self.fp16:
n1 = pred_mask.sum().item()
if n1 > 8:
pred_mask = pred_mask.view(-1)
n2 = max(n1 % 8, 8 * (n1 // 8))
if n2 != n1:
pred_mask[torch.nonzero(pred_mask).view(-1)[:n1-n2]] = 0
pred_mask = pred_mask.view(bs, max_seq_len)
assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()
_token_ids_real = token_ids[pred_mask]
_token_ids_rand = _token_ids_real.clone().random_(self.params.vocab_size)
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token'])
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
mlm_labels[1-pred_mask] = -1
return token_ids, attn_mask, mlm_labels
def round_batch(self,
x: torch.tensor,
lengths: torch.tensor):
"""
For float16 only.
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
Input:
------
x: `torch.tensor(bs, seq_length)` - The token ids.
lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.
Output:
-------
x: `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
"""
if not self.fp16 or len(lengths) < 8:
return x, lengths
# number of sentences == 0 [8]
bs1 = len(lengths)
bs2 = 8 * (bs1 // 8)
assert bs2 > 0 and bs2 % 8 == 0
if bs1 != bs2:
idx = torch.randperm(bs1)[:bs2]
lengths = lengths[idx]
slen = lengths.max().item()
x = x[idx, :slen]
else:
idx = None
# sequence length == 0 [8]
ml1 = x.size(1)
if ml1 % 8 != 0:
pad = 8 - (ml1 % 8)
ml2 = ml1 + pad
pad_id = self.params.special_tok_ids['pad_token']
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
x = torch.cat([x, padding_tensor], 1)
assert x.size() == (bs2, ml2)
assert x.size(0) % 8 == 0
assert x.size(1) % 8 == 0
return x, lengths
def train(self):
"""
The real training loop.
"""
if self.is_master: logger.info('Starting training')
self.student.train()
self.teacher.eval()
for _ in range(self.params.n_epoch):
if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')
iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
for __ in range(self.num_steps_epoch):
batch = self.get_batch()
if self.params.n_gpu > 0:
batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch)
token_ids, attn_mask, mlm_labels = self.prepare_batch(batch=batch)
self.step(input_ids=token_ids, attention_mask=attn_mask, mlm_labels=mlm_labels)
iter_bar.update()
iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}',
'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}'})
iter_bar.close()
if self.is_master: logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}')
self.end_epoch()
if self.is_master: logger.info('Training is finished')
def step(self,
input_ids: torch.tensor,
attention_mask: torch.tensor,
mlm_labels: torch.tensor):
"""
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
and possibly a parameter update (depending on the gradient accumulation).
Input:
------
input_ids: `torch.tensor(bs, seq_length)` - The token ids.
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels.
"""
s_logits = self.student(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size)
with torch.no_grad():
t_logits = self.teacher(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size)
assert s_logits.size() == t_logits.size()
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
if self.params.restrict_ce_to_mask:
mask = (mlm_labels>-1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
else:
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
assert t_logits_slct.size() == s_logits_slct.size()
loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1),
F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2
loss = self.alpha_ce*loss_ce
if self.alpha_mlm > 0.:
loss_mlm = self.mlm_loss_fct(s_logits.view(-1, s_logits.size(-1)), mlm_labels.view(-1))
loss += self.alpha_mlm * loss_mlm
if self.alpha_mse > 0.:
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction
loss += self.alpha_mse * loss_mse
self.total_loss_epoch += loss.item()
self.last_loss = loss.item()
self.last_loss_ce = loss_ce.item()
if self.alpha_mlm > 0.:
self.last_loss_mlm = loss_mlm.item()
if self.alpha_mse > 0.:
self.last_loss_mse = loss_mse.item()
self.optimize(loss)
self.n_sequences_epoch += input_ids.size(0)
def optimize(self,
loss):
"""
Normalization on the loss (gradient accumulation or distributed training), followed by
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
Also update the metrics for tensorboard.
"""
# Check for NaN
if (loss != loss).data.any():
logger.error('NaN detected')
exit()
if self.multi_gpu:
loss = loss.mean()
if self.params.gradient_accumulation_steps > 1:
loss = loss / self.params.gradient_accumulation_steps
if self.fp16:
from apex import amp
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
self.iter()
if self.n_iter % self.params.gradient_accumulation_steps == 0:
if self.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
self.scheduler.step()
self.optimizer.step()
self.optimizer.zero_grad()
def iter(self):
"""
Update global counts, write to tensorboard and save checkpoint.
"""
self.n_iter += 1
self.n_total_iter += 1
if self.n_total_iter % self.params.log_interval == 0:
self.log_tensorboard()
if self.n_total_iter % self.params.checkpoint_interval == 0:
self.save_checkpoint()
def log_tensorboard(self):
"""
Log into tensorboard. Only by the master process.
"""
if not self.is_master:
return
for param_name, param in self.student.named_parameters():
self.tensorboard.add_scalar(tag='parameter_mean/' + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag='parameter_std/' + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter)
if param.grad is None:
continue
self.tensorboard.add_scalar(tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(),global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch", scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter)
if self.alpha_mlm > 0.:
self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter)
if self.alpha_mse > 0.:
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
def end_epoch(self):
"""
Finally arrived at the end of epoch (full pass on dataset).
Do some tensorboard logging and checkpoint saving.
"""
logger.info(f'{self.n_sequences_epoch} sequences have been trained during this epoch.')
if self.is_master:
self.save_checkpoint(checkpoint_name=f'model_epoch_{self.epoch}.pth')
self.tensorboard.add_scalar(tag='epoch/loss', scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.epoch)
self.epoch += 1
self.n_sequences_epoch = 0
self.n_iter = 0
self.total_loss_epoch = 0
def save_checkpoint(self,
checkpoint_name: str = 'checkpoint.pth'):
"""
Save the current state. Only by the master process.
"""
if not self.is_master:
return
mdl_to_save = self.student.module if hasattr(self.student, 'module') else self.student
mdl_to_save.config.save_pretrained(self.dump_path)
state_dict = mdl_to_save.state_dict()
torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))

View File

@ -0,0 +1 @@
gitpython==3.0.2

View File

@ -0,0 +1,77 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocessing script before training DistilBERT.
"""
import argparse
import pickle
import random
import time
import numpy as np
from pytorch_transformers import BertTokenizer
from ..utils import logger
def main():
parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).")
parser.add_argument('--file_path', type=str, default='data/dump.txt',
help='The path to the data.')
parser.add_argument('--bert_tokenizer', type=str, default='bert-base-uncased',
help="The tokenizer to use.")
parser.add_argument('--dump_file', type=str, default='data/dump',
help='The dump file prefix.')
args = parser.parse_args()
logger.info(f'Loading Tokenizer ({args.bert_tokenizer})')
bert_tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer)
logger.info(f'Loading text from {args.file_path}')
with open(args.file_path, 'r', encoding='utf8') as fp:
data = fp.readlines()
logger.info(f'Start encoding')
logger.info(f'{len(data)} examples to process.')
rslt = []
iter = 0
interval = 10000
start = time.time()
for text in data:
text = f'[CLS] {text.strip()} [SEP]'
token_ids = bert_tokenizer.encode(text)
rslt.append(token_ids)
iter += 1
if iter % interval == 0:
end = time.time()
logger.info(f'{iter} examples processed. - {(end-start)/interval:.2f}s/expl')
start = time.time()
logger.info('Finished binarization')
logger.info(f'{len(data)} examples processed.')
dp_file = f'{args.dump_file}.{args.bert_tokenizer}.pickle'
rslt_ = [np.uint16(d) for d in rslt]
random.shuffle(rslt_)
logger.info(f'Dump to {dp_file}')
with open(dp_file, 'wb') as handle:
pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,76 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocessing script before training DistilBERT.
"""
from pytorch_transformers import BertForPreTraining
import torch
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForPreTraining for Transfer Learned Distillation")
parser.add_argument("--bert_model", default='bert-base-uncased', type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/transfer_learning_checkpoint_0247911.pth', type=str)
parser.add_argument("--vocab_transform", action='store_true')
args = parser.parse_args()
model = BertForPreTraining.from_pretrained(args.bert_model)
state_dict = model.state_dict()
compressed_sd = {}
for w in ['word_embeddings', 'position_embeddings']:
compressed_sd[f'distilbert.embeddings.{w}.weight'] = \
state_dict[f'bert.embeddings.{w}.weight']
for w in ['weight', 'bias']:
compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \
state_dict[f'bert.embeddings.LayerNorm.{w}']
std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]:
for w in ['weight', 'bias']:
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.query.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.key.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.value.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.dense.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.intermediate.dense.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.output.dense.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
std_idx += 1
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
if args.vocab_transform:
for w in ['weight', 'bias']:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
print(f'N layers selected for distillation: {std_idx}')
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
print(f'Save transfered checkpoint to {args.dump_checkpoint}.')
torch.save(compressed_sd, args.dump_checkpoint)

View File

@ -0,0 +1,47 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocessing script before training DistilBERT.
"""
from collections import Counter
import argparse
import pickle
from utils import logger
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)")
parser.add_argument("--data_file", type=str, default="data/dump.bert-base-uncased.pickle",
help="The binarized dataset.")
parser.add_argument("--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle",
help="The dump file.")
parser.add_argument("--vocab_size", default=30522, type=int)
args = parser.parse_args()
logger.info(f'Loading data from {args.data_file}')
with open(args.data_file, 'rb') as fp:
data = pickle.load(fp)
logger.info('Counting occurences for MLM.')
counter = Counter()
for tk_ids in data:
counter.update(tk_ids)
counts = [0]*args.vocab_size
for k, v in counter.items():
counts[k] = v
logger.info(f'Dump to {args.token_counts_dump}')
with open(args.token_counts_dump, 'wb') as handle:
pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL)

View File

@ -0,0 +1,236 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training DistilBERT.
"""
import os
import argparse
import pickle
import json
import shutil
import numpy as np
import torch
from pytorch_transformers import BertTokenizer, BertForMaskedLM
from pytorch_transformers import DistilBertForMaskedLM, DistilBertConfig
from distiller import Distiller
from utils import git_log, logger, init_gpu_params, set_seed
from dataset import Dataset
def main():
parser = argparse.ArgumentParser(description="Training")
parser.add_argument("--dump_path", type=str, required=True,
help="The output directory (log, checkpoints, parameters, etc.)")
parser.add_argument("--data_file", type=str, required=True,
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.")
parser.add_argument("--token_counts", type=str, required=True,
help="The token counts in the data_file for MLM.")
parser.add_argument("--force", action='store_true',
help="Overwrite dump_path if it already exists.")
parser.add_argument("--vocab_size", default=30522, type=int,
help="The vocabulary size.")
parser.add_argument("--max_position_embeddings", default=512, type=int,
help="Maximum sequence length we can model (including [CLS] and [SEP]).")
parser.add_argument("--sinusoidal_pos_embds", action='store_false',
help="If true, the position embeddings are simply fixed with sinusoidal embeddings.")
parser.add_argument("--n_layers", default=6, type=int,
help="Number of Transformer blocks.")
parser.add_argument("--n_heads", default=12, type=int,
help="Number of heads in the self-attention module.")
parser.add_argument("--dim", default=768, type=int,
help="Dimension through the network. Must be divisible by n_heads")
parser.add_argument("--hidden_dim", default=3072, type=int,
help="Intermediate dimension in the FFN.")
parser.add_argument("--dropout", default=0.1, type=float,
help="Dropout.")
parser.add_argument("--attention_dropout", default=0.1, type=float,
help="Dropout in self-attention.")
parser.add_argument("--activation", default='gelu', type=str,
help="Activation to use in self-attention")
parser.add_argument("--tie_weights_", action='store_false',
help="If true, we tie the embeddings matrix with the projection over the vocabulary matrix. Default is true.")
parser.add_argument("--from_pretrained_weights", default=None, type=str,
help="Load student initialization checkpoint.")
parser.add_argument("--from_pretrained_config", default=None, type=str,
help="Load student initialization architecture config.")
parser.add_argument("--bert_model", default='bert-base-uncased', type=str,
help="The teacher BERT model.")
parser.add_argument("--temperature", default=2., type=float,
help="Temperature for the softmax temperature.")
parser.add_argument("--alpha_ce", default=0.5, type=float,
help="Linear weight for the distillation loss. Must be >=0.")
parser.add_argument("--alpha_mlm", default=0.5, type=float,
help="Linear weight for the MLM loss. Must be >=0.")
parser.add_argument("--alpha_mse", default=0.0, type=float,
help="Linear weight of the MSE loss. Must be >=0.")
parser.add_argument("--mlm_mask_prop", default=0.15, type=float,
help="Proportion of tokens for which we need to make a prediction.")
parser.add_argument("--word_mask", default=0.8, type=float,
help="Proportion of tokens to mask out.")
parser.add_argument("--word_keep", default=0.1, type=float,
help="Proportion of tokens to keep.")
parser.add_argument("--word_rand", default=0.1, type=float,
help="Proportion of tokens to randomly replace.")
parser.add_argument("--mlm_smoothing", default=0.7, type=float,
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).")
parser.add_argument("--restrict_ce_to_mask", action='store_true',
help="If true, compute the distilation loss only the [MLM] prediction distribution.")
parser.add_argument("--n_epoch", type=int, default=3,
help="Number of pass on the whole dataset.")
parser.add_argument("--batch_size", type=int, default=5,
help="Batch size (for each process).")
parser.add_argument("--tokens_per_batch", type=int, default=-1,
help="If specified, modify the batches so that they have approximately this number of tokens.")
parser.add_argument("--shuffle", action='store_false',
help="If true, shuffle the sequence order. Default is true.")
parser.add_argument("--group_by_size", action='store_false',
help="If true, group sequences that have similar length into the same batch. Default is true.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=50,
help="Gradient accumulation for larger training batches.")
parser.add_argument("--warmup_prop", default=0.05, type=float,
help="Linear warmup proportion.")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="Weight deay if we apply some.")
parser.add_argument("--learning_rate", default=5e-4, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--adam_epsilon", default=1e-6, type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=5.0, type=float,
help="Max gradient norm.")
parser.add_argument("--initializer_range", default=0.02, type=float,
help="Random initialization range.")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--n_gpu", type=int, default=1,
help="Number of GPUs in the node.")
parser.add_argument("--local_rank", type=int, default=-1,
help="Distributed training - Local rank")
parser.add_argument("--seed", type=int, default=56,
help="Random seed")
parser.add_argument("--log_interval", type=int, default=500,
help="Tensorboard logging interval.")
parser.add_argument("--checkpoint_interval", type=int, default=4000,
help="Checkpoint interval.")
args = parser.parse_args()
## ARGS ##
init_gpu_params(args)
set_seed(args)
if args.is_master:
if os.path.exists(args.dump_path):
if not args.force:
raise ValueError(f'Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it'
'Use `--force` if you want to overwrite it')
else:
shutil.rmtree(args.dump_path)
if not os.path.exists(args.dump_path):
os.makedirs(args.dump_path)
logger.info(f'Experiment will be dumped and logged in {args.dump_path}')
### SAVE PARAMS ###
logger.info(f'Param: {args}')
with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f:
json.dump(vars(args), f, indent=4)
git_log(args.dump_path)
assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \
(args.from_pretrained_weights is not None and args.from_pretrained_config is not None)
### TOKENIZER ###
bert_tokenizer = BertTokenizer.from_pretrained(args.bert_model)
special_tok_ids = {}
for tok_name, tok_symbol in bert_tokenizer.special_tokens_map.items():
idx = bert_tokenizer.all_special_tokens.index(tok_symbol)
special_tok_ids[tok_name] = bert_tokenizer.all_special_ids[idx]
logger.info(f'Special tokens {special_tok_ids}')
args.special_tok_ids = special_tok_ids
## DATA LOADER ##
logger.info(f'Loading data from {args.data_file}')
with open(args.data_file, 'rb') as fp:
data = pickle.load(fp)
assert os.path.isfile(args.token_counts)
logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)')
with open(args.token_counts, 'rb') as fp:
counts = pickle.load(fp)
assert len(counts) == args.vocab_size
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
for idx in special_tok_ids.values():
token_probs[idx] = 0. # do not predict special tokens
token_probs = torch.from_numpy(token_probs)
train_dataloader = Dataset(params=args, data=data)
logger.info(f'Data loader created.')
## STUDENT ##
if args.from_pretrained_weights is not None:
assert os.path.isfile(os.path.join(args.from_pretrained_weights))
assert os.path.isfile(os.path.join(args.from_pretrained_config))
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
logger.info(f'Loading pretrained config from {args.from_pretrained_config}')
stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config)
student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights,
config=stu_architecture_config)
else:
args.vocab_size_or_config_json_file = args.vocab_size
stu_architecture_config = DistilBertConfig(**vars(args))
student = DistilBertForMaskedLM(stu_architecture_config)
if args.n_gpu > 0:
student.to(f'cuda:{args.local_rank}')
logger.info(f'Student loaded.')
## TEACHER ##
teacher = BertForMaskedLM.from_pretrained(args.bert_model)
if args.n_gpu > 0:
teacher.to(f'cuda:{args.local_rank}')
logger.info(f'Teacher loaded from {args.bert_model}.')
## DISTILLER ##
torch.cuda.empty_cache()
distiller = Distiller(params=args,
dataloader=train_dataloader,
token_probs=token_probs,
student=student,
teacher=teacher)
distiller.train()
logger.info("Let's go get some drinks.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,129 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Utils to train DistilBERT
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
"""
import git
import json
import os
import socket
import torch
import numpy as np
import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
def git_log(folder_path: str):
"""
Log commit info.
"""
repo = git.Repo(search_parent_directories=True)
repo_infos = {
'repo_id': str(repo),
'repo_sha': str(repo.head.object.hexsha),
'repo_branch': str(repo.active_branch)
}
with open(os.path.join(folder_path, 'git_log.json'), 'w') as f:
json.dump(repo_infos, f, indent=4)
def init_gpu_params(params):
"""
Handle single and multi-GPU / multi-node.
"""
if params.n_gpu <= 0:
params.local_rank = 0
params.master_port = -1
params.is_master = True
params.multi_gpu = False
return
assert torch.cuda.is_available()
logger.info('Initializing GPUs')
if params.n_gpu > 1:
assert params.local_rank != -1
params.world_size = int(os.environ['WORLD_SIZE'])
params.n_gpu_per_node = int(os.environ['N_GPU_NODE'])
params.global_rank = int(os.environ['RANK'])
# number of nodes / node ID
params.n_nodes = params.world_size // params.n_gpu_per_node
params.node_id = params.global_rank // params.n_gpu_per_node
params.multi_gpu = True
assert params.n_nodes == int(os.environ['N_NODES'])
assert params.node_id == int(os.environ['NODE_RANK'])
# local job (single GPU)
else:
assert params.local_rank == -1
params.n_nodes = 1
params.node_id = 0
params.local_rank = 0
params.global_rank = 0
params.world_size = 1
params.n_gpu_per_node = 1
params.multi_gpu = False
# sanity checks
assert params.n_nodes >= 1
assert 0 <= params.node_id < params.n_nodes
assert 0 <= params.local_rank <= params.global_rank < params.world_size
assert params.world_size == params.n_nodes * params.n_gpu_per_node
# define whether this is the master process / if we are in multi-node distributed mode
params.is_master = params.node_id == 0 and params.local_rank == 0
params.multi_node = params.n_nodes > 1
# summary
PREFIX = f"--- Global rank: {params.global_rank} - "
logger.info(PREFIX + "Number of nodes: %i" % params.n_nodes)
logger.info(PREFIX + "Node ID : %i" % params.node_id)
logger.info(PREFIX + "Local rank : %i" % params.local_rank)
logger.info(PREFIX + "World size : %i" % params.world_size)
logger.info(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node)
logger.info(PREFIX + "Master : %s" % str(params.is_master))
logger.info(PREFIX + "Multi-node : %s" % str(params.multi_node))
logger.info(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu))
logger.info(PREFIX + "Hostname : %s" % socket.gethostname())
# set GPU device
torch.cuda.set_device(params.local_rank)
# initialize multi-GPU
if params.multi_gpu:
logger.info("Initializing PyTorch distributed")
torch.distributed.init_process_group(
init_method='env://',
backend='nccl',
)
def set_seed(args):
"""
Set the random seed.
"""
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)

View File

@ -128,7 +128,7 @@ def train(args, train_dataset, model, tokenizer):
batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM and RoBERTa don't use segment_ids
'labels': batch[3]}
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
@ -251,7 +251,7 @@ def evaluate(args, model, tokenizer, prefix=""):
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
if args.local_rank not in [-1, 0]:
if args.local_rank not in [-1, 0] and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
processor = processors[task]()
@ -286,7 +286,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
if args.local_rank == 0:
if args.local_rank == 0 and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Convert to Tensors and build dataset

View File

@ -0,0 +1,497 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for language modeling on WikiText-2 (GPT, GPT-2, BERT, RoBERTa).
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
using a masked language modeling (MLM) loss.
"""
from __future__ import absolute_import, division, print_function
import argparse
import glob
import logging
import os
import pickle
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange
from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
BertConfig, BertForMaskedLM, BertTokenizer,
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
logger = logging.getLogger(__name__)
MODEL_CLASSES = {
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
}
class TextDataset(Dataset):
def __init__(self, tokenizer, file_path='train', block_size=512):
assert os.path.isfile(file_path)
directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(directory, f'cached_lm_{block_size}_{filename}')
if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
with open(cached_features_file, 'rb') as handle:
self.examples = pickle.load(handle)
else:
logger.info("Creating features from dataset file at %s", directory)
self.examples = []
with open(file_path, encoding="utf-8") as f:
text = f.read()
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
while len(tokenized_text) >= block_size: # Truncate in block of block_size
self.examples.append(tokenizer.add_special_tokens_single_sentence(tokenized_text[:block_size]))
tokenized_text = tokenized_text[block_size:]
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
# If your dataset is small, first you should loook for a bigger one :-) and second you
# can change this behavior by adding (model specific) padding.
logger.info("Saving features into cached file %s", cached_features_file)
with open(cached_features_file, 'wb') as handle:
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
def __len__(self):
return len(self.examples)
def __getitem__(self, item):
return torch.tensor(self.examples[item])
def load_and_cache_examples(args, tokenizer, evaluate=False):
dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
return dataset
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def mask_tokens(inputs, tokenizer, args):
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).bool()
labels[~masked_indices] = -1 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
def train(args, train_dataset, model, tokenizer):
""" Train the model """
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter()
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
if args.max_steps > 0:
t_total = args.max_steps
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
else:
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
inputs = inputs.to(args.device)
labels = labels.to(args.device)
model.train()
outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
scheduler.step() # Update learning rate schedule
model.zero_grad()
global_step += 1
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Log metrics
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer)
for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
logging_loss = tr_loss
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
if args.max_steps > 0 and global_step > args.max_steps:
train_iterator.close()
break
if args.local_rank in [-1, 0]:
tb_writer.close()
return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, prefix=""):
# Loop to handle MNLI double evaluation (matched, mis-matched)
eval_output_dir = args.output_dir
results = {}
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
os.makedirs(eval_output_dir)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
eval_loss = 0.0
nb_eval_steps = 0
model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
batch = batch.to(args.device)
with torch.no_grad():
outputs = model(batch, masked_lm_labels=batch) if args.mlm else model(batch, labels=batch)
lm_loss = outputs[0]
eval_loss += lm_loss.mean().item()
nb_eval_steps += 1
eval_loss = eval_loss / nb_eval_steps
perplexity = torch.exp(torch.tensor(eval_loss))
result = {
"perplexity": perplexity
}
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
return results
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--train_data_file", default=None, type=str, required=True,
help="The input training data file (a text file).")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
## Other parameters
parser.add_argument("--eval_data_file", default=None, type=str,
help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
parser.add_argument("--model_type", default="bert", type=str,
help="The model architecture to be fine-tuned.")
parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str,
help="The model checkpoint for weights initialization.")
parser.add_argument("--mlm", action='store_true',
help="Train with masked-language modeling loss instead of language modeling.")
parser.add_argument("--mlm_probability", type=float, default=0.15,
help="Ratio of tokens to mask for masked language modeling loss")
parser.add_argument("--config_name", default="", type=str,
help="Optional pretrained config name or path if not the same as model_name_or_path")
parser.add_argument("--tokenizer_name", default="", type=str,
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
parser.add_argument("--cache_dir", default="", type=str,
help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
parser.add_argument("--block_size", default=-1, type=int,
help="Optional input sequence length after tokenization."
"The training dataset will be truncated in block of this size for training."
"Default to the model max input length for single sentence inputs (take into account special tokens).")
parser.add_argument("--do_train", action='store_true',
help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--evaluate_during_training", action='store_true',
help="Run evaluation during training at each logging step.")
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int,
help="Batch size per GPU/CPU for evaluation.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="Weight deay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
help="Max gradient norm.")
parser.add_argument("--num_train_epochs", default=1.0, type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--max_steps", default=-1, type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
parser.add_argument("--warmup_steps", default=0, type=int,
help="Linear warmup over warmup_steps.")
parser.add_argument('--logging_steps', type=int, default=50,
help="Log every X updates steps.")
parser.add_argument('--save_steps', type=int, default=50,
help="Save checkpoint every X updates steps.")
parser.add_argument("--eval_all_checkpoints", action='store_true',
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available")
parser.add_argument('--overwrite_output_dir', action='store_true',
help="Overwrite the content of the output directory")
parser.add_argument('--overwrite_cache', action='store_true',
help="Overwrite the cached training and evaluation sets")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--local_rank", type=int, default=-1,
help="For distributed training: local_rank")
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
args = parser.parse_args()
if args.model_type in ["bert", "roberta"] and not args.mlm:
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
"flag (masked language modeling).")
if args.eval_data_file is None and args.do_eval:
raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
"or remove the --do_eval argument.")
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
# Setup distant debugging if needed
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
# Setup CUDA, GPU & distributed training
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl')
args.n_gpu = 1
args.device = device
# Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
# Set seed
set_seed(args)
# Load pretrained model and tokenizer
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
if args.block_size <= 0:
args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model
args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
model.to(args.device)
if args.local_rank == 0:
torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
logger.info("Training/evaluation parameters %s", args)
# Training
if args.do_train:
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
if args.local_rank == 0:
torch.distributed.barrier()
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Create output directory if needed
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model.to(args.device)
# Evaluation
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint)
model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result)
return results
if __name__ == "__main__":
main()

View File

@ -157,8 +157,8 @@ def train(args, train_dataset, model, tokenizer):
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
scheduler.step() # Update learning rate schedule
optimizer.step()
scheduler.step() # Update learning rate schedule
model.zero_grad()
global_step += 1
@ -272,7 +272,7 @@ def evaluate(args, model, tokenizer, prefix=""):
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
if args.local_rank not in [-1, 0]:
if args.local_rank not in [-1, 0] and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Load data features from cache or dataset file
@ -299,7 +299,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
if args.local_rank == 0:
if args.local_rank == 0 and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Convert to Tensors and build dataset

View File

@ -81,7 +81,7 @@ class ExamplesTests(unittest.TestCase):
"--do_train",
"--do_eval",
"--version_2_with_negative",
"--learning_rate=1e-4",
"--learning_rate=2e-4",
"--per_gpu_train_batch_size=2",
"--per_gpu_eval_batch_size=1",
"--overwrite_output_dir",

View File

@ -7,6 +7,7 @@ from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from .tokenization_xlm import XLMTokenizer
from .tokenization_roberta import RobertaTokenizer
from .tokenization_distilbert import DistilBertTokenizer
from .tokenization_utils import (PreTrainedTokenizer)
@ -40,6 +41,9 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_distilbert import (DistilBertConfig, DistilBertForMaskedLM, DistilBertModel,
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)

View File

@ -30,6 +30,7 @@ from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel
from .modeling_xlnet import XLNetConfig, XLNetModel
from .modeling_xlm import XLMConfig, XLMModel
from .modeling_roberta import RobertaConfig, RobertaModel
from .modeling_distilbert import DistilBertConfig, DistilBertModel
from .modeling_utils import PreTrainedModel, SequenceSummary
@ -110,7 +111,9 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False}
"""
if 'roberta' in pretrained_model_name_or_path:
if 'distilbert' in pretrained_model_name_or_path:
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
@ -225,7 +228,9 @@ class AutoModel(object):
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if 'roberta' in pretrained_model_name_or_path:
if 'distilbert' in pretrained_model_name_or_path:
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

View File

@ -216,7 +216,7 @@ class BertConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
" or the path to a pretrained model config file (str)")
@ -449,7 +449,7 @@ class BertEncoder(nn.Module):
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class BertPooler(nn.Module):

View File

@ -0,0 +1,751 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch DistilBERT model
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
and in part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import math
import copy
import sys
from io import open
import itertools
import numpy as np
import torch
import torch.nn as nn
from pytorch_transformers.modeling_utils import PretrainedConfig, PreTrainedModel, add_start_docstrings, prune_linear_layer
import logging
logger = logging.getLogger(__name__)
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-pytorch_model.bin",
'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-pytorch_model.bin"
}
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json"
}
class DistilBertConfig(PretrainedConfig):
pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=30522,
max_position_embeddings=512,
sinusoidal_pos_embds=True,
n_layers=6,
n_heads=12,
dim=768,
hidden_dim=4*768,
dropout=0.1,
attention_dropout=0.1,
activation='gelu',
initializer_range=0.02,
tie_weights_=True,
qa_dropout=0.1,
seq_classif_dropout=0.2,
**kwargs):
super(DistilBertConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.max_position_embeddings = max_position_embeddings
self.sinusoidal_pos_embds = sinusoidal_pos_embds
self.n_layers = n_layers
self.n_heads = n_heads
self.dim = dim
self.hidden_dim = hidden_dim
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation = activation
self.initializer_range = initializer_range
self.tie_weights_ = tie_weights_
self.qa_dropout = qa_dropout
self.seq_classif_dropout = seq_classif_dropout
else:
raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)")
@property
def hidden_size(self):
return self.dim
@property
def num_attention_heads(self):
return self.n_heads
@property
def num_hidden_layers(self):
return self.n_layers
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
def gelu(x):
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
def create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
for pos in range(n_pos)
])
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
out.requires_grad = False
class Embeddings(nn.Module):
def __init__(self,
config):
super(Embeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=0)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
if config.sinusoidal_pos_embds:
create_sinusoidal_embeddings(n_pos=config.max_position_embeddings,
dim=config.dim,
out=self.position_embeddings.weight)
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
self.dropout = nn.Dropout(config.dropout)
def forward(self, input_ids):
"""
Parameters
----------
input_ids: torch.tensor(bs, max_seq_length)
The token ids to embed.
Outputs
-------
embeddings: torch.tensor(bs, max_seq_length, dim)
The embedded tokens (plus position embeddings, no token_type embeddings)
"""
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim)
return embeddings
class MultiHeadSelfAttention(nn.Module):
def __init__(self, config):
super(MultiHeadSelfAttention, self).__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.dropout = nn.Dropout(p=config.attention_dropout)
self.output_attentions = config.output_attentions
assert self.dim % self.n_heads == 0
self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
return
mask = torch.ones(self.n_heads, attention_head_size)
for head in heads:
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers
self.q_lin = prune_linear_layer(self.q_lin, index)
self.k_lin = prune_linear_layer(self.k_lin, index)
self.v_lin = prune_linear_layer(self.v_lin, index)
self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
# Update hyper params
self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads
def forward(self, query, key, value, mask, head_mask = None):
"""
Parameters
----------
query: torch.tensor(bs, seq_length, dim)
key: torch.tensor(bs, seq_length, dim)
value: torch.tensor(bs, seq_length, dim)
mask: torch.tensor(bs, seq_length)
Outputs
-------
weights: torch.tensor(bs, n_heads, seq_length, seq_length)
Attention weights
context: torch.tensor(bs, seq_length, dim)
Contextualized layer. Optional: only if `output_attentions=True`
"""
bs, q_length, dim = query.size()
k_length = key.size(1)
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
# assert key.size() == value.size()
dim_per_head = self.dim // self.n_heads
assert 2 <= mask.dim() <= 3
causal = (mask.dim() == 3)
mask_reshp = (bs, 1, 1, k_length)
def shape(x):
""" separate heads """
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(x):
""" group heads """
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2,3)) # (bs, n_heads, q_length, k_length)
mask = (mask==0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, q_length, k_length)
weights = nn.Softmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
# Mask heads if we want to
if head_mask is not None:
weights = weights * head_mask
context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
context = unshape(context) # (bs, q_length, dim)
context = self.out_lin(context) # (bs, q_length, dim)
if self.output_attentions:
return (context, weights)
else:
return (context,)
class FFN(nn.Module):
def __init__(self, config):
super(FFN, self).__init__()
self.dropout = nn.Dropout(p=config.dropout)
self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
assert config.activation in ['relu', 'gelu'], "activation ({}) must be in ['relu', 'gelu']".format(config.activation)
self.activation = gelu if config.activation == 'gelu' else nn.ReLU()
def forward(self, input):
x = self.lin1(input)
x = self.activation(x)
x = self.lin2(x)
x = self.dropout(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, config):
super(TransformerBlock, self).__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.hidden_dim = config.hidden_dim
self.dropout = nn.Dropout(p=config.dropout)
self.activation = config.activation
self.output_attentions = config.output_attentions
assert config.dim % config.n_heads == 0
self.attention = MultiHeadSelfAttention(config)
self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
self.ffn = FFN(config)
self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
def forward(self, x, attn_mask=None, head_mask=None):
"""
Parameters
----------
x: torch.tensor(bs, seq_length, dim)
attn_mask: torch.tensor(bs, seq_length)
Outputs
-------
sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length)
The attention weights
ffn_output: torch.tensor(bs, seq_length, dim)
The output of the transformer block contextualization.
"""
# Self-Attention
sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask, head_mask=head_mask)
if self.output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
assert type(sa_output) == tuple
sa_output = sa_output[0]
sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
# Feed Forward Network
ffn_output = self.ffn(sa_output) # (bs, seq_length, dim)
ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
output = (ffn_output,)
if self.output_attentions:
output = (sa_weights,) + output
return output
class Transformer(nn.Module):
def __init__(self, config):
super(Transformer, self).__init__()
self.n_layers = config.n_layers
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
layer = TransformerBlock(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
def forward(self, x, attn_mask=None, head_mask=None):
"""
Parameters
----------
x: torch.tensor(bs, seq_length, dim)
Input sequence embedded.
attn_mask: torch.tensor(bs, seq_length)
Attention mask on the sequence.
Outputs
-------
hidden_state: torch.tensor(bs, seq_length, dim)
Sequence of hiddens states in the last (top) layer
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
Tuple of length n_layers with the hidden states from each layer.
Optional: only if output_hidden_states=True
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if output_attentions=True
"""
all_hidden_states = ()
all_attentions = ()
hidden_state = x
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
layer_outputs = layer_module(x=hidden_state,
attn_mask=attn_mask,
head_mask=head_mask[i])
hidden_state = layer_outputs[-1]
if self.output_attentions:
assert len(layer_outputs) == 2
attentions = layer_outputs[0]
all_attentions = all_attentions + (attentions,)
else:
assert len(layer_outputs) == 1
# Add last layer
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
outputs = (hidden_state,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
class DistilBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = DistilBertConfig
pretrained_model_archive_map = DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = None
base_model_prefix = "distilbert"
def __init__(self, *inputs, **kwargs):
super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, nn.Embedding):
if module.weight.requires_grad:
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
DISTILBERT_START_DOCSTRING = r"""
DistilBERT is a small, fast, cheap and light Transformer model
trained by distilling Bert base. It has 40% less parameters than
`bert-base-uncased`, runs 60% faster while preserving over 95% of
Bert's performances as measured on the GLUE language understanding benchmark.
Here are the differences between the interface of Bert and DistilBert:
- DistilBert doesn't have `token_type_ids`, you don't need to indicate which token belongs to which segment. Just separate your segments with the separation token `tokenizer.sep_token` (or `[SEP]`)
- DistilBert doesn't have options to select the input positions (`position_ids` input). This could be added if necessary though, just let's us know if you need this option.
For more information on DistilBERT, please refer to our
`detailed blog post`_
.. _`detailed blog post`:
https://medium.com/huggingface/distilbert-8cf3380435b5
Parameters:
config (:class:`~pytorch_transformers.DistilBertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
DISTILBERT_INPUTS_DOCSTRING = r"""
Inputs:
**input_ids** ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The input sequences should start with `[CLS]` and end with `[SEP]` tokens.
For now, ONLY BertTokenizer(`bert-base-uncased`) is supported and you should use this tokenizer when using DistilBERT.
**attention_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@add_start_docstrings("The bare DistilBERT encoder/transformer outputing raw hidden-states without any specific head on top.",
DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
class DistilBertModel(DistilBertPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the output of the last layer of the model.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config):
super(DistilBertModel, self).__init__(config)
self.embeddings = Embeddings(config) # Embeddings
self.transformer = Transformer(config) # Encoder
self.apply(self.init_weights)
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.transformer.layer[layer].attention.prune_heads(heads)
def forward(self,
input_ids, attention_mask=None, head_mask=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids) # (bs, seq_length)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids) # (bs, seq_length, dim)
tfmr_output = self.transformer(x=embedding_output,
attn_mask=attention_mask,
head_mask=head_mask)
hidden_state = tfmr_output[0]
output = (hidden_state, ) + tfmr_output[1:]
return output # last-layer hidden-state, (all hidden_states), (all attentions)
@add_start_docstrings("""DistilBert Model with a `masked language modeling` head on top. """,
DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
"""
def __init__(self, config):
super(DistilBertForMaskedLM, self).__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.distilbert = DistilBertModel(config)
self.vocab_transform = nn.Linear(config.dim, config.dim)
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
self.apply(self.init_weights)
self.tie_weights()
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.vocab_projector,
self.distilbert.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, masked_lm_labels=None, head_mask=None):
dlbrt_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size)
outputs = (prediction_logits, ) + dlbrt_output[1:]
if masked_lm_labels is not None:
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)),
masked_lm_labels.view(-1))
outputs = (mlm_loss,) + outputs
return outputs # (mlm_loss), prediction_logits, (all hidden_states), (all attentions)
@add_start_docstrings("""DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
def __init__(self, config):
super(DistilBertForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.distilbert = DistilBertModel(config)
self.pre_classifier = nn.Linear(config.dim, config.dim)
self.classifier = nn.Linear(config.dim, config.num_labels)
self.dropout = nn.Dropout(config.seq_classif_dropout)
self.apply(self.init_weights)
def forward(self, input_ids, attention_mask=None, labels=None, head_mask=None):
distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, dim)
outputs = (logits,) + distilbert_output[1:]
if labels is not None:
if self.num_labels == 1:
loss_fct = nn.MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings("""DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss, start_scores, end_scores = outputs[:2]
"""
def __init__(self, config):
super(DistilBertForQuestionAnswering, self).__init__(config)
self.distilbert = DistilBertModel(config)
self.qa_outputs = nn.Linear(config.dim, config.num_labels)
assert config.num_labels == 2
self.dropout = nn.Dropout(config.qa_dropout)
self.apply(self.init_weights)
def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None, head_mask=None):
distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) # (bs, max_query_len)
end_logits = end_logits.squeeze(-1) # (bs, max_query_len)
outputs = (start_logits, end_logits,) + distilbert_output[1:]
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)

View File

@ -657,9 +657,6 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
`multiple_choice_labels`: optional multiple choice labels: ``torch.LongTensor`` of shape [batch_size]
with indices selected in [0, ..., num_choices].
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss.

View File

@ -98,15 +98,15 @@ ROBERTA_INPUTS_DOCSTRING = r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
To match pre-training, RoBERTa input sequence should be formatted with [CLS] and [SEP] tokens as follows:
To match pre-training, RoBERTa input sequence should be formatted with <s> and </s> tokens as follows:
(a) For sequence pairs:
``tokens: [CLS] is this jack ##son ##ville ? [SEP][SEP] no it is not . [SEP]``
``tokens: <s> Is this Jacksonville ? </s> </s> No it is not . </s>``
(b) For single sequences:
``tokens: [CLS] the dog is hairy . [SEP]``
``tokens: <s> the dog is hairy . </s>``
Fully encoded sequences or sequence pairs can be obtained using the RobertaTokenizer.encode function with
the ``add_special_tokens`` parameter set to ``True``.

View File

@ -285,7 +285,7 @@ class TransfoXLConfig(PretrainedConfig):
self.init_std = init_std
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
" or the path to a pretrained model config file (str)")
@property
def max_position_embeddings(self):
@ -1142,10 +1142,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
else:
mask_shift_len = qlen
dec_attn_mask = (torch.triu(all_ones, 1+mlen)
+ torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
+ torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1
else:
dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None]
hids = []
attentions = []

View File

@ -166,7 +166,7 @@ class PretrainedConfig(object):
# redirect to the cache, if necessary
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError:
except EnvironmentError as e:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
@ -179,7 +179,7 @@ class PretrainedConfig(object):
pretrained_model_name_or_path,
', '.join(cls.pretrained_config_archive_map.keys()),
config_file))
return None
raise e
if resolved_config_file == config_file:
logger.info("loading configuration file {}".format(config_file))
else:
@ -473,7 +473,7 @@ class PreTrainedModel(nn.Module):
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError:
except EnvironmentError as e:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
@ -486,7 +486,7 @@ class PreTrainedModel(nn.Module):
pretrained_model_name_or_path,
', '.join(cls.pretrained_model_archive_map.keys()),
archive_file))
return None
raise e
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else:

View File

@ -178,7 +178,7 @@ class XLMConfig(PretrainedConfig):
self.end_n_top = end_n_top
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
" or the path to a pretrained model config file (str)")
@property
def vocab_size(self):

View File

@ -306,7 +306,7 @@ class XLNetConfig(PretrainedConfig):
self.end_n_top = end_n_top
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
" or the path to a pretrained model config file (str)")
@property
def max_position_embeddings(self):
@ -677,8 +677,11 @@ XLNET_INPUTS_DOCSTRING = r"""
``1`` for tokens that are MASKED, ``0`` for tokens that are NOT MASKED.
**mems**: (`optional`)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
that contains pre-computed hidden-states (key and values in the attention blocks) as output by the model
(see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
To activate mems you need to set up config.mem_len to a positive value which will be the max number of tokens in
the memory output by the model. E.g. `model = XLNetModel.from_pretrained('xlnet-base-case, mem_len=1024)` will
instantiate a model which can use up to 1024 tokens of memory (in addition to the input it self).
**perm_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, sequence_length)``:
Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
If ``perm_mask[k, i, j] = 0``, i attend to j in batch k;
@ -705,7 +708,8 @@ class XLNetModel(XLNetPreTrainedModel):
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
@ -859,7 +863,7 @@ class XLNetModel(XLNetPreTrainedModel):
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
qlen, bsz = input_ids.shape[0], input_ids.shape[1]
mlen = mems[0].shape[0] if mems is not None else 0
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
klen = mlen + qlen
dtype_float = next(self.parameters()).dtype
@ -1011,7 +1015,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
@ -1091,7 +1096,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
@ -1189,7 +1195,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:

View File

@ -49,6 +49,7 @@ class CommonTestCases:
test_torchscript = True
test_pruning = True
test_resize_embeddings = True
test_head_masking = True
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -159,6 +160,9 @@ class CommonTestCases:
def test_headmasking(self):
if not self.test_head_masking:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_attentions = True
@ -285,6 +289,9 @@ class CommonTestCases:
self.assertTrue(models_equal)
def test_tie_model_weights(self):
if not self.test_torchscript:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_same_values(layer_1, layer_2):

View File

@ -0,0 +1,217 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import shutil
import pytest
from pytorch_transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM,
DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
from pytorch_transformers.modeling_distilbert import DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
class DistilBertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (DistilBertModel, DistilBertForMaskedLM, DistilBertForQuestionAnswering,
DistilBertForSequenceClassification)
test_pruning = True
test_torchscript = True
test_resize_embeddings = True
test_head_masking = True
class DistilBertModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = DistilBertConfig(
vocab_size_or_config_json_file=self.vocab_size,
dim=self.hidden_size,
n_layers=self.num_hidden_layers,
n_heads=self.num_attention_heads,
hidden_dim=self.intermediate_size,
hidden_act=self.hidden_act,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range)
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_and_check_distilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DistilBertModel(config=config)
model.eval()
(sequence_output,) = model(input_ids, input_mask)
(sequence_output,) = model(input_ids)
result = {
"sequence_output": sequence_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
def create_and_check_distilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DistilBertForMaskedLM(config=config)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, masked_lm_labels=token_labels)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_and_check_distilbert_for_question_answering(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DistilBertForQuestionAnswering(config=config)
model.eval()
loss, start_logits, end_logits = model(input_ids, input_mask, sequence_labels, sequence_labels)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(
list(result["start_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_and_check_distilbert_for_sequence_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_labels = self.num_labels
model = DistilBertForSequenceClassification(config)
model.eval()
loss, logits = model(input_ids, input_mask, sequence_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_labels])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids, 'attention_mask': input_mask}
return config, inputs_dict
def setUp(self):
self.model_tester = DistilBertModelTest.DistilBertModelTester(self)
self.config_tester = ConfigTester(self, config_class=DistilBertConfig, dim=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_distilbert_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_model(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_masked_lm(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)
# @pytest.mark.slow
# def test_model_from_pretrained(self):
# cache_dir = "/tmp/pytorch_transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# shutil.rmtree(cache_dir)
# self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()

View File

@ -42,7 +42,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self):
return BertTokenizer.from_pretrained(self.tmpdirname)
return self.tokenizer_class.from_pretrained(self.tmpdirname)
def get_input_output_texts(self):
input_text = u"UNwant\u00E9d,running"
@ -50,7 +50,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = BertTokenizer(self.vocab_file)
tokenizer = self.tokenizer_class(self.vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
@ -126,7 +126,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
self.assertFalse(_is_punctuation(u" "))
def test_sequence_builders(self):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")

View File

@ -0,0 +1,46 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
from io import open
from pytorch_transformers.tokenization_distilbert import (DistilBertTokenizer)
from .tokenization_tests_commons import CommonTestCases
from .tokenization_bert_test import BertTokenizationTest
class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DistilBertTokenizer
def get_tokenizer(self):
return DistilBertTokenizer.from_pretrained(self.tmpdirname)
def test_sequence_builders(self):
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
assert encoded_sentence == [101] + text + [102]
assert encoded_pair == [101] + text + [102] + text_2 + [102]
if __name__ == '__main__':
unittest.main()

View File

@ -125,6 +125,9 @@ class BertTokenizer(PreTrainedTokenizer):
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, **kwargs)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "

View File

@ -0,0 +1,62 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for DistilBERT."""
from __future__ import absolute_import, division, print_function, unicode_literals
import collections
import logging
import os
import unicodedata
from io import open
from .tokenization_bert import BertTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'distilbert-base-uncased': 512,
'distilbert-base-uncased-distilled-squad': 512,
}
class DistilBertTokenizer(BertTokenizer):
r"""
Constructs a DistilBertTokenizer.
:class:`~pytorch_transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the
minimum of this value (if specified) and the underlying BERT model's sequence length.
never_split: List of tokens which will never be split during tokenization. Only has an effect when
do_wordpiece_only=False
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

View File

@ -112,6 +112,8 @@ class GPT2Tokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<|endoftext|>",
bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs):
super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v: k for k, v in self.encoder.items()}

View File

@ -87,6 +87,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
try:
import ftfy
from spacy.lang.en import English

View File

@ -84,14 +84,14 @@ class RobertaTokenizer(GPT2Tokenizer):
def add_special_tokens_single_sentence(self, token_ids):
"""
Adds special tokens to a sequence for sequence classification tasks.
A RoBERTa sequence has the following format: [CLS] X [SEP]
A RoBERTa sequence has the following format: <s> X </s>
"""
return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
"""
Adds special tokens to a sequence pair for sequence classification tasks.
A RoBERTa sequence pair has the following format: [CLS] A [SEP][SEP] B [SEP]
A RoBERTa sequence pair has the following format: <s> A </s></s> B </s>
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]

View File

@ -73,6 +73,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
super(TransfoXLTokenizer, self).__init__(unk_token=unk_token, eos_token=eos_token,
additional_special_tokens=additional_special_tokens,
**kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
if never_split is None:
never_split = self.all_special_tokens
if special is None:

View File

@ -222,6 +222,9 @@ class PreTrainedTokenizer(object):
self._additional_special_tokens = []
self.max_len = max_len if max_len is not None else int(1e12)
self.max_len_single_sentence = self.max_len
self.max_len_sentences_pair = self.max_len
self.added_tokens_encoder = {}
self.added_tokens_decoder = {}
@ -349,7 +352,7 @@ class PreTrainedTokenizer(object):
resolved_vocab_files[file_id] = None
else:
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError:
except EnvironmentError as e:
if pretrained_model_name_or_path in s3_models:
logger.error("Couldn't reach server to download vocabulary.")
else:
@ -359,7 +362,7 @@ class PreTrainedTokenizer(object):
"at this path or url.".format(
pretrained_model_name_or_path, ', '.join(s3_models),
pretrained_model_name_or_path, str(vocab_files.keys())))
return None
raise e
for file_id, file_path in vocab_files.items():
if file_path == resolved_vocab_files[file_id]:
@ -653,10 +656,12 @@ class PreTrainedTokenizer(object):
return first_sentence_tokens, second_sentence_tokens
def add_special_tokens_single_sentence(self, token_ids):
raise NotImplementedError
logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
return token_ids
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
raise NotImplementedError
logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
return token_ids_0 + token_ids_1
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
""" Converts a single index or a sequence of indices (integers) in a token "
@ -699,9 +704,9 @@ class PreTrainedTokenizer(object):
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
text = self.convert_tokens_to_string(filtered_tokens)
if self.sep_token is not None and self.sep_token in text:
text = text.replace(self.cls_token, self.sep_token)
split_text = list(filter(lambda sentence: len(sentence) > 0, text.split(self.sep_token)))
if self._sep_token is not None and self._sep_token in text:
text = text.replace(self._cls_token, self._sep_token)
split_text = list(filter(lambda sentence: len(sentence) > 0, text.split(self._sep_token)))
if clean_up_tokenization_spaces:
clean_text = [self.clean_up_tokenization(text) for text in split_text]
return clean_text

View File

@ -122,6 +122,10 @@ class XLMTokenizer(PreTrainedTokenizer):
cls_token=cls_token, mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
**kwargs)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
try:
import ftfy
from spacy.lang.en import English

View File

@ -71,6 +71,10 @@ class XLNetTokenizer(PreTrainedTokenizer):
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, additional_special_tokens=
additional_special_tokens, **kwargs)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
try:
import sentencepiece as spm
except ImportError: