Generate: `model_kwargs` can also be an input to `prepare_inputs_for_generation` (#20353)
This commit is contained in:
parent
d21c97cc0f
commit
4cf38148dc
|
@ -194,9 +194,9 @@ class FlaxGenerationMixin:
|
|||
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
|
||||
unused_model_args = []
|
||||
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
||||
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
|
||||
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
|
||||
if "kwargs" in model_args:
|
||||
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
|
||||
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
|
||||
if "kwargs" in model_args or "model_kwargs" in model_args:
|
||||
model_args |= set(inspect.signature(self.__call__).parameters)
|
||||
for key, value in model_kwargs.items():
|
||||
if value is not None and key not in model_args:
|
||||
|
|
|
@ -1445,9 +1445,9 @@ class TFGenerationMixin:
|
|||
|
||||
unused_model_args = []
|
||||
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
||||
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
|
||||
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
|
||||
if "kwargs" in model_args:
|
||||
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
|
||||
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
|
||||
if "kwargs" in model_args or "model_kwargs" in model_args:
|
||||
model_args |= set(inspect.signature(self.call).parameters)
|
||||
for key, value in model_kwargs.items():
|
||||
if value is not None and key not in model_args:
|
||||
|
|
|
@ -981,9 +981,9 @@ class GenerationMixin:
|
|||
|
||||
unused_model_args = []
|
||||
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
||||
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
|
||||
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
|
||||
if "kwargs" in model_args:
|
||||
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
|
||||
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
|
||||
if "kwargs" in model_args or "model_kwargs" in model_args:
|
||||
model_args |= set(inspect.signature(self.forward).parameters)
|
||||
for key, value in model_kwargs.items():
|
||||
if value is not None and key not in model_args:
|
||||
|
|
|
@ -3007,8 +3007,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||
self.assertTrue(max_score_diff < 1e-5)
|
||||
|
||||
def test_validate_generation_inputs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-roberta")
|
||||
|
||||
encoder_input_str = "Hello world"
|
||||
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
|
@ -3021,3 +3021,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||
with self.assertRaisesRegex(ValueError, "foo"):
|
||||
fake_model_kwargs = {"foo": "bar"}
|
||||
model.generate(input_ids, **fake_model_kwargs)
|
||||
|
||||
# However, valid model_kwargs are accepted
|
||||
valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)}
|
||||
model.generate(input_ids, **valid_model_kwargs)
|
||||
|
|
Loading…
Reference in New Issue