diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 0ede4f9c78..1806e43655 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -404,9 +404,7 @@ def unpack_inputs(func): fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args))) # process the inputs and call the wrapped function - main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1]) - main_input = fn_args_and_kwargs.pop(main_input_name, None) - unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs) + unpacked_inputs = input_processing(func, self.config, **fn_args_and_kwargs) return func(self, **unpacked_inputs) # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This @@ -417,7 +415,7 @@ def unpack_inputs(func): return run_call_with_unpacked_inputs -def input_processing(func, config, input_ids, **kwargs): +def input_processing(func, config, **kwargs): """ Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input has to be named accordingly to the parameters name, i.e. `input_ids = tf.keras.Input(shape=(128,), dtype='int32', @@ -438,6 +436,8 @@ def input_processing(func, config, input_ids, **kwargs): has_kwargs = bool(signature.pop("kwargs", None)) signature.pop("self", None) parameter_names = list(signature.keys()) + main_input_name = parameter_names[0] + main_input = kwargs.pop(main_input_name, None) output = {} allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray, KerasTensor) @@ -483,8 +483,8 @@ def input_processing(func, config, input_ids, **kwargs): else: raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") - if isinstance(input_ids, (tuple, list)): - for i, input in enumerate(input_ids): + if isinstance(main_input, (tuple, list)): + for i, input in enumerate(main_input): # EagerTensors don't allow to use the .name property so we check for a real Tensor if type(input) == tf.Tensor: # Tensor names have always the pattern `name:id` then we check only the @@ -502,25 +502,25 @@ def input_processing(func, config, input_ids, **kwargs): f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for" f" {parameter_names[i]}." ) - elif isinstance(input_ids, Mapping): - if "inputs" in input_ids: + elif isinstance(main_input, Mapping): + if "inputs" in main_input: warnings.warn( "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`" " instead.", FutureWarning, ) - output["input_ids"] = input_ids.pop("inputs") + output["input_ids"] = main_input.pop("inputs") - if "decoder_cached_states" in input_ids: + if "decoder_cached_states" in main_input: warnings.warn( "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use" " `past_key_values` instead.", FutureWarning, ) - output["past_key_values"] = input_ids.pop("decoder_cached_states") + output["past_key_values"] = main_input.pop("decoder_cached_states") - for k, v in dict(input_ids).items(): + for k, v in dict(main_input).items(): if isinstance(v, allowed_types) or v is None: output[k] = v elif k not in parameter_names and "args" not in parameter_names: @@ -531,12 +531,12 @@ def input_processing(func, config, input_ids, **kwargs): else: raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") else: - if isinstance(input_ids, (tf.Tensor, KerasTensor)) or input_ids is None: - output[parameter_names[0]] = input_ids + if isinstance(main_input, (tf.Tensor, KerasTensor)) or main_input is None: + output[main_input_name] = main_input else: raise ValueError( - f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for" - f" {parameter_names[0]}." + f"Data of type {type(main_input)} is not allowed only {allowed_types} is accepted for" + f" {main_input_name}." ) # Populates any unspecified argument with their default value, according to the signature. diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 87516228f2..d27ecaccb0 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1881,6 +1881,7 @@ class UtilsFunctionsTest(unittest.TestCase): def __init__(self): config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False} self.config = PretrainedConfig(**config_kwargs) + self.main_input_name = "input_ids" @unpack_inputs def call( @@ -1888,9 +1889,14 @@ class UtilsFunctionsTest(unittest.TestCase): ): return input_ids, past, output_attentions, output_hidden_states, return_dict + @unpack_inputs + def foo(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None): + return pixel_values, output_attentions, output_hidden_states, return_dict + dummy_model = DummyModel() input_ids = tf.constant([0, 1, 2, 3]) past = tf.constant([4, 5, 6, 7]) + pixel_values = tf.constant([8, 9, 10, 11]) # test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config. output = dummy_model.call(input_ids=input_ids, past=past) @@ -1937,6 +1943,14 @@ class UtilsFunctionsTest(unittest.TestCase): self.assertFalse(output[3]) self.assertFalse(output[4]) + # test case 7: the decorator is independent from `main_input_name` -- it treats the first argument of the + # decorated function as its main input. + output = dummy_model.foo(pixel_values=pixel_values) + tf.debugging.assert_equal(output[0], pixel_values) + self.assertFalse(output[1]) + self.assertFalse(output[2]) + self.assertFalse(output[3]) + # Tests whether the stable softmax is stable on CPU, with and without XLA def test_xla_stable_softmax(self): large_penalty = -1e9