EncoderDecoderConfigs should not create new objects (#11300)

* removes the creation of separate config objects and uses the existing ones instead+overwrite resize_token_embeddings from parent class because it is not working for the EncoderDecoderModel

* rollback to current version of the huggingface master branch

* reworked version that ties the encoder and decoder config of the parent encoderdecoder instance

* overwrite of resize_token_embeddings throws an error now

* review comment suggestion

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* implemented warning in case encoderdecoder is created with differing configs of encoderdecoderconfig and decoderconfig or encoderconfig

* added test to avoid diverging configs of wrapper class and wrapped classes

* Update src/transformers/models/encoder_decoder/modeling_encoder_decoder.py

* make style

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
cronoik 2021-04-25 11:45:46 +02:00 committed by GitHub
parent f45cb66bf6
commit 35cd8eed88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 0 deletions

View File

@ -175,6 +175,21 @@ class EncoderDecoderModel(PreTrainedModel):
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
)
# make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder
assert ( assert (
self.encoder.get_output_embeddings() is None self.encoder.get_output_embeddings() is None
), "The encoder {} should not have a LM Head. Please use a model without LM Head" ), "The encoder {} should not have a LM Head. Please use a model without LM Head"
@ -458,6 +473,12 @@ class EncoderDecoderModel(PreTrainedModel):
} }
return input_dict return input_dict
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported."
"Please use the respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here # apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx) return self.decoder._reorder_cache(past, beam_idx)

View File

@ -34,6 +34,7 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
AutoConfig,
AutoTokenizer, AutoTokenizer,
BartForCausalLM, BartForCausalLM,
BertGenerationDecoder, BertGenerationDecoder,
@ -884,3 +885,38 @@ class BartEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def test_encoder_decoder_model_shared_weights(self): def test_encoder_decoder_model_shared_weights(self):
pass pass
@require_torch
class EncoderDecoderModelTest(unittest.TestCase):
def get_from_encoderdecoder_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
def get_decoder_config(self):
config = AutoConfig.from_pretrained("bert-base-uncased")
config.is_decoder = True
config.add_cross_attention = True
return config
def get_encoderdecoder_model(self):
return EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
def get_encoder_decoder_models(self):
encoder_model = BertModel.from_pretrained("bert-base-uncased")
decoder_model = BertLMHeadModel.from_pretrained("bert-base-uncased", config=self.get_decoder_config())
return {"encoder": encoder_model, "decoder": decoder_model}
def _check_configuration_tie(self, model):
assert id(model.decoder.config) == id(model.config.decoder)
assert id(model.encoder.config) == id(model.config.encoder)
@slow
def test_configuration_tie(self):
model = self.get_from_encoderdecoder_pretrained_model()
self._check_configuration_tie(model)
model = EncoderDecoderModel(**self.get_encoder_decoder_models())
self._check_configuration_tie(model)
model = self.get_encoderdecoder_model()
self._check_configuration_tie(model)