diff --git a/remarshal.py b/remarshal.py index 41aef75..3ca7e65 100755 --- a/remarshal.py +++ b/remarshal.py @@ -81,6 +81,17 @@ def argv0_to_format(argv0): return False, None, None +def extension_to_format(path): + _, ext = os.path.splitext(path) + + ext = ext[1:] + + if ext == 'yml': + ext = 'yaml' + + return ext if ext in FORMATS else None + + def json_serialize(obj): if isinstance(obj, datetime.datetime): return obj.isoformat() @@ -119,12 +130,10 @@ def parse_command_line(argv): if not format_from_argv0: parser.add_argument('--if', '-if', '--input-format', dest='input_format', - required=True, help="input format", choices=FORMATS) parser.add_argument('--of', '-of', '--output-format', dest='output_format', - required=True, help="output format", choices=FORMATS) @@ -177,6 +186,16 @@ def parse_command_line(argv): args.__dict__['indent_json'] = None if argv0_to != 'yaml': args.__dict__['yaml_style'] = None + else: + if args.input_format is None: + args.input_format = extension_to_format(args.input) + if args.input_format is None: + parser.error('Cannot determine the input format') + + if args.output_format is None: + args.output_format = extension_to_format(args.output) + if args.output_format is None: + parser.error('Cannot determine the output format') # Wrap yaml_style. args.__dict__['yaml_options'] = {'default_style': args.yaml_style} diff --git a/tests/test_remarshal.py b/tests/test_remarshal.py index 9418fc1..5d9ecda 100755 --- a/tests/test_remarshal.py +++ b/tests/test_remarshal.py @@ -208,6 +208,44 @@ class TestRemarshal(unittest.TestCase): test_format_string('{0}2{1}.exe') test_format_string('{0}2{1}-script.py') + def test_format_detection(self): + ext_to_fmt = { + 'json': 'json', + 'toml': 'toml', + 'yaml': 'yaml', + 'yml': 'yaml', + } + + for from_ext in ext_to_fmt.keys(): + for to_ext in ext_to_fmt.keys(): + args = remarshal.parse_command_line([ + sys.argv[0], + 'input.' + from_ext, + 'output.' + to_ext + ]) + + self.assertEqual(args.input_format, ext_to_fmt[from_ext]) + self.assertEqual(args.output_format, ext_to_fmt[to_ext]) + + def test_format_detection_failure_input_stdin(self): + with self.assertRaises(SystemExit) as cm: + remarshal.parse_command_line([sys.argv[0], '-']) + self.assertEqual(cm.exception.code, 2) + + def test_format_detection_failure_input_txt(self): + with self.assertRaises(SystemExit) as cm: + remarshal.parse_command_line([sys.argv[0], 'input.txt']) + self.assertEqual(cm.exception.code, 2) + + def test_format_detection_failure_output_txt(self): + with self.assertRaises(SystemExit) as cm: + remarshal.parse_command_line([ + sys.argv[0], + 'input.json', + 'output.txt' + ]) + self.assertEqual(cm.exception.code, 2) + def test_run_no_args(self): with self.assertRaises(SystemExit) as cm: remarshal.run([sys.argv[0]]) @@ -233,6 +271,11 @@ class TestRemarshal(unittest.TestCase): remarshal.run(args) self.assertEqual(cm.exception.errno, errno.ENOENT) + def test_run_no_output_format(self): + with self.assertRaises(SystemExit) as cm: + remarshal.run([sys.argv[0], test_file_path('array.toml')]) + self.assertEqual(cm.exception.code, 2) + def test_ordered_simple(self): for from_ in 'json', 'toml', 'yaml': for to in 'json', 'toml', 'yaml':