Added tests for yaml and json parser (#19219)

* Added tests for yaml and json

* Added tests for yaml and json
This commit is contained in:
IMvision12 2022-09-28 01:55:57 +05:30 committed by GitHub
parent 2d95695825
commit 2df602870b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 2 deletions

View File

@ -281,7 +281,9 @@ class HfArgumentParser(ArgumentParser):
- the dataclass instances in the same order as they were passed to the initializer.
"""
outputs = self.parse_dict(json.loads(Path(json_file).read_text()), allow_extra_keys=allow_extra_keys)
open_json_file = open(Path(json_file))
data = json.loads(open_json_file.read())
outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys)
return tuple(outputs)
def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
@ -301,5 +303,5 @@ class HfArgumentParser(ArgumentParser):
- the dataclass instances in the same order as they were passed to the initializer.
"""
outputs = self.parse_dict(yaml.safe_load(yaml_file), allow_extra_keys=allow_extra_keys)
outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
return tuple(outputs)

View File

@ -13,12 +13,17 @@
# limitations under the License.
import argparse
import json
import os
import tempfile
import unittest
from argparse import Namespace
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import List, Optional
import yaml
from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import string_to_bool
@ -258,6 +263,43 @@ class HfArgumentParserTest(unittest.TestCase):
self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False)
def test_parse_json(self):
parser = HfArgumentParser(BasicExample)
args_dict_for_json = {
"foo": 12,
"bar": 3.14,
"baz": "42",
"flag": True,
}
with tempfile.TemporaryDirectory() as tmp_dir:
temp_local_path = os.path.join(tmp_dir, "temp_json")
os.mkdir(temp_local_path)
with open(temp_local_path + ".json", "w+") as f:
json.dump(args_dict_for_json, f)
parsed_args = parser.parse_yaml_file(Path(temp_local_path + ".json"))[0]
args = BasicExample(**args_dict_for_json)
self.assertEqual(parsed_args, args)
def test_parse_yaml(self):
parser = HfArgumentParser(BasicExample)
args_dict_for_yaml = {
"foo": 12,
"bar": 3.14,
"baz": "42",
"flag": True,
}
with tempfile.TemporaryDirectory() as tmp_dir:
temp_local_path = os.path.join(tmp_dir, "temp_yaml")
os.mkdir(temp_local_path)
with open(temp_local_path + ".yaml", "w+") as f:
yaml.dump(args_dict_for_yaml, f)
parsed_args = parser.parse_yaml_file(Path(temp_local_path + ".yaml"))[0]
args = BasicExample(**args_dict_for_yaml)
self.assertEqual(parsed_args, args)
def test_integration_training_args(self):
parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser)