Added tests for yaml and json parser (#19219)
* Added tests for yaml and json * Added tests for yaml and json
This commit is contained in:
parent
2d95695825
commit
2df602870b
|
@ -281,7 +281,9 @@ class HfArgumentParser(ArgumentParser):
|
|||
|
||||
- the dataclass instances in the same order as they were passed to the initializer.
|
||||
"""
|
||||
outputs = self.parse_dict(json.loads(Path(json_file).read_text()), allow_extra_keys=allow_extra_keys)
|
||||
open_json_file = open(Path(json_file))
|
||||
data = json.loads(open_json_file.read())
|
||||
outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys)
|
||||
return tuple(outputs)
|
||||
|
||||
def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
|
||||
|
@ -301,5 +303,5 @@ class HfArgumentParser(ArgumentParser):
|
|||
|
||||
- the dataclass instances in the same order as they were passed to the initializer.
|
||||
"""
|
||||
outputs = self.parse_dict(yaml.safe_load(yaml_file), allow_extra_keys=allow_extra_keys)
|
||||
outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
|
||||
return tuple(outputs)
|
||||
|
|
|
@ -13,12 +13,17 @@
|
|||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import yaml
|
||||
from transformers import HfArgumentParser, TrainingArguments
|
||||
from transformers.hf_argparser import string_to_bool
|
||||
|
||||
|
@ -258,6 +263,43 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||
|
||||
self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False)
|
||||
|
||||
def test_parse_json(self):
|
||||
parser = HfArgumentParser(BasicExample)
|
||||
|
||||
args_dict_for_json = {
|
||||
"foo": 12,
|
||||
"bar": 3.14,
|
||||
"baz": "42",
|
||||
"flag": True,
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
temp_local_path = os.path.join(tmp_dir, "temp_json")
|
||||
os.mkdir(temp_local_path)
|
||||
with open(temp_local_path + ".json", "w+") as f:
|
||||
json.dump(args_dict_for_json, f)
|
||||
parsed_args = parser.parse_yaml_file(Path(temp_local_path + ".json"))[0]
|
||||
|
||||
args = BasicExample(**args_dict_for_json)
|
||||
self.assertEqual(parsed_args, args)
|
||||
|
||||
def test_parse_yaml(self):
|
||||
parser = HfArgumentParser(BasicExample)
|
||||
|
||||
args_dict_for_yaml = {
|
||||
"foo": 12,
|
||||
"bar": 3.14,
|
||||
"baz": "42",
|
||||
"flag": True,
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
temp_local_path = os.path.join(tmp_dir, "temp_yaml")
|
||||
os.mkdir(temp_local_path)
|
||||
with open(temp_local_path + ".yaml", "w+") as f:
|
||||
yaml.dump(args_dict_for_yaml, f)
|
||||
parsed_args = parser.parse_yaml_file(Path(temp_local_path + ".yaml"))[0]
|
||||
args = BasicExample(**args_dict_for_yaml)
|
||||
self.assertEqual(parsed_args, args)
|
||||
|
||||
def test_integration_training_args(self):
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
self.assertIsNotNone(parser)
|
||||
|
|
Loading…
Reference in New Issue