Improve forward signature test (#27729)
* First draft * Extend test_forward_signature * Update tests/test_modeling_common.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Revert suggestion --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
bcd0a91a01
commit
7edf8bfafd
|
@ -542,6 +542,12 @@ class ModelTesterMixin:
|
|||
else ["encoder_outputs"]
|
||||
)
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and self.has_attentions:
|
||||
expected_arg_names = ["pixel_values", "output_hidden_states", "output_attentions", "return_dict"]
|
||||
self.assertListEqual(arg_names, expected_arg_names)
|
||||
elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and not self.has_attentions:
|
||||
expected_arg_names = ["pixel_values", "output_hidden_states", "return_dict"]
|
||||
self.assertListEqual(arg_names, expected_arg_names)
|
||||
else:
|
||||
expected_arg_names = [model.main_input_name]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
|
Loading…
Reference in New Issue