diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c8d97247e7..b47e221b3f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1147,6 +1147,27 @@ class GenerationMixin: # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) if "kwargs" in model_args or "model_kwargs" in model_args: model_args |= set(inspect.signature(self.forward).parameters) + + # Encoder-Decoder models may also need Encoder arguments from `model_kwargs` + if self.config.is_encoder_decoder: + base_model = getattr(self, self.base_model_prefix, None) + + # allow encoder kwargs + encoder = getattr(self, "encoder", None) + if encoder is None: + encoder = getattr(base_model, "encoder") + + encoder_model_args = set(inspect.signature(encoder.forward).parameters) + model_args |= encoder_model_args + + # allow decoder kwargs + decoder = getattr(self, "decoder", None) + if decoder is None: + decoder = getattr(base_model, "decoder") + + decoder_model_args = set(inspect.signature(decoder.forward).parameters) + model_args |= {f"decoder_{x}" for x in decoder_model_args} + for key, value in model_kwargs.items(): if value is not None and key not in model_args: unused_model_args.append(key)