Hotfix for failing `MusicgenForConditionalGeneration` tests (#25091)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
f9cc333805
commit
21150cb0f3
|
@ -1154,19 +1154,24 @@ class GenerationMixin:
|
|||
|
||||
# allow encoder kwargs
|
||||
encoder = getattr(self, "encoder", None)
|
||||
if encoder is None:
|
||||
encoder = getattr(base_model, "encoder")
|
||||
# `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`.
|
||||
# Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder`
|
||||
# TODO: A better way to handle this.
|
||||
if encoder is None and base_model is not None:
|
||||
encoder = getattr(base_model, "encoder", None)
|
||||
|
||||
encoder_model_args = set(inspect.signature(encoder.forward).parameters)
|
||||
model_args |= encoder_model_args
|
||||
if encoder is not None:
|
||||
encoder_model_args = set(inspect.signature(encoder.forward).parameters)
|
||||
model_args |= encoder_model_args
|
||||
|
||||
# allow decoder kwargs
|
||||
decoder = getattr(self, "decoder", None)
|
||||
if decoder is None:
|
||||
decoder = getattr(base_model, "decoder")
|
||||
if decoder is None and base_model is not None:
|
||||
decoder = getattr(base_model, "decoder", None)
|
||||
|
||||
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
|
||||
model_args |= {f"decoder_{x}" for x in decoder_model_args}
|
||||
if decoder is not None:
|
||||
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
|
||||
model_args |= {f"decoder_{x}" for x in decoder_model_args}
|
||||
|
||||
for key, value in model_kwargs.items():
|
||||
if value is not None and key not in model_args:
|
||||
|
|
Loading…
Reference in New Issue