[Reformer] Add Masked LM Reformer (#5426)

* fix conflicts

* fix

* happy rebasing
This commit is contained in:
Patrick von Platen 2020-07-01 22:43:18 +02:00 committed by GitHub
parent f4323dbf8c
commit d16e36c7e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 130 additions and 11 deletions

View File

@ -114,9 +114,15 @@ ReformerModelWithLMHead
:members:
ReformerForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ReformerForMaskedLM
:members:
ReformerForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ReformerForQuestionAnswering
:members:

View File

@ -366,6 +366,7 @@ if is_torch_available():
ReformerAttention,
ReformerLayer,
ReformerModel,
ReformerForMaskedLM,
ReformerModelWithLMHead,
ReformerForQuestionAnswering,
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,

View File

@ -122,7 +122,12 @@ from .modeling_mobilebert import (
MobileBertModel,
)
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_reformer import ReformerForQuestionAnswering, ReformerModel, ReformerModelWithLMHead
from .modeling_reformer import (
ReformerForMaskedLM,
ReformerForQuestionAnswering,
ReformerModel,
ReformerModelWithLMHead,
)
from .modeling_retribert import RetriBertModel
from .modeling_roberta import (
RobertaForMaskedLM,
@ -266,6 +271,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
(FlaubertConfig, FlaubertWithLMHeadModel),
(XLMConfig, XLMWithLMHeadModel),
(ElectraConfig, ElectraForMaskedLM),
(ReformerConfig, ReformerForMaskedLM),
]
)

View File

@ -1704,6 +1704,7 @@ class ReformerModel(ReformerPreTrainedModel):
class ReformerModelWithLMHead(ReformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
assert config.is_decoder, "If you want to use `ReformerLMHeadModel` make sure that `is_decoder=True`."
self.reformer = ReformerModel(config)
self.lm_head = ReformerOnlyLMHead(config)
@ -1791,6 +1792,87 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
return inputs_dict
@add_start_docstrings("""Reformer Model with a `language modeling` head on top. """, REFORMER_START_DOCSTRING)
class ReformerForMaskedLM(ReformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
assert (
not config.is_decoder
), "If you want to use `ReformerForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
self.reformer = ReformerModel(config)
self.lm_head = ReformerOnlyLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.lm_head.decoder
def tie_weights(self):
# word embeddings are not tied in Reformer
pass
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/reformer-crime-and-punishment")
def forward(
self,
input_ids=None,
position_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
num_hashes=None,
labels=None,
output_hidden_states=None,
output_attentions=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss (cross entropy).
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
reformer_outputs = self.reformer(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
num_hashes=num_hashes,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
sequence_output = reformer_outputs[0]
logits = self.lm_head(sequence_output)
outputs = (logits,) + reformer_outputs[1:]
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (masked_lm_loss,) + outputs
return outputs # (mlm_loss), lm_logits, (hidden_states), (attentions)
@add_start_docstrings(
"""Reformer Model with a span classification head on top for
extractive question-answering tasks like SQuAD / TriviaQA ( a linear layer on

View File

@ -25,6 +25,7 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
from transformers import (
ReformerConfig,
ReformerForMaskedLM,
ReformerModel,
ReformerModelWithLMHead,
ReformerTokenizer,
@ -209,7 +210,24 @@ class ReformerModelTester:
)
self.check_loss_output(result)
def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, choice_labels, is_decoder):
def create_and_check_reformer_with_mlm(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = False
model = ReformerForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.check_loss_output(result)
def create_and_check_reformer_model_with_attn_mask(
self, config, input_ids, input_mask, choice_labels, is_decoder=False
):
# no special position embeddings
config.axial_pos_embds = False
config.is_decoder = is_decoder
@ -250,7 +268,9 @@ class ReformerModelTester:
self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, choice_labels, is_decoder):
def create_and_check_reformer_layer_dropout_seed(
self, config, input_ids, input_mask, choice_labels, is_decoder=False
):
config.is_decoder = is_decoder
layer = ReformerLayer(config).to(torch_device)
layer.train()
@ -441,17 +461,21 @@ class ReformerTesterMixin:
def test_reformer_model_attn_masking(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True)
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False)
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, is_decoder=True)
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, is_decoder=False)
def test_reformer_with_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs)
def test_reformer_with_mlm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_with_mlm(*config_and_inputs)
def test_reformer_layer_training_dropout(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True)
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False)
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=True)
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=False)
def test_reformer_chunking_forward_equality(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@ -501,7 +525,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
"batch_size": 13,
"seq_length": 32,
"is_training": True,
"is_decoder": False,
"is_decoder": True,
"use_input_mask": True,
"use_labels": True,
"vocab_size": 32,
@ -560,7 +584,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
"use_input_mask": True,
"use_labels": True,
"is_training": False,
"is_decoder": False,
"is_decoder": True,
"vocab_size": 32,
"attention_head_size": 16,
"hidden_size": 64,
@ -910,7 +934,7 @@ class ReformerIntegrationTests(unittest.TestCase):
config["num_buckets"] = [2, 4]
config["is_decoder"] = False
torch.manual_seed(0)
model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device)
model = ReformerForMaskedLM(ReformerConfig(**config)).to(torch_device)
model.eval()
input_ids, attn_mask = self._get_input_ids_and_mask()
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]