diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 5b4d00c46e..4eb857db21 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -873,7 +873,116 @@ class BertForPreTraining(BertPreTrainedModel): return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) -# TODO: Split with a different BertWithLMHead to get rid of `lm_labels` here and in encoder_decoder. +@add_start_docstrings( + """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING +) +class BertLMHeadModel(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`." + + self.bert = BertModel(config) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=None, + **kwargs + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the left-to-right language modeling loss (next word prediction). + Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Next token prediction loss. + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import BertTokenizer, BertLMHeadModel + import torch + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = BertLMHeadModel.from_pretrained('bert-base-uncased', is_decoder=True) + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=input_ids) + + loss, prediction_scores = outputs[:2] + + """ + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here + + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + outputs = (ltr_lm_loss,) + outputs + + return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) class BertForMaskedLM(BertPreTrainedModel): def __init__(self, config): @@ -899,7 +1008,6 @@ class BertForMaskedLM(BertPreTrainedModel): labels=None, encoder_hidden_states=None, encoder_attention_mask=None, - lm_labels=None, output_attentions=None, **kwargs ): @@ -909,11 +1017,6 @@ class BertForMaskedLM(BertPreTrainedModel): Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` - lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Labels for computing the left-to-right language modeling loss (next word prediction). - Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) - Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels - in ``[0, ..., config.vocab_size]`` kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): Used to hide legacy arguments that have been deprecated. @@ -921,8 +1024,6 @@ class BertForMaskedLM(BertPreTrainedModel): :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: Masked language modeling loss. - ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_labels` is provided): - Next token prediction loss. prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): @@ -957,6 +1058,7 @@ class BertForMaskedLM(BertPreTrainedModel): DeprecationWarning, ) labels = kwargs.pop("masked_lm_labels") + assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." outputs = self.bert( @@ -976,46 +1078,24 @@ class BertForMaskedLM(BertPreTrainedModel): outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here - # Although this may seem awkward, BertForMaskedLM supports two scenarios: - # 1. If a tensor that contains the indices of masked labels is provided, - # the cross-entropy is the MLM cross-entropy that measures the likelihood - # of predictions for masked words. - # 2. If `lm_labels` is provided we are in a causal scenario where we - # try to predict the next token for each input in the decoder. if labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) outputs = (masked_lm_loss,) + outputs - if lm_labels is not None: - # we are doing next-token prediction; shift prediction scores and input ids by one - prediction_scores = prediction_scores[:, :-1, :].contiguous() - lm_labels = lm_labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss() - ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1)) - outputs = (ltr_lm_loss,) + outputs - - return outputs # (ltr_lm_loss), (masked_lm_loss), prediction_scores, (hidden_states), (attentions) + return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): input_shape = input_ids.shape effective_batch_size = input_shape[0] - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # if model is does not use a causal mask then add a dummy token - if self.config.is_decoder is False: - assert self.config.pad_token_id is not None, "The PAD token should be defined for generation" - attention_mask = torch.cat( - [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1 - ) - - dummy_token = torch.full( - (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device - ) - input_ids = torch.cat([input_ids, dummy_token], dim=1) + # add a dummy token + assert self.config.pad_token_id is not None, "The PAD token should be defined for generation" + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) return {"input_ids": input_ids, "attention_mask": attention_mask} diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index e7a8e154ea..cb44671a9f 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -192,7 +192,6 @@ class EncoderDecoderModel(PreTrainedModel): decoder_head_mask=None, decoder_inputs_embeds=None, labels=None, - lm_labels=None, **kwargs, ): @@ -239,11 +238,6 @@ class EncoderDecoderModel(PreTrainedModel): Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` - lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Labels for computing the left-to-right language modeling loss (next word prediction) for the decoder. - Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) - Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels - in ``[0, ..., config.vocab_size]`` kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function. - With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function. @@ -293,7 +287,6 @@ class EncoderDecoderModel(PreTrainedModel): encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - lm_labels=lm_labels, labels=labels, **kwargs_decoder, ) diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index ed42031232..737a0082c9 100644 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -35,7 +35,7 @@ if is_torch_available(): BertForTokenClassification, BertForMultipleChoice, ) - from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST + from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BertLMHeadModel class BertModelTester: @@ -211,6 +211,33 @@ class BertModelTester: ) self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) + def create_and_check_bert_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = BertLMHeadModel(config=config) + model.to(torch_device) + model.eval() + loss, prediction_scores = model( + input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels + ) + result = { + "loss": loss, + "prediction_scores": prediction_scores, + } + self.parent.assertListEqual( + list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size] + ) + self.check_loss_output(result) + def create_and_check_bert_for_masked_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -229,7 +256,7 @@ class BertModelTester: ) self.check_loss_output(result) - def create_and_check_bert_model_for_masked_lm_as_decoder( + def create_and_check_bert_model_for_causal_lm_as_decoder( self, config, input_ids, @@ -241,7 +268,7 @@ class BertModelTester: encoder_hidden_states, encoder_attention_mask, ): - model = BertForMaskedLM(config=config) + model = BertLMHeadModel(config=config) model.to(torch_device) model.eval() loss, prediction_scores = model( @@ -461,13 +488,17 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase): encoder_attention_mask, ) + def test_for_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_bert_for_causal_lm(*config_and_inputs) + def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs) - def test_for_masked_lm_decoder(self): + def test_for_causal_lm_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() - self.model_tester.create_and_check_bert_model_for_masked_lm_as_decoder(*config_and_inputs) + self.model_tester.create_and_check_bert_model_for_causal_lm_as_decoder(*config_and_inputs) def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index caf2891aab..65f479e8f3 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -27,7 +27,8 @@ from .utils import require_torch, slow, torch_device if is_torch_available(): - from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel, EncoderDecoderConfig + from transformers import BertModel, EncoderDecoderModel, EncoderDecoderConfig + from transformers.modeling_bert import BertLMHeadModel import numpy as np import torch @@ -70,7 +71,6 @@ class EncoderDecoderModelTest(unittest.TestCase): "decoder_token_labels": decoder_token_labels, "decoder_choice_labels": decoder_choice_labels, "encoder_hidden_states": encoder_hidden_states, - "lm_labels": decoder_token_labels, "labels": decoder_token_labels, } @@ -116,7 +116,7 @@ class EncoderDecoderModelTest(unittest.TestCase): **kwargs ): encoder_model = BertModel(config) - decoder_model = BertForMaskedLM(decoder_config) + decoder_model = BertLMHeadModel(decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) self.assertTrue(enc_dec_model.config.decoder.is_decoder) self.assertTrue(enc_dec_model.config.is_encoder_decoder) @@ -153,7 +153,7 @@ class EncoderDecoderModelTest(unittest.TestCase): **kwargs ): encoder_model = BertModel(config) - decoder_model = BertForMaskedLM(decoder_config) + decoder_model = BertLMHeadModel(decoder_config) kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model} enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs) enc_dec_model.to(torch_device) @@ -179,7 +179,7 @@ class EncoderDecoderModelTest(unittest.TestCase): **kwargs ): encoder_model = BertModel(config) - decoder_model = BertForMaskedLM(decoder_config) + decoder_model = BertLMHeadModel(decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) enc_dec_model.eval() @@ -220,7 +220,7 @@ class EncoderDecoderModelTest(unittest.TestCase): **kwargs ): encoder_model = BertModel(config) - decoder_model = BertForMaskedLM(decoder_config) + decoder_model = BertLMHeadModel(decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) enc_dec_model.eval() @@ -269,7 +269,7 @@ class EncoderDecoderModelTest(unittest.TestCase): **kwargs ): encoder_model = BertModel(config) - decoder_model = BertForMaskedLM(decoder_config) + decoder_model = BertLMHeadModel(decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( @@ -288,41 +288,9 @@ class EncoderDecoderModelTest(unittest.TestCase): self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,))) - def create_and_check_bert_encoder_decoder_model_lm_labels( - self, - config, - input_ids, - attention_mask, - encoder_hidden_states, - decoder_config, - decoder_input_ids, - decoder_attention_mask, - lm_labels, - **kwargs - ): - encoder_model = BertModel(config) - decoder_model = BertForMaskedLM(decoder_config) - enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) - enc_dec_model.to(torch_device) - outputs_encoder_decoder = enc_dec_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - lm_labels=lm_labels, - ) - - lm_loss = outputs_encoder_decoder[0] - self.check_loss_output(lm_loss) - # check that backprop works - lm_loss.backward() - - self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) - self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,))) - def create_and_check_bert_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): encoder_model = BertModel(config) - decoder_model = BertForMaskedLM(decoder_config) + decoder_model = BertLMHeadModel(decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) @@ -356,10 +324,6 @@ class EncoderDecoderModelTest(unittest.TestCase): input_ids_dict = self.prepare_config_and_inputs_bert() self.create_and_check_bert_encoder_decoder_model_labels(**input_ids_dict) - def test_bert_encoder_decoder_model_lm_labels(self): - input_ids_dict = self.prepare_config_and_inputs_bert() - self.create_and_check_bert_encoder_decoder_model_lm_labels(**input_ids_dict) - def test_bert_encoder_decoder_model_generate(self): input_ids_dict = self.prepare_config_and_inputs_bert() self.create_and_check_bert_encoder_decoder_model_generate(**input_ids_dict)