* first copy & past commit from Bert and morgans LSH code

* add easy way to compare to trax original code

* translate most of function

* make trax lsh self attention deterministic with numpy seed + copy paste code

* add same config

* add same config

* make layer init work

* implemented hash_vectors function for lsh attention

* continue reformer translation

* hf LSHSelfAttentionLayer gives same output as trax layer

* refactor code

* refactor code

* refactor code

* refactor

* refactor + add reformer config

* delete bogus file

* split reformer attention layer into two layers

* save intermediate step

* save intermediate step

* make test work

* add complete reformer block layer

* finish reformer layer

* implement causal and self mask

* clean reformer test and refactor code

* fix merge conflicts

* fix merge conflicts

* update init

* fix device for GPU

* fix chunk length init for tests

* include morgans optimization

* improve memory a bit

* improve comment

* factorize num_buckets

* better testing parameters

* make whole model work

* make lm model work

* add t5 copy paste tokenizer

* add chunking feed forward

* clean config

* add improved assert statements

* make tokenizer work

* improve test

* correct typo

* extend config

* add complexer test

* add new axial position embeddings

* add local block attention layer

* clean tests

* refactor

* better testing

* save intermediate progress

* clean test file

* make shorter input length work for model

* allow variable input length

* refactor

* make forward pass for pretrained model work

* add generation possibility

* finish dropout and init

* make style

* refactor

* add first version of RevNet Layers

* make forward pass work and add convert file

* make uploaded model forward pass work

* make uploaded model forward pass work

* refactor code

* add namedtuples and cache buckets

* correct head masks

* refactor

* made reformer more flexible

* make style

* remove set max length

* add attention masks

* fix up tests

* fix lsh attention mask

* make random seed optional for the moment

* improve memory in reformer

* add tests

* make style

* make sure masks work correctly

* detach gradients

* save intermediate

* correct backprob through gather

* make style

* change back num hashes

* rename to labels

* fix rotation shape

* fix detach

* update

* fix trainer

* fix backward dropout

* make reformer more flexible

* fix conflict

* fix

* fix

* add tests for fixed seed in reformer layer

* fix trainer typo

* fix typo in activations

* add fp16 tests

* add fp16 training

* support fp16

* correct gradient bug in reformer

* add fast gelu

* re-add dropout for embedding dropout

* better naming

* better naming

* renaming

* finalize test branch

* finalize tests

* add more tests

* finish tests

* fix

* fix type trainer

* fix fp16 tests

* fix tests

* fix tests

* fix tests

* fix issue with dropout

* fix dropout seeds

* correct random seed on gpu

* finalize random seed for dropout

* finalize random seed for dropout

* remove duplicate line

* correct half precision bug

* make style

* refactor

* refactor

* docstring

* remove sinusoidal position encodings for reformer

* move chunking to modeling_utils

* make style

* clean config

* make style

* fix tests

* fix auto tests

* pretrained models

* fix docstring

* update conversion file

* Update pretrained_models.rst

* fix rst

* fix rst

* update copyright

* fix test path

* fix test path

* fix small issue in test

* include reformer in generation tests

* add docs for axial position encoding

* finish docs

* Update convert_reformer_trax_checkpoint_to_pytorch.py

* remove isort

* include sams comments

* remove wrong comment in utils

* correct typos

* fix typo

* Update reformer.rst

* applied morgans optimization

* make style

* make gpu compatible

* remove bogus file

* big test refactor

* add example for chunking

* fix typo

* add to README
This commit is contained in:
Patrick von Platen 2020-05-07 10:17:01 +02:00 committed by GitHub
parent 877fc56410
commit dca34695d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 3608 additions and 23 deletions

View File

@ -163,8 +163,9 @@ At some point in the future, you'll be able to seamlessly move from pre-training
16. **[BART](https://huggingface.co/transformers/model_doc/bart.html)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/pdf/1910.13461.pdf) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
17. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
18. **[DialoGPT](https://huggingface.co/transformers/model_doc/dialogpt.html)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
18. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
19. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
19. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
20. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
21. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).

View File

@ -143,3 +143,14 @@ positional embeddings.
Absolute positional embeddings are selected in the range ``[0, config.max_position_embeddings - 1]``. Some models
use other types of positional embeddings, such as sinusoidal position embeddings or relative position embeddings.
Feed Forward Chunking
--------------------------
In transformers two feed forward layers usually follows the self attention layer in each residual attention block. The intermediate embedding size of the feed forward layers is often bigger than the hidden size of the model (*e.g.* for ``bert-base-uncased``).
For an input of size ``[batch_size, sequence_length]``, the memory required to store the intermediate feed forward embeddings ``[batch_size, sequence_length, config.intermediate_size]`` can account for a large fraction of the memory use. The authors of `Reformer: The Efficient Transformer <https://arxiv.org/abs/2001.04451>`_ noticed that since the computation is independent of the ``sequence_length`` dimension, it is mathematically equivalent to compute the output embeddings of both feed forward layers ``[batch_size, config.hidden_size]_0, ..., [batch_size, config.hidden_size]_n`` individually and concat them afterward to ``[batch_size, sequence_length, config.hidden_size]`` with ``n = sequence_length``, which trades increased computation time against reduced memory use, but yields a mathematically **equivalent** result.
For models employing the function :func:`~.transformers.apply_chunking_to_forward`, the ``chunk_size`` defines the number of output embeddings that are computed in parallel and thus defines the trade-off between memory and time complexity.
If ``chunk_size`` is set to 0, no feed forward chunking is done.

View File

@ -107,3 +107,4 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train
model_doc/t5
model_doc/electra
model_doc/dialogpt
model_doc/reformer

View File

@ -14,6 +14,12 @@ The base class ``PreTrainedModel`` implements the common methods for loading/sav
.. autoclass:: transformers.PreTrainedModel
:members:
``Helper Functions``
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: transformers.apply_chunking_to_forward
``TFPreTrainedModel``
~~~~~~~~~~~~~~~~~~~~~

View File

@ -0,0 +1,114 @@
Reformer
----------------------------------------------------
**DISCLAIMER:** This model is still a work in progress, if you see something strange,
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`_
Overview
~~~~~
The Reformer model was presented in `Reformer: The Efficient Transformer <https://https://arxiv.org/abs/2001.04451.pdf>`_ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
Here the abstract:
*Large Transformer models routinely achieve state-of-the-art results on a number of tasks but training these models can be prohibitively costly, especially on long sequences. We introduce two techniques to improve the efficiency of Transformers. For one, we replace dot-product attention by one that uses locality-sensitive hashing, changing its complexity from O(L^2) to O(Llog(L)), where L is the length of the sequence. Furthermore, we use reversible residual layers instead of the standard residuals, which allows storing activations only once in the training process instead of N times, where N is the number of layers. The resulting model, the Reformer, performs on par with Transformer models while being much more memory-efficient and much faster on long sequences.*
The Authors' code can be found `here <https://github.com/google/trax/tree/master/trax/models/reformer>`_ .
Axial Positional Encodings
~~~~~~~~~~~~~~~~~~~~
Axial Positional Encodings were first implemented in Google's `trax library <https://github.com/google/trax/blob/4d99ad4965bab1deba227539758d59f0df0fef48/trax/layers/research/position_encodings.py#L29>`_ and developed by the authors of this model's paper. In models that are treating very long input sequences, the conventional position id encodings store an embedings vector of size :math:`d` being the ``config.hidden_size`` for every position :math:`i, \ldots, n_s`, with :math:`n_s` being ``config.max_embedding_size``. *E.g.*, having a sequence length of :math:`n_s = 2^{19} \approx 0.5M` and a ``config.hidden_size`` of :math:`d = 2^{10} \approx 1000` would result in a position encoding matrix:
.. math::
X_{i,j}, \text{ with } i \in \left[1,\ldots, d\right] \text{ and } j \in \left[1,\ldots, n_s\right]
which alone has over 500M parameters to store. Axial positional encodings factorize :math:`X_{i,j}` into two matrices:
.. math::
X^{1}_{i,j}, \text{ with } i \in \left[1,\ldots, d^1\right] \text{ and } j \in \left[1,\ldots, n_s^1\right]
and
.. math::
X^{2}_{i,j}, \text{ with } i \in \left[1,\ldots, d^2\right] \text{ and } j \in \left[1,\ldots, n_s^2\right]
with:
.. math::
d = d^1 + d^2 \text{ and } n_s = n_s^1 \times n_s^2 .
Therefore the following holds:
.. math::
X_{i,j} = \begin{cases}
X^{1}_{i, k}, & \text{if }\ i < d^1 \text{ with } k = j \mod n_s^1 \\
X^{2}_{i - d^1, l}, & \text{if } i \ge d^1 \text{ with } l = \lfloor\frac{j}{n_s^1}\rfloor
\end{cases}
Intuitively, this means that a position embedding vector :math:`x_j \in \mathbb{R}^{d}` is now the composition of two factorized embedding vectors: :math:`x^1_{k, l} + x^2_{l, k}`, where as the ``config.max_embedding_size`` dimension :math:`j` is factorized into :math:`k \text{ and } l`.
This design ensures that each position embedding vector :math:`x_j` is unique.
Using the above example again, axial position encoding with :math:`d^1 = 2^5, d^2 = 2^5, n_s^1 = 2^9, n_s^2 = 2^{10}` can drastically reduced the number of parameters to :math:`2^{14} + 2^{15} \approx 49000` parameters.
In practice, the parameter ``config.axial_pos_embds_dim`` is set to ``list``:math:`(d^1, d^2)` which sum has to be equal to ``config.hidden_size`` and ``config.axial_pos_shape`` is set to ``list``:math:`(n_s^1, n_s^2)` and which product has to be equal to ``config.max_embedding_size`` which during training has to be equal to the ``sequence length`` of the ``input_ids``.
LSH Self Attention
~~~~~~~~~~~~~~~~~~~~
In Locality sensitive hashing (LSH) self attention the key and query projection weights are tied. Therefore, the key query embedding vectors are also tied.
LSH self attention uses the locality sensitive
hashing mechanism proposed in `Practical and Optimal LSH for Angular Distance <https://arxiv.org/abs/1509.02897>`_ to assign each of the tied key query embedding vectors to one of ``config.num_buckets`` possible buckets. The premise is that the more "similar" key query embedding vectors (in terms of *cosine similarity*) are to each other, the more likely they are assigned to the same bucket.
The accuracy of the LSH mechanism can be improved by increasing ``config.num_hashes`` or directly the argument ``num_hashes`` of the forward function so that the output of the LSH self attention better approximates the output of the "normal" full self attention.
The buckets are then sorted and chunked into query key embedding vector chunks each of length ``config.lsh_chunk_length``. For each chunk, the query embedding vectors attend to its key vectors (which are tied to themselves) and to the key embedding vectors of ``config.lsh_num_chunks_before`` previous neighboring chunks and ``config.lsh_num_chunks_after`` following neighboring chunks.
For more information, see the `original Paper <https://arxiv.org/abs/2001.04451>`_ or this great `blog post <https://www.pragmatic.ml/reformer-deep-dive/>`_.
Note that ``config.num_buckets`` can also be factorized into a ``list``:math:`(n_{\text{buckets}}^1, n_{\text{buckets}}^2)`. This way instead of assigning the query key embedding vectors to one of :math:`(1,\ldots, n_{\text{buckets}})` they are assigned to one of :math:`(1-1,\ldots, n_{\text{buckets}}^1-1, \ldots, 1-n_{\text{buckets}}^2, \ldots, n_{\text{buckets}}^1-n_{\text{buckets}}^2)`. This is crucial for very long sequences to save memory.
It is recommended to leave ``config.num_buckets=None``, so that depending on the sequence length, a good value for ``num_buckets`` are calculated on the fly.
Using LSH self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length.
Local Self Attention
~~~~~~~~~~~~~~~~~~~~
Local self attention is essentially a "normal" self attention layer with
key, query and value projections, but is chunked so that in each chunk of length ``config.local_chunk_length`` the query embedding vectors only attends to the key embedding vectors in its chunk and to the key embedding vectors of ``config.local_num_chunks_before`` previous neighboring chunks and ``config.local_num_chunks_after`` following neighboring chunks.
Using Local self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length.
Training
~~~~~~~~~~~~~~~~~~~~
During training, we must ensure that the sequence length is set to a value that can be divided by the least common multiple of ``config.lsh_chunk_length`` and ``config.local_chunk_length`` and that the parameters of the Axial Positional Encodings are correctly set as described above. Reformer is very memory efficient so that the model can easily be trained on sequences as long as 64000 tokens.
For training, the ``ReformerModelWithLMHead`` should be used as follows:
::
input_ids = tokenizer.encode('This is a sentence from the training data', return_tensors='pt')
loss = model(input_ids, labels=input_ids)[0]
ReformerConfig
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ReformerConfig
:members:
ReformerTokenizer
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ReformerTokenizer
:members:
ReformerModel
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ReformerModel
:members:
ReformerModelWithLMHead
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ReformerModelWithLMHead
:members:

View File

@ -296,3 +296,6 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| | ``DialoGPT-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters |
| | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Reformer | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters |
| | | | Trained on English text: Crime and Punishment novel by Fyodor Dostoyevsky |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+

View File

@ -47,6 +47,7 @@ from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_marian import MarianConfig
from .configuration_mmbt import MMBTConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
@ -138,6 +139,7 @@ from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_reformer import ReformerTokenizer
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
@ -159,7 +161,7 @@ if is_sklearn_available():
# Modeling
if is_torch_available():
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, top_k_top_p_filtering
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, top_k_top_p_filtering, apply_chunking_to_forward
from .modeling_auto import (
AutoModel,
AutoModelForPreTraining,
@ -190,6 +192,7 @@ if is_torch_available():
BertForQuestionAnswering,
load_tf_weights_in_bert,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BertLayer,
)
from .modeling_openai import (
OpenAIGPTPreTrainedModel,
@ -320,6 +323,14 @@ if is_torch_available():
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_reformer import (
ReformerAttention,
ReformerLayer,
ReformerModel,
ReformerModelWithLMHead,
REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
)
# Optimization
from .optimization import (
AdamW,

View File

@ -34,12 +34,18 @@ if torch.__version__ < "1.4.0":
else:
gelu = F.gelu
def gelu_fast(x):
return 0.5 * x * (1 + torch.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))
ACT2FN = {
"relu": F.relu,
"swish": swish,
"gelu": gelu,
"tanh": torch.tanh,
"gelu_new": gelu_new,
"gelu_fast": gelu_fast,
}

View File

@ -29,6 +29,7 @@ from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_reformer import ReformerConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
@ -73,6 +74,7 @@ CONFIG_MAPPING = OrderedDict(
("camembert", CamembertConfig,),
("xlm-roberta", XLMRobertaConfig,),
("bart", BartConfig,),
("reformer", ReformerConfig,),
("roberta", RobertaConfig,),
("flaubert", FlaubertConfig,),
("bert", BertConfig,),
@ -130,6 +132,7 @@ class AutoConfig:
- contains `camembert`: :class:`~transformers.CamembertConfig` (CamemBERT model)
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaConfig` (XLM-RoBERTa model)
- contains `roberta`: :class:`~transformers.RobertaConfig` (RoBERTa model)
- contains `reformer`: :class:`~transformers.ReformerConfig` (Reformer model)
- contains `bert`: :class:`~transformers.BertConfig` (Bert model)
- contains `openai-gpt`: :class:`~transformers.OpenAIGPTConfig` (OpenAI GPT model)
- contains `gpt2`: :class:`~transformers.GPT2Config` (OpenAI GPT-2 model)

View File

@ -0,0 +1,210 @@
# coding=utf-8
# Copyright 2020 The Trax Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Reformer model configuration """
import logging
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/config.json"
}
class ReformerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.ReformerModel`.
It is used to instantiate an Reformer model according to the specified arguments, defining the model
architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
for more information.
Args:
attention_head_size (:obj:`int`, optional, defaults to 64):
Dimensionality of the projected key, query and value vectors
attn_layers (:obj:`list(str)`, optional, defaults to ["local", "lsh", "local", "lsh", "local", "lsh"]):
List of attention layer types in ascending order. It can be chosen between a
LSHSelfAttention layer ("lsh") and a LocalSelfAttention layer ("local").
For more information on LSHSelfAttention layer, see `LSH Self Attention <reformer.html#lsh-self-attention>`__ .
For more information on LocalSelfAttention layer, see `Local Self Attention <reformer.html#local-sensitive-hashing-self-attention>`__ .
axial_pos_embds (:obj:`bool`, optional, defaults to True):
If `True` use axial position embeddings. For more information on how axial position embeddings work, see `Axial Position Encodings <reformer.html#axial-positional-encodings>`__
axial_norm_std (:obj:`float`, optional, defaluts to 1.0):
The standard deviation of the normal_initializer for initializing the weight matrices of the axial positional encodings.
axial_pos_shape (:obj:`list(int)`, optional, defaults to `[64, 64]`):
The position dims of the axial position encodings.
During training the product of the position dims has to equal the sequence length.
For more information on how axial position embeddings work, see `Axial Position Encodings <reformer.html#axial-positional-encodings>`__ncodings.
axial_pos_embds_dim (:obj:`list(int)`, optional, defaults to `[64, 192]`):
The embedding dims of the axial position encodings.
The sum of the embedding dims has to equal the hidden size.
For more information on how axial position embeddings work, see `Axial Position Encodings <reformer.html#axial-positional-encodings>`__ncodings.
chunk_size_lm_head (:obj:`int`, optional, defaults to 0):
The chunk size of the final language model feed forward head layer.
A chunk size of 0 means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
chunk_size_feed_forward (:obj:`int`, optional, defaults to 0):
The chunk size of all feed forward layers in the residual attention blocks.
A chunk size of 0 means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
eos_token_id (:obj:`int`, optional, defaults to 2):
The token id for the <EOS> token.
feed_forward_size (:obj:`int`, optional, defaults to 512):
Dimensionality of the "feed_forward" (i.e., feed-forward) layer in the residual attention block.
hash_seed (:obj:`int`, optional, defaults to `None`):
Seed that can be used to make local sensitive hashing in LSHSelfAttention deterministic. This should only be set for testing purposed. For evaluation and training purposes `hash_seed` should be set to `None` to ensure fully random rotations in local sensitive hashing scheme.
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "relu"):
The non-linear activation function (function or string) in the feed forward layer in the residual attention block.
If string, "gelu", "relu", "swish", "gelu_new" and "gelu_fast" are supported.
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.05):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
hidden_size (:obj:`int`, optional, defaults to 256):
Dimensionality of the output hidden states of the residual attention blocks.
initializer_range (:obj:`float`, optional, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
is_decoder (:obj:`bool`, optional, defaults to False):
If `is_decoder` is True, a causal mask is used in addition to `attention_mask`.
When using the Reformer for causal language modeling, `is_decoder` is set to `True`.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers.
local_chunk_length (:obj:`int`, optional, defaults to 64):
Length of chunk which attends to itself in LocalSelfAttention. Chunking reduces memory complexity from sequence length x sequence length (self attention) to chunk length x chunk length x sequence length / chunk length (chunked self attention).
local_num_chunks_before (:obj:`int`, optional, defaults to 1):
Number of previous neighbouring chunks to attend to in LocalSelfAttention layer to itself.
local_num_chunks_after (:obj:`int`, optional, defaults to 0):
Number of following neighbouring chunks to attend to in LocalSelfAttention layer in addition to itself.
local_attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout ratio for the attention probabilities in LocalSelfAttention.
lsh_chunk_length (:obj:`int`, optional, defaults to 64):
Length of chunk which attends to itself in LSHSelfAttention. Chunking reduces memory complexity from sequence length x sequence length (self attention) to chunk length x chunk length x sequence length / chunk length (chunked self attention).
lsh_num_chunks_before (:obj:`int`, optional, defaults to 1):
Number of previous neighbouring chunks to attend to in LSHSelfAttention layer to itself.
lsh_num_chunks_after (:obj:`int`, optional, defaults to 0):
Number of following neighbouring chunks to attend to in LSHSelfAttention layer to itself.
lsh_attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout ratio for the attention probabilities in LSHSelfAttention.
max_position_embeddings (:obj:`int`, optional, defaults to 4096):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
num_attention_heads (:obj:`int`, optional, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `64`):
Number of buckets, the key query vectors can be "hashed into" using the locality sensitive hashing scheme. Each query key vector is hashed into a hash in `1, ..., num_buckets`.
The number of buckets can also be factorized into a list for improved memory complexity. In this case, each query key vector is hashed into a hash in `1-1, 1-2, ..., num_buckets[0]-1, ..., num_buckets[0]-num_buckets[1]` if `num_buckets` is factorized into two factors.
The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length.
num_hashes (:obj:`int`, optional, defaults to 1):
Number of hashing rounds (e.g. number of random rotations) in Local Sensitive Hashing scheme.
The higher `num_hashes`, the more accurate the `LSHSelfAttention` becomes, but also the more memory and time intensive the hashing becomes.
pad_token_id (:obj:`int`, optional, defaults to 0):
The token id for the <PAD> token.
vocab_size (:obj:`int`, optional, defaults to 320):
Vocabulary size of the Reformer model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.ReformerModel`.
Example::
from transformers import ReformerModel, ReformerConfig
# Initializing a Reformer configuration
configuration = ReformerConfig()
# Initializing a Reformer model
model = ReformerModel(configuration)
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "reformer"
def __init__(
self,
attention_head_size=64,
attn_layers=["local", "lsh", "local", "lsh", "local", "lsh"],
axial_norm_std=1.0,
axial_pos_embds=True,
axial_pos_shape=[64, 64],
axial_pos_embds_dim=[64, 192],
chunk_size_lm_head=0,
chunk_size_feed_forward=0,
eos_token_id=2,
feed_forward_size=512,
hash_seed=None,
hidden_act="relu",
hidden_dropout_prob=0.05,
hidden_size=256,
initializer_range=0.02,
is_decoder=False,
layer_norm_eps=1e-12,
local_num_chunks_before=1,
local_num_chunks_after=0,
local_attention_probs_dropout_prob=0.05,
local_attn_chunk_length=64,
lsh_attn_chunk_length=64,
lsh_attention_probs_dropout_prob=0.0,
lsh_num_chunks_before=1,
lsh_num_chunks_after=0,
max_position_embeddings=4096,
num_attention_heads=2,
num_buckets=32,
num_hashes=1,
pad_token_id=0,
vocab_size=320,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_decoder=is_decoder, **kwargs)
self.hash_seed = hash_seed
self.vocab_size = vocab_size
self.attention_head_size = attention_head_size
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_hashes = num_hashes
self.num_hidden_layers = len(attn_layers)
self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets
self.lsh_attn_chunk_length = lsh_attn_chunk_length
self.local_attn_chunk_length = local_attn_chunk_length
self.lsh_num_chunks_after = lsh_num_chunks_after
self.lsh_num_chunks_before = lsh_num_chunks_before
self.local_num_chunks_after = local_num_chunks_after
self.local_num_chunks_before = local_num_chunks_before
self.hidden_act = hidden_act
self.feed_forward_size = feed_forward_size
self.hidden_dropout_prob = hidden_dropout_prob
self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob
self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.axial_pos_embds = axial_pos_embds
self.axial_pos_shape = tuple(axial_pos_shape)
self.axial_pos_embds_dim = tuple(axial_pos_embds_dim)
self.axial_norm_std = axial_norm_std
self.chunk_size_lm_head = chunk_size_lm_head
self.chunk_size_feed_forward = chunk_size_feed_forward
self.attn_layers = attn_layers

View File

@ -0,0 +1,211 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Reformer checkpoint."""
import argparse
import logging
import pickle
import numpy as np
import torch
from transformers import ReformerConfig, ReformerModelWithLMHead
logging.basicConfig(level=logging.INFO)
def set_param(torch_layer, weight, bias=None):
# set parameter of one layer
assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch_layer)
torch_layer.weight = torch.nn.Parameter(weight)
if bias is not None:
assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch_layer)
torch_layer.bias = torch.nn.Parameter(bias)
def set_layer_weights_in_torch_lsh(weights, torch_layer, hidden_size):
# set torch weights for 1-to-1 comparison
np_query_key = np.asarray(weights[0])
np_value = np.asarray(weights[1])
np_dense = np.asarray(weights[2])
set_param(
torch_layer.self_attention.query_key,
torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1),
)
def set_layer_weights_in_torch_local(weights, torch_layer, hidden_size):
# set torch weights for 1-to-1 comparison
np_query = np.asarray(weights[0])
np_key = np.asarray(weights[1])
np_value = np.asarray(weights[2])
np_dense = np.asarray(weights[3])
set_param(
torch_layer.self_attention.query, torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.self_attention.key, torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1),
)
def set_block_weights_in_torch(weights, torch_block, hidden_size):
# layernorm 1
layer_norm_1 = weights[0][0][0]
layer_norm_1_weight = np.asarray(layer_norm_1[0])
layer_norm_1_bias = np.asarray(layer_norm_1[1])
set_param(
torch_block.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias),
)
# lsh weights + output
attn_weights = weights[0][1]
if len(attn_weights) < 4:
set_layer_weights_in_torch_lsh(attn_weights, torch_block.attention, hidden_size)
else:
set_layer_weights_in_torch_local(attn_weights, torch_block.attention, hidden_size)
# intermediate weighs
intermediate_weights = weights[2][0][2][2]
# Chunked Feed Forward
if len(intermediate_weights) == 4:
intermediate_weights = intermediate_weights[2]
# layernorm 2
layer_norm_2_weight = np.asarray(intermediate_weights[0][0])
layer_norm_2_bias = np.asarray(intermediate_weights[0][1])
set_param(
torch_block.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias),
)
# intermediate dense
inter_dense_weight = np.asarray(intermediate_weights[1][0])
inter_dense_bias = np.asarray(intermediate_weights[1][1])
set_param(
torch_block.feed_forward.dense.dense,
torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(),
torch.tensor(inter_dense_bias),
)
# intermediate out
out_dense_weight = np.asarray(intermediate_weights[4][0])
out_dense_bias = np.asarray(intermediate_weights[4][1])
set_param(
torch_block.feed_forward.output.dense,
torch.tensor(out_dense_weight).transpose(0, 1).contiguous(),
torch.tensor(out_dense_bias),
)
def set_model_weights_in_torch(weights, torch_model, hidden_size):
# reformer model
torch_model_reformer = torch_model.reformer
# word embeds
word_embeddings = np.asarray(weights[1])
set_param(
torch_model_reformer.embeddings.word_embeddings, torch.tensor(word_embeddings),
)
if isinstance(weights[3], tuple):
position_embeddings = torch_model_reformer.embeddings.position_embeddings
for emb_idx in range(len(position_embeddings.weights)):
emb_weights = np.asarray(weights[3][emb_idx][0])
assert position_embeddings.weights[emb_idx].shape == emb_weights.shape, "{} emb does not match".format(
position_embeddings[emb_idx]
)
position_embeddings.weights[emb_idx] = torch.nn.Parameter(torch.tensor(emb_weights))
trax_layer_weights = weights[5]
assert len(torch_model_reformer.encoder.layers) * 4 + 1 == len(
trax_layer_weights
), "HF and trax model do not have the same number of layers"
for layer_idx, layer in enumerate(torch_model_reformer.encoder.layers):
block_weights = trax_layer_weights[4 * layer_idx : 4 * (layer_idx + 1)]
set_block_weights_in_torch(block_weights, layer, hidden_size)
# output weights
out_weights = weights[6]
# output layer norm
layer_norm_out_weight = np.asarray(out_weights[0][0])
layer_norm_out_bias = np.asarray(out_weights[0][1])
set_param(
torch_model_reformer.encoder.layer_norm,
torch.tensor(layer_norm_out_weight),
torch.tensor(layer_norm_out_bias),
)
# output embeddings
output_embed_weights = np.asarray(out_weights[2][0])
output_embed_bias = np.asarray(out_weights[2][1])
set_param(
torch_model.lm_head.decoder,
torch.tensor(output_embed_weights).transpose(0, 1).contiguous(),
torch.tensor(output_embed_bias),
)
def convert_trax_checkpoint_to_pytorch(trax_model_pkl_path, config_file, pytorch_dump_path):
# Initialise PyTorch model
config = ReformerConfig.from_json_file(config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = ReformerModelWithLMHead(config)
with open(trax_model_pkl_path, "rb") as f:
model_weights = pickle.load(f)["weights"]
set_model_weights_in_torch(model_weights, model, config.hidden_size)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--trax_model_pkl_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained Reformer model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_trax_checkpoint_to_pytorch(args.trax_model_pkl_path, args.config_file, args.pytorch_dump_path)

View File

@ -31,6 +31,7 @@ from .configuration_auto import (
FlaubertConfig,
GPT2Config,
OpenAIGPTConfig,
ReformerConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
@ -97,6 +98,7 @@ from .modeling_flaubert import (
)
from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model
from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
from .modeling_roberta import (
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM,
@ -179,6 +181,7 @@ MODEL_MAPPING = OrderedDict(
(XLMConfig, XLMModel),
(CTRLConfig, CTRLModel),
(ElectraConfig, ElectraModel),
(ReformerConfig, ReformerModel),
]
)
@ -222,6 +225,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(CTRLConfig, CTRLLMHeadModel),
(ElectraConfig, ElectraForMaskedLM),
(EncoderDecoderConfig, EncoderDecoderModel),
(ReformerConfig, ReformerModelWithLMHead),
]
)

File diff suppressed because it is too large Load Diff

View File

@ -13,8 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
import inspect
import logging
import os
from typing import Callable, Tuple
@ -175,7 +175,7 @@ class ModuleUtilsMixin:
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def get_head_mask(self, head_mask, num_hidden_layers):
def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
"""
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@ -189,6 +189,8 @@ class ModuleUtilsMixin:
"""
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
if is_attention_chunked is True:
head_mask = head_mask.unsqueeze(-1)
else:
head_mask = [None] * num_hidden_layers
@ -786,6 +788,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
attention_mask=None,
decoder_start_token_id=None,
use_cache=None,
**model_specific_kwargs
):
r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
@ -863,6 +866,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
use_cache: (`optional`) bool
If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
model_specific_kwargs: (`optional`) dict
Additional model specific kwargs will be forwarded to the `forward` function of the model.
Return:
output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
@ -1116,6 +1122,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
model_specific_kwargs=model_specific_kwargs,
)
else:
output = self._generate_no_beam_search(
@ -1138,6 +1145,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
model_specific_kwargs=model_specific_kwargs,
)
return output
@ -1163,6 +1171,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_outputs,
attention_mask,
use_cache,
model_specific_kwargs,
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
@ -1175,7 +1184,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
)
outputs = self(**model_inputs)
@ -1288,6 +1297,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_outputs,
attention_mask,
use_cache,
model_specific_kwargs,
):
""" Generate sequences for each example with beam search.
"""
@ -1314,7 +1324,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
@ -2087,3 +2097,66 @@ def prune_layer(layer, index, dim=None):
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
else:
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
def apply_chunking_to_forward(
chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
) -> torch.Tensor:
"""
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`.
It then applies a layer `forward_fn` to each chunk independently to save memory.
If the `forward_fn` is independent across the `chunk_dim` this function will yield the
same result as not applying it.
Args:
chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size`
chunk_dim: int - the dimension over which the input_tensors should be chunked
forward_fn: fn - the forward fn of the model
input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked
Returns:
a Tensor with the same shape the foward_fn would have given if applied
Examples::
# rename the usual forward() fn to forward_chunk()
def forward_chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states
# implement a chunked forward function
def forward(self, hidden_states):
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
"""
assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
tensor_shape = input_tensors[0].shape
assert all(
input_tensor.shape == tensor_shape for input_tensor in input_tensors
), "All input tenors have to be of the same shape"
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
assert num_args_in_forward_chunk_fn == len(
input_tensors
), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(
num_args_in_forward_chunk_fn, len(input_tensors)
)
if chunk_size > 0:
assert (
input_tensors[0].shape[chunk_dim] % chunk_size == 0
), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
input_tensors[0][chunk_dim], chunk_size
)
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
# chunk input tensor into tuples
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
# apply forward fn to every tuple
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
# concatenate output at same dimension
return torch.cat(output_chunks, dim=chunk_dim)
return forward_fn(*input_tensors)

View File

@ -30,6 +30,7 @@ from .configuration_auto import (
FlaubertConfig,
GPT2Config,
OpenAIGPTConfig,
ReformerConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
@ -49,6 +50,7 @@ from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_reformer import ReformerTokenizer
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLTokenizerFast
@ -69,6 +71,7 @@ TOKENIZER_MAPPING = OrderedDict(
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
(BartConfig, (BartTokenizer, None)),
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
(ReformerConfig, (ReformerTokenizer, None)),
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
(BertConfig, (BertTokenizer, BertTokenizerFast)),
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),

View File

@ -0,0 +1,179 @@
# coding=utf-8
# Copyright 2020 The Trax Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization class for model Reformer."""
import logging
import os
from shutil import copyfile
from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
SPIECE_UNDERLINE = ""
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to file names for serializing Tokenizer instances
####################################################
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to pretrained vocabulary URL for all the model shortcut names.
####################################################
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/spiece.model"
}
}
####################################################
# Mapping from model shortcut names to max length of inputs
####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"google/reformer-crime-and-punishment": 524288,
}
class ReformerTokenizer(PreTrainedTokenizer):
"""
Constructs an Reformer tokenizer. Based on `SentencePiece <https://github.com/google/sentencepiece>`__ .
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
should refer to the superclass for more information regarding methods.
Args:
vocab_file (:obj:`string`):
`SentencePiece <https://github.com/google/sentencepiece>`__ file (generally has a `.spm` extension) that
contains the vocabulary necessary to instantiate a tokenizer.
eos_token (:obj:`string`, `optional`, defaults to "</s>"):
The end of sequence token.
.. note::
When building a sequence using special tokens, this is not the token that is used for the end
of sequence. The token used is the :obj:`sep_token`.
unk_token (:obj:`string`, `optional`, defaults to "<unk>"):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (:obj:`string`, `optional`, defaults to "<pad>"):
The token used for padding, for example when batching sequences of different lengths.
additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`None`):
Additional special tokens used by the tokenizer.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
additional_special_tokens=[],
**kwargs
):
super().__init__(
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
try:
import sentencepiece as spm
except ImportError:
logger.warning(
"You need to install SentencePiece to use ReformerTokenizer:"
"https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
@property
def vocab_size(self):
return self.sp_model.get_piece_size()
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state
def __setstate__(self, d):
self.__dict__ = d
try:
import sentencepiece as spm
except ImportError:
logger.warning(
"You need to install SentencePiece to use ReformerTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)
def _tokenize(self, text, sample=False):
""" Take as input a string and return a list of strings (tokens) for words/sub-words
"""
if not sample:
pieces = self.sp_model.EncodeAsPieces(text)
else:
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
return pieces
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index < self.sp_model.get_piece_size():
token = self.sp_model.IdToPiece(index)
return token
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = self.sp_model.decode_pieces(tokens)
return out_string
def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory.
"""
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)

View File

@ -22,6 +22,8 @@ class TestActivations(unittest.TestCase):
get_activation("swish")
get_activation("relu")
get_activation("tanh")
get_activation("gelu_new")
get_activation("gelu_fast")
with self.assertRaises(KeyError):
get_activation("bogus")
with self.assertRaises(KeyError):

View File

@ -125,6 +125,9 @@ class ModelTesterMixin:
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes:
config.output_attentions = True
@ -138,10 +141,17 @@ class ModelTesterMixin:
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
if self.is_encoder_decoder:
@ -175,10 +185,16 @@ class ModelTesterMixin:
self_attentions = outputs[-1]
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_torchscript(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -465,14 +481,16 @@ class ModelTesterMixin:
self.assertEqual(model.config.output_attentions, False)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
if hasattr(self.model_tester, "encoder_seq_length"):
seq_length = self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
seq_length = seq_length * self.model_tester.chunk_length
else:
seq_length = self.model_tester.seq_length
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[
self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "encoder_seq_length")
else self.model_tester.seq_length,
self.model_tester.hidden_size,
],
list(hidden_states[0].shape[-2:]), [seq_length, self.model_tester.hidden_size],
)
def test_resize_tokens_embeddings(self):
@ -485,6 +503,9 @@ class ModelTesterMixin:
model = model_class(config)
model.to(torch_device)
if self.model_tester.is_training is False:
model.eval()
model_vocab_size = config.vocab_size
# Retrieve the embeddings and clone theme
model_embed = model.resize_token_embeddings(model_vocab_size)
@ -628,9 +649,13 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
# make sure that input_ids is at most of size 15
input_ids = input_ids[..., :15]
# iterate over all generative models
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
if config.bos_token_id is None:
# if bos token id is not defined, model needs input_ids
@ -669,8 +694,12 @@ class ModelTesterMixin:
torch_device
)
# make sure that input_ids is at most of size 15
input_ids = input_ids[..., :15]
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
if config.bos_token_id is None:
# if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
@ -750,7 +779,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
def floats_tensor(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor of the shape within the vocab size."""
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng

View File

@ -0,0 +1,955 @@
# coding=utf-8 # Copyright 2020 Huggingface
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():
from transformers import (
ReformerConfig,
ReformerModel,
ReformerModelWithLMHead,
ReformerTokenizer,
ReformerLayer,
REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
)
import torch
class ReformerModelTester:
def __init__(
self,
parent,
batch_size=None,
seq_length=None,
is_training=None,
is_decoder=None,
use_input_mask=None,
vocab_size=None,
attention_head_size=None,
hidden_size=None,
num_attention_heads=None,
local_attn_chunk_length=None,
local_num_chunks_before=None,
local_num_chunks_after=None,
num_buckets=None,
num_hashes=1,
lsh_attn_chunk_length=None,
lsh_num_chunks_before=None,
lsh_num_chunks_after=None,
chunk_size_lm_head=None,
chunk_size_feed_forward=None,
feed_forward_size=None,
hidden_act=None,
hidden_dropout_prob=None,
local_attention_probs_dropout_prob=None,
lsh_attention_probs_dropout_prob=None,
max_position_embeddings=None,
initializer_range=None,
axial_norm_std=None,
layer_norm_eps=None,
axial_pos_embds=None,
axial_pos_shape=None,
axial_pos_embds_dim=None,
attn_layers=None,
pad_token_id=None,
eos_token_id=None,
scope=None,
hash_seed=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.is_decoder = is_decoder
self.use_input_mask = use_input_mask
self.vocab_size = vocab_size
self.attention_head_size = attention_head_size
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = len(attn_layers)
self.local_attn_chunk_length = local_attn_chunk_length
self.local_num_chunks_after = local_num_chunks_after
self.local_num_chunks_before = local_num_chunks_before
self.num_hashes = num_hashes
self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets
self.lsh_attn_chunk_length = lsh_attn_chunk_length
self.lsh_num_chunks_after = lsh_num_chunks_after
self.lsh_num_chunks_before = lsh_num_chunks_before
self.hidden_act = hidden_act
self.feed_forward_size = feed_forward_size
self.hidden_dropout_prob = hidden_dropout_prob
self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob
self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.axial_pos_embds = axial_pos_embds
self.axial_pos_shape = tuple(axial_pos_shape)
self.axial_pos_embds_dim = tuple(axial_pos_embds_dim)
self.axial_norm_std = axial_norm_std
self.chunk_size_lm_head = chunk_size_lm_head
self.chunk_size_feed_forward = chunk_size_feed_forward
self.scope = scope
self.attn_layers = attn_layers
self.pad_token_id = pad_token_id
self.hash_seed = hash_seed
attn_chunk_length = local_attn_chunk_length if local_attn_chunk_length is not None else lsh_attn_chunk_length
num_chunks_after = local_num_chunks_after if local_num_chunks_after is not None else lsh_num_chunks_after
num_chunks_before = local_num_chunks_before if local_num_chunks_before is not None else lsh_num_chunks_before
self.encoder_seq_length = seq_length // attn_chunk_length + (self.seq_length % attn_chunk_length != 0)
self.key_length = (num_chunks_before + num_chunks_after + 1) * attn_chunk_length
self.chunk_length = attn_chunk_length
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
config = ReformerConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
feed_forward_size=self.feed_forward_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
local_attention_probs_dropout_prob=self.local_attention_probs_dropout_prob,
lsh_attention_probs_dropout_prob=self.lsh_attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
is_decoder=self.is_decoder,
axial_pos_embds=self.axial_pos_embds,
axial_pos_shape=self.axial_pos_shape,
axial_pos_embds_dim=self.axial_pos_embds_dim,
local_attn_chunk_length=self.local_attn_chunk_length,
local_num_chunks_after=self.local_num_chunks_after,
local_num_chunks_before=self.local_num_chunks_before,
num_hashes=self.num_hashes,
num_buckets=self.num_buckets,
lsh_attn_chunk_length=self.lsh_attn_chunk_length,
lsh_num_chunks_after=self.lsh_num_chunks_after,
lsh_num_chunks_before=self.lsh_num_chunks_before,
attn_layers=self.attn_layers,
pad_token_id=self.pad_token_id,
hash_seed=self.hash_seed,
)
return (
config,
input_ids,
input_mask,
)
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_reformer_model(
self, config, input_ids, input_mask,
):
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
(sequence_output,) = model(input_ids, attention_mask=input_mask)
(sequence_output,) = model(input_ids)
result = {
"sequence_output": sequence_output,
}
# 2 * hidden_size because we use reversible resnet layers
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
)
def create_and_check_reformer_model_with_lm_backward(
self, config, input_ids, input_mask,
):
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
loss.backward()
def create_and_check_reformer_with_lm(
self, config, input_ids, input_mask,
):
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.check_loss_output(result)
def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder):
# no special position embeddings
config.axial_pos_embds = False
config.is_decoder = is_decoder
if self.lsh_attn_chunk_length is not None:
# need to set chunk length equal sequence length to be certain that chunking works
config.lsh_attn_chunk_length = self.seq_length
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
# set all position encodings to zero so that postions don't matter
with torch.no_grad():
embedding = model.embeddings.position_embeddings.embedding
embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device))
embedding.weight.requires_grad = False
half_seq_len = self.seq_length // 2
roll = self.chunk_length
half_input_ids = input_ids[:, :half_seq_len]
# normal padded
attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1,)
input_ids_padded = torch.cat(
[half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1,
)
# shifted padded
input_ids_roll = torch.cat(
[half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1,
)
input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1)
attn_mask_roll = torch.roll(attn_mask, roll, dims=-1)
output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len]
output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][:, roll : half_seq_len + roll]
self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder):
config.is_decoder = is_decoder
layer = ReformerLayer(config).to(torch_device)
layer.train()
shape = (
self.batch_size,
self.seq_length,
config.hidden_size,
) # Batch x SeqLen x hiddenSize
# get random tensors
hidden_states = floats_tensor(shape)
prev_attn_output = floats_tensor(shape)
# now the random seeds for attention and feed forward is initialized
# forward tensors with dropout
layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask)
next_attn_output = layer_outputs.attn_output
next_hidden_states = layer_outputs.hidden_states
torch.manual_seed(layer.attention_seed)
attn_outputs = layer.attention(hidden_states, attention_mask=input_mask)
self.parent.assertTrue(
torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3,)
)
torch.manual_seed(layer.feed_forward_seed)
feed_forward_hidden_states = layer.feed_forward(next_attn_output)
self.parent.assertTrue(
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
)
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask):
torch.manual_seed(0)
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0]
config.chunk_size_lm_head = 1
config.chunk_size_feed_forward = 1
torch.manual_seed(0)
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask):
if not self.is_training:
return
# disable dropout
config.hidden_dropout_prob = 0
config.local_attention_probs_dropout_prob = 0
config.lsh_attention_probs_dropout_prob = 0
torch.manual_seed(0)
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.train()
model.zero_grad()
loss_no_chunk, output_no_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[:2]
loss_no_chunk.backward()
grad_slice_word_no_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
grad_slice_position_factor_1_no_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
grad_slice_position_factor_2_no_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
config.chunk_size_lm_head = 1
config.chunk_size_feed_forward = 1
torch.manual_seed(0)
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.train()
model.zero_grad()
loss_chunk, output_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[:2]
loss_chunk.backward()
grad_slice_word_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
grad_slice_position_factor_1_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
grad_slice_position_factor_2_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
self.parent.assertTrue(torch.allclose(loss_chunk, loss_no_chunk, atol=1e-3))
self.parent.assertTrue(torch.allclose(grad_slice_word_no_chunk, grad_slice_word_chunk, atol=1e-3))
self.parent.assertTrue(
torch.allclose(grad_slice_position_factor_1_chunk, grad_slice_position_factor_1_no_chunk, atol=1e-3)
)
self.parent.assertTrue(
torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3)
)
def create_and_check_reformer_random_seed(self, config, input_ids, input_mask):
layer = ReformerLayer(config).to(torch_device)
layer.train()
shape = (
self.batch_size,
self.seq_length,
config.hidden_size,
) # Batch x SeqLen x hiddenSize
hidden_states = floats_tensor(shape)
attn_output = floats_tensor(shape)
seeds = []
for _ in range(100):
layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask)
attn_output = layer_outputs.attn_output
hidden_states = layer_outputs.hidden_states
torch.manual_seed(layer.attention_seed)
seeds.append(layer.attention_seed)
self.parent.assertGreater(len(set(seeds)), 70)
seeds = []
for _ in range(100):
layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask)
attn_output = layer_outputs.attn_output
hidden_states = layer_outputs.hidden_states
torch.manual_seed(layer.feed_forward_seed)
seeds.append(layer.feed_forward_seed)
self.parent.assertGreater(len(set(seeds)), 70)
def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask):
model = ReformerModel(config=config)
model.to(torch_device)
model.half()
model.eval()
output = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask):
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.half()
model.eval()
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
self.parent.assertFalse(torch.isnan(output).any().item())
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask,) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
class ReformerTesterMixin:
"""
Reformer Local and Reformer LSH run essentially the same tests
"""
def test_config(self):
self.config_tester.run_common_tests()
def test_reformer_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model(*config_and_inputs)
def test_reformer_lm_model_backward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_with_lm_backward(*config_and_inputs)
def test_reformer_model_attn_masking(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True)
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False)
def test_reformer_with_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs)
def test_reformer_layer_training_dropout(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True)
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False)
def test_reformer_chunking_forward_equality(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs)
def test_reformer_chunking_backward_equality(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
@slow
def test_dropout_random_seed_is_changing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_random_seed(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_reformer_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_reformer_model_fp16_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)
@require_torch
class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest.TestCase):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
test_torchscript = False
def prepare_kwargs(self):
return {
"batch_size": 13,
"seq_length": 32,
"is_training": True,
"is_decoder": False,
"use_input_mask": True,
"vocab_size": 32,
"attention_head_size": 16,
"hidden_size": 32,
"num_attention_heads": 2,
"local_attn_chunk_length": 4,
"local_num_chunks_before": 1,
"local_num_chunks_after": 0,
"chunk_size_lm_head": 0,
"chunk_size_feed_forward": 0,
"feed_forward_size": 32,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"local_attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"initializer_range": 0.02,
"axial_norm_std": 1.0,
"layer_norm_eps": 1e-12,
"axial_pos_embds": True,
"axial_pos_shape": [4, 8],
"axial_pos_embds_dim": [16, 16],
"attn_layers": ["local", "local", "local", "local"],
"pad_token_id": 0,
"eos_token_id": 2,
"scope": None,
"hash_seed": 0,
}
def setUp(self):
tester_kwargs = self.prepare_kwargs()
self.model_tester = ReformerModelTester(self, **tester_kwargs)
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
@slow
def test_model_from_pretrained(self):
for model_name in list(REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = ReformerModelWithLMHead.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch
class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase, ReformerTesterMixin):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
test_torchscript = False
def prepare_kwargs(self):
return {
"batch_size": 13,
"seq_length": 13,
"use_input_mask": True,
"is_training": False,
"is_decoder": False,
"vocab_size": 32,
"attention_head_size": 16,
"hidden_size": 64,
"num_attention_heads": 2,
"num_buckets": 2,
"num_hashes": 4,
"lsh_attn_chunk_length": 4,
"lsh_num_chunks_before": 2,
"lsh_num_chunks_after": 3,
"chunk_size_lm_head": 5,
"chunk_size_feed_forward": 6,
"feed_forward_size": 32,
"hidden_act": "relu",
"hidden_dropout_prob": 0.1,
"lsh_attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"initializer_range": 0.02,
"axial_norm_std": 1.0,
"layer_norm_eps": 1e-12,
"axial_pos_embds": True,
"axial_pos_shape": [4, 8],
"axial_pos_embds_dim": [16, 48],
"attn_layers": ["lsh", "lsh", "lsh", "lsh"],
"pad_token_id": 0,
"eos_token_id": 2,
"scope": None,
"hash_seed": 0,
}
def setUp(self):
tester_kwargs = self.prepare_kwargs()
self.model_tester = ReformerModelTester(self, **tester_kwargs)
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
@require_torch
class ReformerIntegrationTests(unittest.TestCase):
"""
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/04/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `branch_to_save_trax_integration_tests`.
"""
def _get_basic_config_and_input(self):
config = {
"vocab_size": 320,
"attention_head_size": 8,
"hidden_size": 16,
"num_attention_heads": 2,
"num_buckets": 2,
"num_hashes": 4,
"lsh_attn_chunk_length": 4,
"local_attn_chunk_length": 4,
"lsh_num_chunks_before": 1,
"lsh_num_chunks_after": 0,
"local_num_chunks_before": 1,
"local_num_chunks_after": 0,
"chunk_size_lm_head": 0,
"chunk_size_feed_forward": 0,
"feed_forward_size": 32,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"lsh_attention_probs_dropout_prob": 0.0,
"local_attention_probs_dropout_prob": 0.0,
"max_position_embeddings": 32,
"initializer_range": 0.02,
"axial_norm_std": 1.0,
"layer_norm_eps": 1e-12,
"sinusoidal_pos_embds": False,
"axial_pos_embds": True,
"axial_pos_shape": [4, 8],
"axial_pos_embds_dim": [8, 8],
"hash_seed": 0,
"is_decoder": True,
}
return config
def _get_hidden_states(self):
return torch.tensor(
[
[
[
1.90826353e00,
-1.45999730e00,
-6.20405462e-01,
1.52503433e00,
-3.64464232e-01,
-8.27359235e-01,
8.39670803e-01,
2.44492178e-01,
4.98332758e-01,
2.69175139e00,
-7.08081422e-03,
1.04915401e00,
-1.83476661e00,
7.67220476e-01,
2.98580543e-01,
2.84803992e-02,
],
[
-2.66374286e-02,
4.33497576e-01,
3.10386309e-01,
5.46039944e-01,
-2.47292666e-04,
-7.52305019e-01,
2.39162103e-01,
7.25216186e-01,
-7.58357372e-01,
4.20635998e-01,
-4.04739919e-02,
1.59924145e-01,
2.05135748e00,
-1.15997978e00,
5.37166397e-01,
2.62873606e-01,
],
[
1.85247482e-01,
7.07046037e-01,
-6.77089715e-01,
-2.24209655e00,
-3.75307980e-02,
-8.59380874e-01,
-2.81027884e00,
1.01276376e00,
-1.69438001e00,
4.17574660e-01,
-1.49196962e00,
-1.76483717e00,
-1.94566312e-01,
-1.71183858e00,
7.72903565e-01,
-1.11557056e00,
],
[
9.46069193e-01,
1.53417623e-01,
-9.58686996e-01,
1.18126669e-01,
1.75967724e00,
1.62194590e00,
-5.74108159e-01,
6.79920443e-01,
5.44028163e-01,
2.05466114e-01,
-3.63045868e-01,
2.41865062e-01,
3.20348382e-01,
-9.05611176e-01,
-1.92690727e-01,
-1.19917547e00,
],
]
],
dtype=torch.float32,
device=torch_device,
)
def _get_attn_mask(self):
return torch.tensor([[0, 1, 0, 0]], dtype=torch.long, device=torch_device)
def _get_input_ids_and_mask(self):
mask = torch.tensor(
[
[1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0],
],
dtype=torch.long,
device=torch_device,
)
input_ids = torch.tensor(
[
[
89,
279,
286,
84,
194,
316,
182,
28,
283,
37,
169,
7,
253,
267,
107,
250,
44,
7,
102,
62,
3,
243,
171,
265,
302,
48,
164,
264,
148,
229,
280,
150,
],
[
9,
192,
66,
112,
163,
83,
135,
70,
224,
96,
31,
80,
196,
80,
63,
22,
85,
100,
47,
283,
0,
163,
126,
143,
195,
82,
53,
82,
18,
27,
182,
52,
],
],
dtype=torch.long,
device=torch_device,
)
return input_ids, mask
def test_lsh_layer_forward(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["lsh"]
config["is_decoder"] = False
hidden_states = self._get_hidden_states()
torch.manual_seed(0)
layer = ReformerLayer(ReformerConfig(**config)).to(torch_device)
layer.eval()
reformer_output = layer(prev_attn_output=hidden_states.clone(), hidden_states=hidden_states)
output_slice = reformer_output.hidden_states[0, 0, :5]
expected_output_slice = torch.tensor(
[1.6879, -1.3083, -0.4708, 1.3555, -0.6292], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_lsh_layer_forward_complex(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["lsh"]
config["num_buckets"] = [2, 4]
attn_mask = self._get_attn_mask()
hidden_states = self._get_hidden_states()
torch.manual_seed(0)
layer = ReformerLayer(ReformerConfig(**config)).to(torch_device)
layer.eval()
reformer_output = layer(
prev_attn_output=hidden_states.clone(), hidden_states=hidden_states, attention_mask=attn_mask,
)
output_slice = reformer_output.hidden_states[0, 0, :5]
expected_output_slice = torch.tensor(
[1.6439, -1.2306, -0.5108, 1.3006, -0.6537], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_local_layer_forward(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["local"]
config["is_decoder"] = False
hidden_states = self._get_hidden_states()
torch.manual_seed(0)
layer = ReformerLayer(ReformerConfig(**config)).to(torch_device)
layer.eval()
reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states)
output_slice = reformer_output.hidden_states[0, 0, :5]
expected_output_slice = torch.tensor(
[1.4212, -2.0576, -0.9688, 1.4599, -0.1344], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_local_layer_forward_complex(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["local"]
attn_mask = self._get_attn_mask()
hidden_states = self._get_hidden_states()
torch.manual_seed(0)
layer = ReformerLayer(ReformerConfig(**config)).to(torch_device)
layer.eval()
reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,)
output_slice = reformer_output.hidden_states[0, 0, :5]
expected_output_slice = torch.tensor(
[1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_lsh_model_forward(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["lsh", "lsh", "lsh", "lsh"]
config["num_buckets"] = [2, 4]
torch.manual_seed(0)
model = ReformerModel(ReformerConfig(**config)).to(torch_device)
model.eval()
input_ids, attn_mask = self._get_input_ids_and_mask()
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
output_slice = hidden_states[0, 0, :5]
expected_output_slice = torch.tensor(
[-0.9896, -0.9396, -1.0831, -0.0597, 0.2456], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_local_model_forward(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["local", "local", "local", "local"]
torch.manual_seed(0)
model = ReformerModel(ReformerConfig(**config)).to(torch_device)
model.eval()
input_ids, attn_mask = self._get_input_ids_and_mask()
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
output_slice = hidden_states[0, 0, :5]
expected_output_slice = torch.tensor(
[-1.6791, 0.7171, 0.1594, 0.4063, 1.2584], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_lm_model_forward(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["local", "lsh", "local", "lsh", "local", "lsh"]
config["num_buckets"] = [2, 4]
config["is_decoder"] = False
torch.manual_seed(0)
model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device)
model.eval()
input_ids, attn_mask = self._get_input_ids_and_mask()
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
output_slice = hidden_states[1, -1, :5]
expected_output_slice = torch.tensor(
[0.0324, -0.0121, 0.0615, 0.0031, -0.0297], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_local_lm_model_grad(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["local", "local", "local", "local"]
config["hidden_dropout_prob"] = 0.0
config["local_attention_probs_dropout_prob"] = 0.0
torch.manual_seed(0)
model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device)
model.train()
model.zero_grad()
input_ids, _ = self._get_input_ids_and_mask()
loss = model(input_ids=input_ids, labels=input_ids)[0]
self.assertTrue(torch.allclose(loss, torch.tensor(5.7786, dtype=torch.float, device=torch_device), atol=1e-3))
loss.backward()
# check last grads to cover all proable errors
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
expected_grad_slice_word = torch.tensor(
[-0.0005, 0.0001, 0.0002, 0.0003, 0.0006], dtype=torch.float, device=torch_device,
)
grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
expected_grad_slice_pos_fac_1 = torch.tensor(
[0.0037, -1.3793, -1.0231, -1.5230, -2.5306], dtype=torch.float, device=torch_device,
)
grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
expected_grad_slice_pos_fac_2 = torch.tensor(
[-1.3165, 0.5168, 0.7785, 1.0811, -0.9830], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3))
self.assertTrue(torch.allclose(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3))
self.assertTrue(torch.allclose(grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3))
def test_lsh_lm_model_grad(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["lsh", "lsh", "lsh", "lsh"]
config["hidden_dropout_prob"] = 0.0
config["lsh_attention_probs_dropout_prob"] = 0.0
config["num_buckets"] = [2, 4]
config["num_hashes"] = 6
torch.manual_seed(0)
model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device)
model.train()
model.zero_grad()
input_ids, _ = self._get_input_ids_and_mask()
loss = model(input_ids=input_ids, labels=input_ids)[0]
self.assertTrue(torch.allclose(loss, torch.tensor(5.7819, dtype=torch.float, device=torch_device), atol=1e-3))
loss.backward()
# check last grads to cover all proable errors
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
expected_grad_slice_word = torch.tensor(
[2.6357e-05, 4.3358e-04, -8.4985e-04, 1.0094e-04, 3.8954e-04], dtype=torch.float, device=torch_device,
)
grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
expected_grad_slice_pos_fac_1 = torch.tensor(
[-0.0984, 0.6283, 0.4282, 1.2960, 0.6897], dtype=torch.float, device=torch_device,
)
grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
expected_grad_slice_pos_fac_2 = torch.tensor(
[0.4626, -0.0231, -0.0172, 0.1081, 0.3805], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3))
self.assertTrue(torch.allclose(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3))
self.assertTrue(torch.allclose(grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3))
@slow
def test_pretrained_generate_crime_and_punish(self):
model = ReformerModelWithLMHead.from_pretrained("google/reformer-crime-and-punishment").to(torch_device)
tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punishment")
model.eval()
input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device)
output_ids = model.generate(
input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False, num_hashes=8
)
output_text = tokenizer.decode(output_ids[0])
self.assertEqual(
output_text,
"A few months later state expression in his ideas, at the first entrance. He was positively for an inst",
)