diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index ac3245a29c..140651b2e8 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -234,29 +234,60 @@ class HfArgumentParser(ArgumentParser): return (*outputs,) - def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: + def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: """ Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the dataclass types. + + Args: + json_file (`str` or `os.PathLike`): + File name of the json file to parse + allow_extra_keys (`bool`, *optional*, defaults to `False`): + Defaults to False. If False, will raise an exception if the json file contains keys that are not + parsed. + + Returns: + Tuple consisting of: + + - the dataclass instances in the same order as they were passed to the initializer. """ data = json.loads(Path(json_file).read_text()) + unused_keys = set(data.keys()) outputs = [] for dtype in self.dataclass_types: keys = {f.name for f in dataclasses.fields(dtype) if f.init} inputs = {k: v for k, v in data.items() if k in keys} + unused_keys.difference_update(inputs.keys()) obj = dtype(**inputs) outputs.append(obj) - return (*outputs,) + if not allow_extra_keys and unused_keys: + raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}") + return tuple(outputs) - def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: + def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: """ Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass types. + + Args: + args (`dict`): + dict containing config values + allow_extra_keys (`bool`, *optional*, defaults to `False`): + Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed. + + Returns: + Tuple consisting of: + + - the dataclass instances in the same order as they were passed to the initializer. """ + unused_keys = set(args.keys()) outputs = [] for dtype in self.dataclass_types: keys = {f.name for f in dataclasses.fields(dtype) if f.init} inputs = {k: v for k, v in args.items() if k in keys} + unused_keys.difference_update(inputs.keys()) obj = dtype(**inputs) outputs.append(obj) - return (*outputs,) + if not allow_extra_keys and unused_keys: + raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}") + return tuple(outputs) diff --git a/tests/utils/test_hf_argparser.py b/tests/utils/test_hf_argparser.py index 5ef63080a6..827888509b 100644 --- a/tests/utils/test_hf_argparser.py +++ b/tests/utils/test_hf_argparser.py @@ -245,6 +245,19 @@ class HfArgumentParserTest(unittest.TestCase): args = BasicExample(**args_dict) self.assertEqual(parsed_args, args) + def test_parse_dict_extra_key(self): + parser = HfArgumentParser(BasicExample) + + args_dict = { + "foo": 12, + "bar": 3.14, + "baz": "42", + "flag": True, + "extra": 42, + } + + self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False) + def test_integration_training_args(self): parser = HfArgumentParser(TrainingArguments) self.assertIsNotNone(parser)