From 21150cb0f3739a69d767384d58c0247d72735b50 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 25 Jul 2023 20:26:00 +0200 Subject: [PATCH] Hotfix for failing `MusicgenForConditionalGeneration` tests (#25091) Co-authored-by: ydshieh --- src/transformers/generation/utils.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b47e221b3f..7465e88161 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1154,19 +1154,24 @@ class GenerationMixin: # allow encoder kwargs encoder = getattr(self, "encoder", None) - if encoder is None: - encoder = getattr(base_model, "encoder") + # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`. + # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder` + # TODO: A better way to handle this. + if encoder is None and base_model is not None: + encoder = getattr(base_model, "encoder", None) - encoder_model_args = set(inspect.signature(encoder.forward).parameters) - model_args |= encoder_model_args + if encoder is not None: + encoder_model_args = set(inspect.signature(encoder.forward).parameters) + model_args |= encoder_model_args # allow decoder kwargs decoder = getattr(self, "decoder", None) - if decoder is None: - decoder = getattr(base_model, "decoder") + if decoder is None and base_model is not None: + decoder = getattr(base_model, "decoder", None) - decoder_model_args = set(inspect.signature(decoder.forward).parameters) - model_args |= {f"decoder_{x}" for x in decoder_model_args} + if decoder is not None: + decoder_model_args = set(inspect.signature(decoder.forward).parameters) + model_args |= {f"decoder_{x}" for x in decoder_model_args} for key, value in model_kwargs.items(): if value is not None and key not in model_args: