Generate: `model_kwargs` can also be an input to `prepare_inputs_for_generation` (#20353)

This commit is contained in:
Joao Gante 2022-11-21 16:20:27 +00:00 committed by GitHub
parent d21c97cc0f
commit 4cf38148dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 11 deletions

View File

@ -194,9 +194,9 @@ class FlaxGenerationMixin:
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
if "kwargs" in model_args or "model_kwargs" in model_args:
model_args |= set(inspect.signature(self.__call__).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:

View File

@ -1445,9 +1445,9 @@ class TFGenerationMixin:
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
if "kwargs" in model_args or "model_kwargs" in model_args:
model_args |= set(inspect.signature(self.call).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:

View File

@ -981,9 +981,9 @@ class GenerationMixin:
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
if "kwargs" in model_args or "model_kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:

View File

@ -3007,8 +3007,8 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertTrue(max_score_diff < 1e-5)
def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-roberta")
encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
@ -3021,3 +3021,7 @@ class GenerationIntegrationTests(unittest.TestCase):
with self.assertRaisesRegex(ValueError, "foo"):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)
# However, valid model_kwargs are accepted
valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)}
model.generate(input_ids, **valid_model_kwargs)