refactor: add types anns; simplify some func types
This commit is contained in:
parent
5c0bfb5107
commit
121e622052
172
remarshal.py
172
remarshal.py
|
@ -4,12 +4,15 @@
|
|||
# License: MIT
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os.path
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, List, Mapping, Sequence, Set, Tuple, Union
|
||||
|
||||
import cbor2 # type: ignore
|
||||
import dateutil.parser
|
||||
|
@ -57,7 +60,7 @@ for loader in loaders:
|
|||
# === JSON ===
|
||||
|
||||
|
||||
def json_default(obj):
|
||||
def json_default(obj) -> str:
|
||||
if isinstance(obj, datetime.datetime):
|
||||
return obj.isoformat()
|
||||
msg = f"{obj!r} is not JSON serializable"
|
||||
|
@ -67,14 +70,14 @@ def json_default(obj):
|
|||
# === CLI ===
|
||||
|
||||
|
||||
def argv0_to_format(argv0):
|
||||
def argv0_to_format(argv0: str) -> Tuple[str, str]:
|
||||
possible_format = "(" + "|".join(FORMATS) + ")"
|
||||
match = re.search("^" + possible_format + "2" + possible_format, argv0)
|
||||
from_, to = match.groups() if match else (None, None)
|
||||
return bool(match), from_, to
|
||||
from_, to = match.groups() if match else ("", "")
|
||||
return from_, to
|
||||
|
||||
|
||||
def extension_to_format(path):
|
||||
def extension_to_format(path: str) -> str:
|
||||
_, ext = os.path.splitext(path)
|
||||
|
||||
ext = ext[1:]
|
||||
|
@ -82,10 +85,10 @@ def extension_to_format(path):
|
|||
if ext == "yml":
|
||||
ext = "yaml"
|
||||
|
||||
return ext if ext in FORMATS else None
|
||||
return ext if ext in FORMATS else ""
|
||||
|
||||
|
||||
def parse_command_line(argv): # noqa: C901.
|
||||
def parse_command_line(argv: List[str]) -> argparse.Namespace: # noqa: C901.
|
||||
defaults = {
|
||||
"json_indent": 0,
|
||||
"ordered": True,
|
||||
|
@ -93,7 +96,8 @@ def parse_command_line(argv): # noqa: C901.
|
|||
}
|
||||
|
||||
me = os.path.basename(argv[0])
|
||||
format_from_argv0, argv0_from, argv0_to = argv0_to_format(me)
|
||||
argv0_from, argv0_to = argv0_to_format(me)
|
||||
format_from_argv0 = argv0_to != ""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert between CBOR, JSON, MessagePack, TOML, and YAML."
|
||||
|
@ -127,6 +131,7 @@ def parse_command_line(argv): # noqa: C901.
|
|||
"-if",
|
||||
"--input-format",
|
||||
dest="input_format",
|
||||
default="",
|
||||
help="input format",
|
||||
choices=FORMATS,
|
||||
)
|
||||
|
@ -135,6 +140,7 @@ def parse_command_line(argv): # noqa: C901.
|
|||
"-of",
|
||||
"--output-format",
|
||||
dest="output_format",
|
||||
default="",
|
||||
help="output format",
|
||||
choices=FORMATS,
|
||||
)
|
||||
|
@ -223,14 +229,14 @@ def parse_command_line(argv): # noqa: C901.
|
|||
args.input_format = argv0_from
|
||||
args.output_format = argv0_to
|
||||
else:
|
||||
if args.input_format is None:
|
||||
if args.input_format == "":
|
||||
args.input_format = extension_to_format(args.input)
|
||||
if args.input_format is None:
|
||||
if args.input_format == "":
|
||||
parser.error("Need an explicit input format")
|
||||
|
||||
if args.output_format is None:
|
||||
if args.output_format == "":
|
||||
args.output_format = extension_to_format(args.output)
|
||||
if args.output_format is None:
|
||||
if args.output_format == "":
|
||||
parser.error("Need an explicit output format")
|
||||
|
||||
for key, value in defaults.items():
|
||||
|
@ -254,11 +260,11 @@ def parse_command_line(argv): # noqa: C901.
|
|||
|
||||
def traverse(
|
||||
col,
|
||||
dict_callback=lambda x: dict(x),
|
||||
list_callback=lambda x: x,
|
||||
key_callback=lambda x: x,
|
||||
instance_callbacks=[],
|
||||
default_callback=lambda x: x,
|
||||
dict_callback: Callable = lambda x: dict(x),
|
||||
list_callback: Callable = lambda x: x,
|
||||
key_callback: Callable = lambda x: x,
|
||||
instance_callbacks: Set[Tuple[type, Any]] = set(),
|
||||
default_callback: Callable = lambda x: x,
|
||||
):
|
||||
if isinstance(col, dict):
|
||||
res = dict_callback(
|
||||
|
@ -302,7 +308,10 @@ def traverse(
|
|||
return res
|
||||
|
||||
|
||||
def decode_json(input_data):
|
||||
Document = Union[bool, bytes, datetime.datetime, Mapping, None, Sequence, str]
|
||||
|
||||
|
||||
def decode_json(input_data: bytes) -> Document:
|
||||
try:
|
||||
return json.loads(
|
||||
input_data.decode("utf-8"),
|
||||
|
@ -312,7 +321,7 @@ def decode_json(input_data):
|
|||
raise ValueError(msg)
|
||||
|
||||
|
||||
def decode_msgpack(input_data):
|
||||
def decode_msgpack(input_data: bytes) -> Document:
|
||||
try:
|
||||
return umsgpack.unpackb(input_data)
|
||||
except umsgpack.UnpackException as e:
|
||||
|
@ -320,7 +329,7 @@ def decode_msgpack(input_data):
|
|||
raise ValueError(msg)
|
||||
|
||||
|
||||
def decode_cbor(input_data):
|
||||
def decode_cbor(input_data: bytes) -> Document:
|
||||
try:
|
||||
return cbor2.loads(input_data)
|
||||
except cbor2.CBORDecodeError as e:
|
||||
|
@ -328,7 +337,7 @@ def decode_cbor(input_data):
|
|||
raise ValueError(msg)
|
||||
|
||||
|
||||
def decode_toml(input_data):
|
||||
def decode_toml(input_data: bytes) -> Document:
|
||||
try:
|
||||
# Remove TOML Kit's custom classes.
|
||||
# https://github.com/sdispater/tomlkit/issues/43
|
||||
|
@ -377,7 +386,7 @@ def decode_toml(input_data):
|
|||
raise ValueError(msg)
|
||||
|
||||
|
||||
def decode_yaml(input_data):
|
||||
def decode_yaml(input_data: bytes) -> Document:
|
||||
try:
|
||||
loader = TimezoneLoader
|
||||
return yaml.load(input_data, loader)
|
||||
|
@ -386,7 +395,7 @@ def decode_yaml(input_data):
|
|||
raise ValueError(msg)
|
||||
|
||||
|
||||
def decode(input_format, input_data):
|
||||
def decode(input_format: str, input_data: bytes) -> Document:
|
||||
decoder = {
|
||||
"cbor": decode_cbor,
|
||||
"json": decode_json,
|
||||
|
@ -402,7 +411,9 @@ def decode(input_format, input_data):
|
|||
return decoder[input_format](input_data)
|
||||
|
||||
|
||||
def encode_json(data, ordered, indent):
|
||||
def encode_json(
|
||||
data: Document, ordered: bool, indent: Union[bool, int] # noqa: FBT001
|
||||
) -> str:
|
||||
if indent is True:
|
||||
indent = 2
|
||||
|
||||
|
@ -433,7 +444,7 @@ def encode_json(data, ordered, indent):
|
|||
raise ValueError(msg)
|
||||
|
||||
|
||||
def encode_msgpack(data):
|
||||
def encode_msgpack(data: Document) -> bytes:
|
||||
try:
|
||||
return umsgpack.packb(data)
|
||||
except umsgpack.UnsupportedTypeException as e:
|
||||
|
@ -441,7 +452,7 @@ def encode_msgpack(data):
|
|||
raise ValueError(msg)
|
||||
|
||||
|
||||
def encode_cbor(data):
|
||||
def encode_cbor(data: Document) -> bytes:
|
||||
try:
|
||||
return cbor2.dumps(data)
|
||||
except cbor2.CBOREncodeError as e:
|
||||
|
@ -449,7 +460,7 @@ def encode_cbor(data):
|
|||
raise ValueError(msg)
|
||||
|
||||
|
||||
def encode_toml(data, ordered):
|
||||
def encode_toml(data: Mapping, ordered: bool) -> str: # noqa: FBT001
|
||||
try:
|
||||
return tomlkit.dumps(data, sort_keys=not ordered)
|
||||
except AttributeError as e:
|
||||
|
@ -466,7 +477,9 @@ def encode_toml(data, ordered):
|
|||
raise ValueError(msg)
|
||||
|
||||
|
||||
def encode_yaml(data, ordered, yaml_options):
|
||||
def encode_yaml(
|
||||
data: Document, ordered: bool, yaml_options: Dict # noqa: FBT001
|
||||
) -> str:
|
||||
dumper = OrderedDumper if ordered else yaml.SafeDumper
|
||||
try:
|
||||
return yaml.dump(
|
||||
|
@ -483,10 +496,41 @@ def encode_yaml(data, ordered, yaml_options):
|
|||
raise ValueError(msg)
|
||||
|
||||
|
||||
def encode(
|
||||
output_format: str,
|
||||
data: Document,
|
||||
*,
|
||||
json_indent: int,
|
||||
ordered: bool,
|
||||
yaml_options: Dict,
|
||||
) -> bytes:
|
||||
if output_format == "json":
|
||||
encoded = encode_json(data, ordered, json_indent).encode("utf-8")
|
||||
elif output_format == "msgpack":
|
||||
encoded = encode_msgpack(data)
|
||||
elif output_format == "toml":
|
||||
if not isinstance(data, Mapping):
|
||||
msg = (
|
||||
f"Top-level value of type '{type(data).__name__}' cannot "
|
||||
"be encoded as TOML"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
encoded = encode_toml(data, ordered).encode("utf-8")
|
||||
elif output_format == "yaml":
|
||||
encoded = encode_yaml(data, ordered, yaml_options).encode("utf-8")
|
||||
elif output_format == "cbor":
|
||||
encoded = encode_cbor(data)
|
||||
else:
|
||||
msg = f"Unknown output format: {output_format}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return encoded
|
||||
|
||||
|
||||
# === Main ===
|
||||
|
||||
|
||||
def run(argv):
|
||||
def run(argv: List[str]) -> None:
|
||||
args = parse_command_line(argv)
|
||||
remarshal(
|
||||
args.input,
|
||||
|
@ -502,33 +546,35 @@ def run(argv):
|
|||
|
||||
|
||||
def remarshal(
|
||||
input,
|
||||
output,
|
||||
input_format,
|
||||
output_format,
|
||||
wrap=None,
|
||||
unwrap=None,
|
||||
json_indent=0,
|
||||
yaml_options={},
|
||||
ordered=True,
|
||||
transform=None,
|
||||
):
|
||||
input: str,
|
||||
output: str,
|
||||
input_format: str,
|
||||
output_format: str,
|
||||
wrap: Union[str, None] = None,
|
||||
unwrap: Union[str, None] = None,
|
||||
json_indent: int = 0,
|
||||
yaml_options: Dict = {},
|
||||
ordered: bool = True, # noqa: FBT001
|
||||
transform: Union[Callable[[Document], Document], None] = None,
|
||||
) -> None:
|
||||
try:
|
||||
if input == "-":
|
||||
input_file = getattr(sys.stdin, "buffer", sys.stdin)
|
||||
else:
|
||||
input_file = open(input, "rb")
|
||||
|
||||
if output == "-":
|
||||
output_file = getattr(sys.stdout, "buffer", sys.stdout)
|
||||
else:
|
||||
output_file = open(output, "wb")
|
||||
input_file = sys.stdin.buffer if input == "-" else open(input, "rb")
|
||||
output_file = sys.stdout.buffer if output == "-" else open(output, "wb")
|
||||
|
||||
input_data = input_file.read()
|
||||
if not isinstance(input_data, bytes):
|
||||
msg = "input_data must be bytes"
|
||||
raise TypeError(msg)
|
||||
|
||||
parsed = decode(input_format, input_data)
|
||||
|
||||
if unwrap is not None:
|
||||
if not isinstance(parsed, Mapping):
|
||||
msg = (
|
||||
f"Top-level value of type '{type(parsed).__name__}' "
|
||||
"cannot be unwrapped"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
parsed = parsed[unwrap]
|
||||
if wrap is not None:
|
||||
temp = {}
|
||||
|
@ -538,24 +584,14 @@ def remarshal(
|
|||
if transform:
|
||||
parsed = transform(parsed)
|
||||
|
||||
if output_format == "json":
|
||||
output_data = encode_json(parsed, ordered, json_indent)
|
||||
elif output_format == "msgpack":
|
||||
output_data = encode_msgpack(parsed)
|
||||
elif output_format == "toml":
|
||||
output_data = encode_toml(parsed, ordered)
|
||||
elif output_format == "yaml":
|
||||
output_data = encode_yaml(parsed, ordered, yaml_options)
|
||||
elif output_format == "cbor":
|
||||
output_data = encode_cbor(parsed)
|
||||
else:
|
||||
msg = f"Unknown output format: {output_format}"
|
||||
raise ValueError(msg)
|
||||
encoded = encode(
|
||||
output_format,
|
||||
parsed,
|
||||
json_indent=json_indent,
|
||||
ordered=ordered,
|
||||
yaml_options=yaml_options,
|
||||
)
|
||||
|
||||
if output_format == "msgpack" or output_format == "cbor":
|
||||
encoded = output_data
|
||||
else:
|
||||
encoded = output_data.encode("utf-8")
|
||||
output_file.write(encoded)
|
||||
finally:
|
||||
if "input_file" in locals():
|
||||
|
@ -564,12 +600,12 @@ def remarshal(
|
|||
output_file.close()
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
try:
|
||||
run(sys.argv)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except (OSError, ValueError) as e:
|
||||
except (OSError, TypeError, ValueError) as e:
|
||||
print(f"Error: {e}", file=sys.stderr) # noqa: T201
|
||||
sys.exit(1)
|
||||
|
||||
|
|
|
@ -312,7 +312,7 @@ class TestRemarshal(unittest.TestCase):
|
|||
assert output == reference
|
||||
|
||||
def test_missing_wrap(self):
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
self.convert_and_read("array.json", "json", "toml")
|
||||
|
||||
def test_wrap(self):
|
||||
|
@ -397,10 +397,10 @@ class TestRemarshal(unittest.TestCase):
|
|||
def test_format_string(s):
|
||||
for from_str in "json", "toml", "yaml":
|
||||
for to_str in "json", "toml", "yaml":
|
||||
found, from_parsed, to_parsed = remarshal.argv0_to_format(
|
||||
from_parsed, to_parsed = remarshal.argv0_to_format(
|
||||
s.format(from_str, to_str)
|
||||
)
|
||||
assert (found, from_parsed, to_parsed) == (found, from_str, to_str)
|
||||
assert (from_parsed, to_parsed) == (from_str, to_str)
|
||||
|
||||
test_format_string("{0}2{1}")
|
||||
test_format_string("{0}2{1}.exe")
|
||||
|
|
Loading…
Reference in New Issue