Encoder-Decoder: add informative exception when the decoder is not compatible (#23426)
This commit is contained in:
parent
939a65aba7
commit
a574de302f
|
@ -16,6 +16,7 @@
|
|||
|
||||
|
||||
import gc
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
|
@ -245,6 +246,13 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
|
||||
)
|
||||
|
||||
decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys())
|
||||
if "encoder_hidden_states" not in decoder_signature:
|
||||
raise ValueError(
|
||||
"The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
|
||||
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
|
||||
)
|
||||
|
||||
# tie encoder, decoder weights if config set accordingly
|
||||
self.tie_weights()
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
""" Classes to support TF Encoder-Decoder architectures"""
|
||||
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
@ -266,6 +266,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
|
||||
)
|
||||
|
||||
decoder_signature = set(inspect.signature(self.decoder.call).parameters.keys())
|
||||
if "encoder_hidden_states" not in decoder_signature:
|
||||
raise ValueError(
|
||||
"The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
|
||||
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
|
||||
)
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue