Hotfix for failing `MusicgenForConditionalGeneration` tests (#25091)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-07-25 20:26:00 +02:00 committed by GitHub
parent f9cc333805
commit 21150cb0f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 8 deletions

View File

@ -1154,19 +1154,24 @@ class GenerationMixin:
# allow encoder kwargs
encoder = getattr(self, "encoder", None)
if encoder is None:
encoder = getattr(base_model, "encoder")
# `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`.
# Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder`
# TODO: A better way to handle this.
if encoder is None and base_model is not None:
encoder = getattr(base_model, "encoder", None)
encoder_model_args = set(inspect.signature(encoder.forward).parameters)
model_args |= encoder_model_args
if encoder is not None:
encoder_model_args = set(inspect.signature(encoder.forward).parameters)
model_args |= encoder_model_args
# allow decoder kwargs
decoder = getattr(self, "decoder", None)
if decoder is None:
decoder = getattr(base_model, "decoder")
if decoder is None and base_model is not None:
decoder = getattr(base_model, "decoder", None)
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
model_args |= {f"decoder_{x}" for x in decoder_model_args}
if decoder is not None:
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
model_args |= {f"decoder_{x}" for x in decoder_model_args}
for key, value in model_kwargs.items():
if value is not None and key not in model_args: