Add an option to `HfArgumentParser.parse_{dict,json_file}` to raise an Exception when there extra keys (#18692)

* Update parser to track unneeded keys, off by default

* Fix formatting

* Fix docstrings and defaults in HfArgparser

* Fix formatting
This commit is contained in:
Felix Schneider 2022-08-31 20:26:45 +02:00 committed by GitHub
parent f210e2a414
commit 86387fe87f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 4 deletions

View File

@ -234,29 +234,60 @@ class HfArgumentParser(ArgumentParser):
return (*outputs,)
def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
dataclass types.
Args:
json_file (`str` or `os.PathLike`):
File name of the json file to parse
allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the json file contains keys that are not
parsed.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
"""
data = json.loads(Path(json_file).read_text())
unused_keys = set(data.keys())
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in data.items() if k in keys}
unused_keys.difference_update(inputs.keys())
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)
if not allow_extra_keys and unused_keys:
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
return tuple(outputs)
def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
types.
Args:
args (`dict`):
dict containing config values
allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
"""
unused_keys = set(args.keys())
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in args.items() if k in keys}
unused_keys.difference_update(inputs.keys())
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)
if not allow_extra_keys and unused_keys:
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
return tuple(outputs)

View File

@ -245,6 +245,19 @@ class HfArgumentParserTest(unittest.TestCase):
args = BasicExample(**args_dict)
self.assertEqual(parsed_args, args)
def test_parse_dict_extra_key(self):
parser = HfArgumentParser(BasicExample)
args_dict = {
"foo": 12,
"bar": 3.14,
"baz": "42",
"flag": True,
"extra": 42,
}
self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False)
def test_integration_training_args(self):
parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser)