Improve forward signature test (#27729)

* First draft

* Extend test_forward_signature

* Update tests/test_modeling_common.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Revert suggestion

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
NielsRogge 2023-12-04 07:38:22 +01:00 committed by GitHub
parent bcd0a91a01
commit 7edf8bfafd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 0 deletions

View File

@ -542,6 +542,12 @@ class ModelTesterMixin:
else ["encoder_outputs"]
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and self.has_attentions:
expected_arg_names = ["pixel_values", "output_hidden_states", "output_attentions", "return_dict"]
self.assertListEqual(arg_names, expected_arg_names)
elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and not self.has_attentions:
expected_arg_names = ["pixel_values", "output_hidden_states", "return_dict"]
self.assertListEqual(arg_names, expected_arg_names)
else:
expected_arg_names = [model.main_input_name]
self.assertListEqual(arg_names[:1], expected_arg_names)