From 7edf8bfafd464de082051b319df2cea338083f36 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Mon, 4 Dec 2023 07:38:22 +0100 Subject: [PATCH] Improve forward signature test (#27729) * First draft * Extend test_forward_signature * Update tests/test_modeling_common.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Revert suggestion --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- tests/test_modeling_common.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 79c630c0d5..b76a8025e6 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -542,6 +542,12 @@ class ModelTesterMixin: else ["encoder_outputs"] ) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and self.has_attentions: + expected_arg_names = ["pixel_values", "output_hidden_states", "output_attentions", "return_dict"] + self.assertListEqual(arg_names, expected_arg_names) + elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and not self.has_attentions: + expected_arg_names = ["pixel_values", "output_hidden_states", "return_dict"] + self.assertListEqual(arg_names, expected_arg_names) else: expected_arg_names = [model.main_input_name] self.assertListEqual(arg_names[:1], expected_arg_names)