Split LMBert model in two (#4874)

* Split LMBert model in two

* Fix example

* Remove lm_labels

* Adapt tests, refactor prepare_for_generation

* Fix merge

* Hide BeartLMHeadModel
This commit is contained in:
Sylvain Gugger 2020-06-10 18:26:42 -04:00 committed by GitHub
parent f6da8b2200
commit 1e2631d6f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 163 additions and 95 deletions

View File

@ -873,7 +873,116 @@ class BertForPreTraining(BertPreTrainedModel):
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
# TODO: Split with a different BertWithLMHead to get rid of `lm_labels` here and in encoder_decoder.
@add_start_docstrings(
"""Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`."
self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
**kwargs
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the left-to-right language modeling loss (next word prediction).
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Next token prediction loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import BertTokenizer, BertLMHeadModel
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertLMHeadModel.from_pretrained('bert-base-uncased', is_decoder=True)
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=input_ids)
loss, prediction_scores = outputs[:2]
"""
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (ltr_lm_loss,) + outputs
return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
return {"input_ids": input_ids, "attention_mask": attention_mask}
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel):
def __init__(self, config):
@ -899,7 +1008,6 @@ class BertForMaskedLM(BertPreTrainedModel):
labels=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
lm_labels=None,
output_attentions=None,
**kwargs
):
@ -909,11 +1017,6 @@ class BertForMaskedLM(BertPreTrainedModel):
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the left-to-right language modeling loss (next word prediction).
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.
@ -921,8 +1024,6 @@ class BertForMaskedLM(BertPreTrainedModel):
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_labels` is provided):
Next token prediction loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
@ -957,6 +1058,7 @@ class BertForMaskedLM(BertPreTrainedModel):
DeprecationWarning,
)
labels = kwargs.pop("masked_lm_labels")
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
outputs = self.bert(
@ -976,46 +1078,24 @@ class BertForMaskedLM(BertPreTrainedModel):
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
# Although this may seem awkward, BertForMaskedLM supports two scenarios:
# 1. If a tensor that contains the indices of masked labels is provided,
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# of predictions for masked words.
# 2. If `lm_labels` is provided we are in a causal scenario where we
# try to predict the next token for each input in the decoder.
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (masked_lm_loss,) + outputs
if lm_labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores = prediction_scores[:, :-1, :].contiguous()
lm_labels = lm_labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
outputs = (ltr_lm_loss,) + outputs
return outputs # (ltr_lm_loss), (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
effective_batch_size = input_shape[0]
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# if model is does not use a causal mask then add a dummy token
if self.config.is_decoder is False:
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
attention_mask = torch.cat(
[attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1
)
dummy_token = torch.full(
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
)
input_ids = torch.cat([input_ids, dummy_token], dim=1)
# add a dummy token
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
dummy_token = torch.full(
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
)
input_ids = torch.cat([input_ids, dummy_token], dim=1)
return {"input_ids": input_ids, "attention_mask": attention_mask}

View File

@ -192,7 +192,6 @@ class EncoderDecoderModel(PreTrainedModel):
decoder_head_mask=None,
decoder_inputs_embeds=None,
labels=None,
lm_labels=None,
**kwargs,
):
@ -239,11 +238,6 @@ class EncoderDecoderModel(PreTrainedModel):
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the left-to-right language modeling loss (next word prediction) for the decoder.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
- With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function.
@ -293,7 +287,6 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
lm_labels=lm_labels,
labels=labels,
**kwargs_decoder,
)

View File

@ -35,7 +35,7 @@ if is_torch_available():
BertForTokenClassification,
BertForMultipleChoice,
)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BertLMHeadModel
class BertModelTester:
@ -211,6 +211,33 @@ class BertModelTester:
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_bert_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = BertLMHeadModel(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result)
def create_and_check_bert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
@ -229,7 +256,7 @@ class BertModelTester:
)
self.check_loss_output(result)
def create_and_check_bert_model_for_masked_lm_as_decoder(
def create_and_check_bert_model_for_causal_lm_as_decoder(
self,
config,
input_ids,
@ -241,7 +268,7 @@ class BertModelTester:
encoder_hidden_states,
encoder_attention_mask,
):
model = BertForMaskedLM(config=config)
model = BertLMHeadModel(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
@ -461,13 +488,17 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
encoder_attention_mask,
)
def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_bert_for_causal_lm(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
def test_for_masked_lm_decoder(self):
def test_for_causal_lm_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_bert_model_for_masked_lm_as_decoder(*config_and_inputs)
self.model_tester.create_and_check_bert_model_for_causal_lm_as_decoder(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -27,7 +27,8 @@ from .utils import require_torch, slow, torch_device
if is_torch_available():
from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel, EncoderDecoderConfig
from transformers import BertModel, EncoderDecoderModel, EncoderDecoderConfig
from transformers.modeling_bert import BertLMHeadModel
import numpy as np
import torch
@ -70,7 +71,6 @@ class EncoderDecoderModelTest(unittest.TestCase):
"decoder_token_labels": decoder_token_labels,
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"lm_labels": decoder_token_labels,
"labels": decoder_token_labels,
}
@ -116,7 +116,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
decoder_model = BertLMHeadModel(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
@ -153,7 +153,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
decoder_model = BertLMHeadModel(decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
enc_dec_model.to(torch_device)
@ -179,7 +179,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
decoder_model = BertLMHeadModel(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
enc_dec_model.eval()
@ -220,7 +220,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
decoder_model = BertLMHeadModel(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
enc_dec_model.eval()
@ -269,7 +269,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
decoder_model = BertLMHeadModel(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
@ -288,41 +288,9 @@ class EncoderDecoderModelTest(unittest.TestCase):
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))
def create_and_check_bert_encoder_decoder_model_lm_labels(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
lm_labels,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
lm_labels=lm_labels,
)
lm_loss = outputs_encoder_decoder[0]
self.check_loss_output(lm_loss)
# check that backprop works
lm_loss.backward()
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))
def create_and_check_bert_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
decoder_model = BertLMHeadModel(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
@ -356,10 +324,6 @@ class EncoderDecoderModelTest(unittest.TestCase):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_labels(**input_ids_dict)
def test_bert_encoder_decoder_model_lm_labels(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_lm_labels(**input_ids_dict)
def test_bert_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_generate(**input_ids_dict)