EncoderDecoderConfigs should not create new objects (#11300)
* removes the creation of separate config objects and uses the existing ones instead+overwrite resize_token_embeddings from parent class because it is not working for the EncoderDecoderModel * rollback to current version of the huggingface master branch * reworked version that ties the encoder and decoder config of the parent encoderdecoder instance * overwrite of resize_token_embeddings throws an error now * review comment suggestion Co-authored-by: Suraj Patil <surajp815@gmail.com> * implemented warning in case encoderdecoder is created with differing configs of encoderdecoderconfig and decoderconfig or encoderconfig * added test to avoid diverging configs of wrapper class and wrapped classes * Update src/transformers/models/encoder_decoder/modeling_encoder_decoder.py * make style Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
f45cb66bf6
commit
35cd8eed88
|
@ -175,6 +175,21 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||||
|
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
|
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
|
||||||
|
logger.warning(
|
||||||
|
f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
|
||||||
|
)
|
||||||
|
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
|
||||||
|
logger.warning(
|
||||||
|
f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure that the individual model's config refers to the shared config
|
||||||
|
# so that the updates to the config will be synced
|
||||||
|
self.encoder.config = self.config.encoder
|
||||||
|
self.decoder.config = self.config.decoder
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
self.encoder.get_output_embeddings() is None
|
self.encoder.get_output_embeddings() is None
|
||||||
), "The encoder {} should not have a LM Head. Please use a model without LM Head"
|
), "The encoder {} should not have a LM Head. Please use a model without LM Head"
|
||||||
|
@ -458,6 +473,12 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||||
}
|
}
|
||||||
return input_dict
|
return input_dict
|
||||||
|
|
||||||
|
def resize_token_embeddings(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported."
|
||||||
|
"Please use the respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or model.decoder.resize_token_embeddings(...))"
|
||||||
|
)
|
||||||
|
|
||||||
def _reorder_cache(self, past, beam_idx):
|
def _reorder_cache(self, past, beam_idx):
|
||||||
# apply decoder cache reordering here
|
# apply decoder cache reordering here
|
||||||
return self.decoder._reorder_cache(past, beam_idx)
|
return self.decoder._reorder_cache(past, beam_idx)
|
||||||
|
|
|
@ -34,6 +34,7 @@ if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BartForCausalLM,
|
BartForCausalLM,
|
||||||
BertGenerationDecoder,
|
BertGenerationDecoder,
|
||||||
|
@ -884,3 +885,38 @@ class BartEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||||
|
|
||||||
def test_encoder_decoder_model_shared_weights(self):
|
def test_encoder_decoder_model_shared_weights(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class EncoderDecoderModelTest(unittest.TestCase):
|
||||||
|
def get_from_encoderdecoder_pretrained_model(self):
|
||||||
|
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
|
||||||
|
|
||||||
|
def get_decoder_config(self):
|
||||||
|
config = AutoConfig.from_pretrained("bert-base-uncased")
|
||||||
|
config.is_decoder = True
|
||||||
|
config.add_cross_attention = True
|
||||||
|
return config
|
||||||
|
|
||||||
|
def get_encoderdecoder_model(self):
|
||||||
|
return EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
||||||
|
|
||||||
|
def get_encoder_decoder_models(self):
|
||||||
|
encoder_model = BertModel.from_pretrained("bert-base-uncased")
|
||||||
|
decoder_model = BertLMHeadModel.from_pretrained("bert-base-uncased", config=self.get_decoder_config())
|
||||||
|
return {"encoder": encoder_model, "decoder": decoder_model}
|
||||||
|
|
||||||
|
def _check_configuration_tie(self, model):
|
||||||
|
assert id(model.decoder.config) == id(model.config.decoder)
|
||||||
|
assert id(model.encoder.config) == id(model.config.encoder)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_configuration_tie(self):
|
||||||
|
model = self.get_from_encoderdecoder_pretrained_model()
|
||||||
|
self._check_configuration_tie(model)
|
||||||
|
|
||||||
|
model = EncoderDecoderModel(**self.get_encoder_decoder_models())
|
||||||
|
self._check_configuration_tie(model)
|
||||||
|
|
||||||
|
model = self.get_encoderdecoder_model()
|
||||||
|
self._check_configuration_tie(model)
|
||||||
|
|
Loading…
Reference in New Issue