Add an option to `HfArgumentParser.parse_{dict,json_file}` to raise an Exception when there extra keys (#18692)
* Update parser to track unneeded keys, off by default * Fix formatting * Fix docstrings and defaults in HfArgparser * Fix formatting
This commit is contained in:
parent
f210e2a414
commit
86387fe87f
|
@ -234,29 +234,60 @@ class HfArgumentParser(ArgumentParser):
|
|||
|
||||
return (*outputs,)
|
||||
|
||||
def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
|
||||
def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
|
||||
"""
|
||||
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
|
||||
dataclass types.
|
||||
|
||||
Args:
|
||||
json_file (`str` or `os.PathLike`):
|
||||
File name of the json file to parse
|
||||
allow_extra_keys (`bool`, *optional*, defaults to `False`):
|
||||
Defaults to False. If False, will raise an exception if the json file contains keys that are not
|
||||
parsed.
|
||||
|
||||
Returns:
|
||||
Tuple consisting of:
|
||||
|
||||
- the dataclass instances in the same order as they were passed to the initializer.
|
||||
"""
|
||||
data = json.loads(Path(json_file).read_text())
|
||||
unused_keys = set(data.keys())
|
||||
outputs = []
|
||||
for dtype in self.dataclass_types:
|
||||
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
||||
inputs = {k: v for k, v in data.items() if k in keys}
|
||||
unused_keys.difference_update(inputs.keys())
|
||||
obj = dtype(**inputs)
|
||||
outputs.append(obj)
|
||||
return (*outputs,)
|
||||
if not allow_extra_keys and unused_keys:
|
||||
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
|
||||
return tuple(outputs)
|
||||
|
||||
def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
|
||||
def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
|
||||
"""
|
||||
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
|
||||
types.
|
||||
|
||||
Args:
|
||||
args (`dict`):
|
||||
dict containing config values
|
||||
allow_extra_keys (`bool`, *optional*, defaults to `False`):
|
||||
Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.
|
||||
|
||||
Returns:
|
||||
Tuple consisting of:
|
||||
|
||||
- the dataclass instances in the same order as they were passed to the initializer.
|
||||
"""
|
||||
unused_keys = set(args.keys())
|
||||
outputs = []
|
||||
for dtype in self.dataclass_types:
|
||||
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
||||
inputs = {k: v for k, v in args.items() if k in keys}
|
||||
unused_keys.difference_update(inputs.keys())
|
||||
obj = dtype(**inputs)
|
||||
outputs.append(obj)
|
||||
return (*outputs,)
|
||||
if not allow_extra_keys and unused_keys:
|
||||
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
|
||||
return tuple(outputs)
|
||||
|
|
|
@ -245,6 +245,19 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||
args = BasicExample(**args_dict)
|
||||
self.assertEqual(parsed_args, args)
|
||||
|
||||
def test_parse_dict_extra_key(self):
|
||||
parser = HfArgumentParser(BasicExample)
|
||||
|
||||
args_dict = {
|
||||
"foo": 12,
|
||||
"bar": 3.14,
|
||||
"baz": "42",
|
||||
"flag": True,
|
||||
"extra": 42,
|
||||
}
|
||||
|
||||
self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False)
|
||||
|
||||
def test_integration_training_args(self):
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
self.assertIsNotNone(parser)
|
||||
|
|
Loading…
Reference in New Issue