[Reformer] Add Masked LM Reformer (#5426)
* fix conflicts * fix * happy rebasing
This commit is contained in:
parent
f4323dbf8c
commit
d16e36c7e5
|
@ -114,9 +114,15 @@ ReformerModelWithLMHead
|
|||
:members:
|
||||
|
||||
|
||||
ReformerForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.ReformerForMaskedLM
|
||||
:members:
|
||||
|
||||
|
||||
ReformerForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.ReformerForQuestionAnswering
|
||||
:members:
|
||||
|
||||
|
|
|
@ -366,6 +366,7 @@ if is_torch_available():
|
|||
ReformerAttention,
|
||||
ReformerLayer,
|
||||
ReformerModel,
|
||||
ReformerForMaskedLM,
|
||||
ReformerModelWithLMHead,
|
||||
ReformerForQuestionAnswering,
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
|
|
|
@ -122,7 +122,12 @@ from .modeling_mobilebert import (
|
|||
MobileBertModel,
|
||||
)
|
||||
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||
from .modeling_reformer import ReformerForQuestionAnswering, ReformerModel, ReformerModelWithLMHead
|
||||
from .modeling_reformer import (
|
||||
ReformerForMaskedLM,
|
||||
ReformerForQuestionAnswering,
|
||||
ReformerModel,
|
||||
ReformerModelWithLMHead,
|
||||
)
|
||||
from .modeling_retribert import RetriBertModel
|
||||
from .modeling_roberta import (
|
||||
RobertaForMaskedLM,
|
||||
|
@ -266,6 +271,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
|||
(FlaubertConfig, FlaubertWithLMHeadModel),
|
||||
(XLMConfig, XLMWithLMHeadModel),
|
||||
(ElectraConfig, ElectraForMaskedLM),
|
||||
(ReformerConfig, ReformerForMaskedLM),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -1704,6 +1704,7 @@ class ReformerModel(ReformerPreTrainedModel):
|
|||
class ReformerModelWithLMHead(ReformerPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
assert config.is_decoder, "If you want to use `ReformerLMHeadModel` make sure that `is_decoder=True`."
|
||||
self.reformer = ReformerModel(config)
|
||||
self.lm_head = ReformerOnlyLMHead(config)
|
||||
|
||||
|
@ -1791,6 +1792,87 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
|
|||
return inputs_dict
|
||||
|
||||
|
||||
@add_start_docstrings("""Reformer Model with a `language modeling` head on top. """, REFORMER_START_DOCSTRING)
|
||||
class ReformerForMaskedLM(ReformerPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
assert (
|
||||
not config.is_decoder
|
||||
), "If you want to use `ReformerForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
|
||||
self.reformer = ReformerModel(config)
|
||||
self.lm_head = ReformerOnlyLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def tie_weights(self):
|
||||
# word embeddings are not tied in Reformer
|
||||
pass
|
||||
|
||||
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/reformer-crime-and-punishment")
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
num_hashes=None,
|
||||
labels=None,
|
||||
output_hidden_states=None,
|
||||
output_attentions=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the masked language modeling loss.
|
||||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Classification loss (cross entropy).
|
||||
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
reformer_outputs = self.reformer(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
num_hashes=num_hashes,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
sequence_output = reformer_outputs[0]
|
||||
logits = self.lm_head(sequence_output)
|
||||
outputs = (logits,) + reformer_outputs[1:]
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||||
masked_lm_loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
outputs = (masked_lm_loss,) + outputs
|
||||
|
||||
return outputs # (mlm_loss), lm_logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Reformer Model with a span classification head on top for
|
||||
extractive question-answering tasks like SQuAD / TriviaQA ( a linear layer on
|
||||
|
|
|
@ -25,6 +25,7 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
|||
if is_torch_available():
|
||||
from transformers import (
|
||||
ReformerConfig,
|
||||
ReformerForMaskedLM,
|
||||
ReformerModel,
|
||||
ReformerModelWithLMHead,
|
||||
ReformerTokenizer,
|
||||
|
@ -209,7 +210,24 @@ class ReformerModelTester:
|
|||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, choice_labels, is_decoder):
|
||||
def create_and_check_reformer_with_mlm(self, config, input_ids, input_mask, choice_labels):
|
||||
config.is_decoder = False
|
||||
model = ReformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_reformer_model_with_attn_mask(
|
||||
self, config, input_ids, input_mask, choice_labels, is_decoder=False
|
||||
):
|
||||
# no special position embeddings
|
||||
config.axial_pos_embds = False
|
||||
config.is_decoder = is_decoder
|
||||
|
@ -250,7 +268,9 @@ class ReformerModelTester:
|
|||
|
||||
self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
|
||||
|
||||
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, choice_labels, is_decoder):
|
||||
def create_and_check_reformer_layer_dropout_seed(
|
||||
self, config, input_ids, input_mask, choice_labels, is_decoder=False
|
||||
):
|
||||
config.is_decoder = is_decoder
|
||||
layer = ReformerLayer(config).to(torch_device)
|
||||
layer.train()
|
||||
|
@ -441,17 +461,21 @@ class ReformerTesterMixin:
|
|||
|
||||
def test_reformer_model_attn_masking(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True)
|
||||
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False)
|
||||
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, is_decoder=True)
|
||||
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, is_decoder=False)
|
||||
|
||||
def test_reformer_with_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs)
|
||||
|
||||
def test_reformer_with_mlm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_with_mlm(*config_and_inputs)
|
||||
|
||||
def test_reformer_layer_training_dropout(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True)
|
||||
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False)
|
||||
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=True)
|
||||
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=False)
|
||||
|
||||
def test_reformer_chunking_forward_equality(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
@ -501,7 +525,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
|||
"batch_size": 13,
|
||||
"seq_length": 32,
|
||||
"is_training": True,
|
||||
"is_decoder": False,
|
||||
"is_decoder": True,
|
||||
"use_input_mask": True,
|
||||
"use_labels": True,
|
||||
"vocab_size": 32,
|
||||
|
@ -560,7 +584,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
|
|||
"use_input_mask": True,
|
||||
"use_labels": True,
|
||||
"is_training": False,
|
||||
"is_decoder": False,
|
||||
"is_decoder": True,
|
||||
"vocab_size": 32,
|
||||
"attention_head_size": 16,
|
||||
"hidden_size": 64,
|
||||
|
@ -910,7 +934,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
|||
config["num_buckets"] = [2, 4]
|
||||
config["is_decoder"] = False
|
||||
torch.manual_seed(0)
|
||||
model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device)
|
||||
model = ReformerForMaskedLM(ReformerConfig(**config)).to(torch_device)
|
||||
model.eval()
|
||||
input_ids, attn_mask = self._get_input_ids_and_mask()
|
||||
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
|
||||
|
|
Loading…
Reference in New Issue