Allow generic composite models to pass more kwargs (#24927)
* fix * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * update --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
b51312e24d
commit
1e662f0f07
|
@ -1147,6 +1147,27 @@ class GenerationMixin:
|
|||
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
|
||||
if "kwargs" in model_args or "model_kwargs" in model_args:
|
||||
model_args |= set(inspect.signature(self.forward).parameters)
|
||||
|
||||
# Encoder-Decoder models may also need Encoder arguments from `model_kwargs`
|
||||
if self.config.is_encoder_decoder:
|
||||
base_model = getattr(self, self.base_model_prefix, None)
|
||||
|
||||
# allow encoder kwargs
|
||||
encoder = getattr(self, "encoder", None)
|
||||
if encoder is None:
|
||||
encoder = getattr(base_model, "encoder")
|
||||
|
||||
encoder_model_args = set(inspect.signature(encoder.forward).parameters)
|
||||
model_args |= encoder_model_args
|
||||
|
||||
# allow decoder kwargs
|
||||
decoder = getattr(self, "decoder", None)
|
||||
if decoder is None:
|
||||
decoder = getattr(base_model, "decoder")
|
||||
|
||||
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
|
||||
model_args |= {f"decoder_{x}" for x in decoder_model_args}
|
||||
|
||||
for key, value in model_kwargs.items():
|
||||
if value is not None and key not in model_args:
|
||||
unused_model_args.append(key)
|
||||
|
|
Loading…
Reference in New Issue