Encoder-Decoder: add informative exception when the decoder is not compatible (#23426)

This commit is contained in:
Joao Gante 2023-05-17 17:42:54 +01:00 committed by GitHub
parent 939a65aba7
commit a574de302f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 1 deletions

View File

@ -16,6 +16,7 @@
import gc
import inspect
import os
import tempfile
import warnings
@ -245,6 +246,13 @@ class EncoderDecoderModel(PreTrainedModel):
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
)
decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys())
if "encoder_hidden_states" not in decoder_signature:
raise ValueError(
"The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
)
# tie encoder, decoder weights if config set accordingly
self.tie_weights()

View File

@ -14,7 +14,7 @@
# limitations under the License.
""" Classes to support TF Encoder-Decoder architectures"""
import inspect
import re
import warnings
from typing import Optional, Tuple, Union
@ -266,6 +266,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
)
decoder_signature = set(inspect.signature(self.decoder.call).parameters.keys())
if "encoder_hidden_states" not in decoder_signature:
raise ValueError(
"The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
)
@property
def dummy_inputs(self):
"""