refactor: add types anns; simplify some func types

This commit is contained in:
D. Bohdan 2023-07-19 11:40:40 +00:00
parent 5c0bfb5107
commit 121e622052
2 changed files with 107 additions and 71 deletions

View File

@ -4,12 +4,15 @@
# License: MIT
from __future__ import annotations
import argparse
import datetime
import json
import os.path
import re
import sys
from typing import Any, Callable, Dict, List, Mapping, Sequence, Set, Tuple, Union
import cbor2 # type: ignore
import dateutil.parser
@ -57,7 +60,7 @@ for loader in loaders:
# === JSON ===
def json_default(obj):
def json_default(obj) -> str:
if isinstance(obj, datetime.datetime):
return obj.isoformat()
msg = f"{obj!r} is not JSON serializable"
@ -67,14 +70,14 @@ def json_default(obj):
# === CLI ===
def argv0_to_format(argv0):
def argv0_to_format(argv0: str) -> Tuple[str, str]:
possible_format = "(" + "|".join(FORMATS) + ")"
match = re.search("^" + possible_format + "2" + possible_format, argv0)
from_, to = match.groups() if match else (None, None)
return bool(match), from_, to
from_, to = match.groups() if match else ("", "")
return from_, to
def extension_to_format(path):
def extension_to_format(path: str) -> str:
_, ext = os.path.splitext(path)
ext = ext[1:]
@ -82,10 +85,10 @@ def extension_to_format(path):
if ext == "yml":
ext = "yaml"
return ext if ext in FORMATS else None
return ext if ext in FORMATS else ""
def parse_command_line(argv): # noqa: C901.
def parse_command_line(argv: List[str]) -> argparse.Namespace: # noqa: C901.
defaults = {
"json_indent": 0,
"ordered": True,
@ -93,7 +96,8 @@ def parse_command_line(argv): # noqa: C901.
}
me = os.path.basename(argv[0])
format_from_argv0, argv0_from, argv0_to = argv0_to_format(me)
argv0_from, argv0_to = argv0_to_format(me)
format_from_argv0 = argv0_to != ""
parser = argparse.ArgumentParser(
description="Convert between CBOR, JSON, MessagePack, TOML, and YAML."
@ -127,6 +131,7 @@ def parse_command_line(argv): # noqa: C901.
"-if",
"--input-format",
dest="input_format",
default="",
help="input format",
choices=FORMATS,
)
@ -135,6 +140,7 @@ def parse_command_line(argv): # noqa: C901.
"-of",
"--output-format",
dest="output_format",
default="",
help="output format",
choices=FORMATS,
)
@ -223,14 +229,14 @@ def parse_command_line(argv): # noqa: C901.
args.input_format = argv0_from
args.output_format = argv0_to
else:
if args.input_format is None:
if args.input_format == "":
args.input_format = extension_to_format(args.input)
if args.input_format is None:
if args.input_format == "":
parser.error("Need an explicit input format")
if args.output_format is None:
if args.output_format == "":
args.output_format = extension_to_format(args.output)
if args.output_format is None:
if args.output_format == "":
parser.error("Need an explicit output format")
for key, value in defaults.items():
@ -254,11 +260,11 @@ def parse_command_line(argv): # noqa: C901.
def traverse(
col,
dict_callback=lambda x: dict(x),
list_callback=lambda x: x,
key_callback=lambda x: x,
instance_callbacks=[],
default_callback=lambda x: x,
dict_callback: Callable = lambda x: dict(x),
list_callback: Callable = lambda x: x,
key_callback: Callable = lambda x: x,
instance_callbacks: Set[Tuple[type, Any]] = set(),
default_callback: Callable = lambda x: x,
):
if isinstance(col, dict):
res = dict_callback(
@ -302,7 +308,10 @@ def traverse(
return res
def decode_json(input_data):
Document = Union[bool, bytes, datetime.datetime, Mapping, None, Sequence, str]
def decode_json(input_data: bytes) -> Document:
try:
return json.loads(
input_data.decode("utf-8"),
@ -312,7 +321,7 @@ def decode_json(input_data):
raise ValueError(msg)
def decode_msgpack(input_data):
def decode_msgpack(input_data: bytes) -> Document:
try:
return umsgpack.unpackb(input_data)
except umsgpack.UnpackException as e:
@ -320,7 +329,7 @@ def decode_msgpack(input_data):
raise ValueError(msg)
def decode_cbor(input_data):
def decode_cbor(input_data: bytes) -> Document:
try:
return cbor2.loads(input_data)
except cbor2.CBORDecodeError as e:
@ -328,7 +337,7 @@ def decode_cbor(input_data):
raise ValueError(msg)
def decode_toml(input_data):
def decode_toml(input_data: bytes) -> Document:
try:
# Remove TOML Kit's custom classes.
# https://github.com/sdispater/tomlkit/issues/43
@ -377,7 +386,7 @@ def decode_toml(input_data):
raise ValueError(msg)
def decode_yaml(input_data):
def decode_yaml(input_data: bytes) -> Document:
try:
loader = TimezoneLoader
return yaml.load(input_data, loader)
@ -386,7 +395,7 @@ def decode_yaml(input_data):
raise ValueError(msg)
def decode(input_format, input_data):
def decode(input_format: str, input_data: bytes) -> Document:
decoder = {
"cbor": decode_cbor,
"json": decode_json,
@ -402,7 +411,9 @@ def decode(input_format, input_data):
return decoder[input_format](input_data)
def encode_json(data, ordered, indent):
def encode_json(
data: Document, ordered: bool, indent: Union[bool, int] # noqa: FBT001
) -> str:
if indent is True:
indent = 2
@ -433,7 +444,7 @@ def encode_json(data, ordered, indent):
raise ValueError(msg)
def encode_msgpack(data):
def encode_msgpack(data: Document) -> bytes:
try:
return umsgpack.packb(data)
except umsgpack.UnsupportedTypeException as e:
@ -441,7 +452,7 @@ def encode_msgpack(data):
raise ValueError(msg)
def encode_cbor(data):
def encode_cbor(data: Document) -> bytes:
try:
return cbor2.dumps(data)
except cbor2.CBOREncodeError as e:
@ -449,7 +460,7 @@ def encode_cbor(data):
raise ValueError(msg)
def encode_toml(data, ordered):
def encode_toml(data: Mapping, ordered: bool) -> str: # noqa: FBT001
try:
return tomlkit.dumps(data, sort_keys=not ordered)
except AttributeError as e:
@ -466,7 +477,9 @@ def encode_toml(data, ordered):
raise ValueError(msg)
def encode_yaml(data, ordered, yaml_options):
def encode_yaml(
data: Document, ordered: bool, yaml_options: Dict # noqa: FBT001
) -> str:
dumper = OrderedDumper if ordered else yaml.SafeDumper
try:
return yaml.dump(
@ -483,10 +496,41 @@ def encode_yaml(data, ordered, yaml_options):
raise ValueError(msg)
def encode(
output_format: str,
data: Document,
*,
json_indent: int,
ordered: bool,
yaml_options: Dict,
) -> bytes:
if output_format == "json":
encoded = encode_json(data, ordered, json_indent).encode("utf-8")
elif output_format == "msgpack":
encoded = encode_msgpack(data)
elif output_format == "toml":
if not isinstance(data, Mapping):
msg = (
f"Top-level value of type '{type(data).__name__}' cannot "
"be encoded as TOML"
)
raise TypeError(msg)
encoded = encode_toml(data, ordered).encode("utf-8")
elif output_format == "yaml":
encoded = encode_yaml(data, ordered, yaml_options).encode("utf-8")
elif output_format == "cbor":
encoded = encode_cbor(data)
else:
msg = f"Unknown output format: {output_format}"
raise ValueError(msg)
return encoded
# === Main ===
def run(argv):
def run(argv: List[str]) -> None:
args = parse_command_line(argv)
remarshal(
args.input,
@ -502,33 +546,35 @@ def run(argv):
def remarshal(
input,
output,
input_format,
output_format,
wrap=None,
unwrap=None,
json_indent=0,
yaml_options={},
ordered=True,
transform=None,
):
input: str,
output: str,
input_format: str,
output_format: str,
wrap: Union[str, None] = None,
unwrap: Union[str, None] = None,
json_indent: int = 0,
yaml_options: Dict = {},
ordered: bool = True, # noqa: FBT001
transform: Union[Callable[[Document], Document], None] = None,
) -> None:
try:
if input == "-":
input_file = getattr(sys.stdin, "buffer", sys.stdin)
else:
input_file = open(input, "rb")
if output == "-":
output_file = getattr(sys.stdout, "buffer", sys.stdout)
else:
output_file = open(output, "wb")
input_file = sys.stdin.buffer if input == "-" else open(input, "rb")
output_file = sys.stdout.buffer if output == "-" else open(output, "wb")
input_data = input_file.read()
if not isinstance(input_data, bytes):
msg = "input_data must be bytes"
raise TypeError(msg)
parsed = decode(input_format, input_data)
if unwrap is not None:
if not isinstance(parsed, Mapping):
msg = (
f"Top-level value of type '{type(parsed).__name__}' "
"cannot be unwrapped"
)
raise TypeError(msg)
parsed = parsed[unwrap]
if wrap is not None:
temp = {}
@ -538,24 +584,14 @@ def remarshal(
if transform:
parsed = transform(parsed)
if output_format == "json":
output_data = encode_json(parsed, ordered, json_indent)
elif output_format == "msgpack":
output_data = encode_msgpack(parsed)
elif output_format == "toml":
output_data = encode_toml(parsed, ordered)
elif output_format == "yaml":
output_data = encode_yaml(parsed, ordered, yaml_options)
elif output_format == "cbor":
output_data = encode_cbor(parsed)
else:
msg = f"Unknown output format: {output_format}"
raise ValueError(msg)
encoded = encode(
output_format,
parsed,
json_indent=json_indent,
ordered=ordered,
yaml_options=yaml_options,
)
if output_format == "msgpack" or output_format == "cbor":
encoded = output_data
else:
encoded = output_data.encode("utf-8")
output_file.write(encoded)
finally:
if "input_file" in locals():
@ -564,12 +600,12 @@ def remarshal(
output_file.close()
def main():
def main() -> None:
try:
run(sys.argv)
except KeyboardInterrupt:
pass
except (OSError, ValueError) as e:
except (OSError, TypeError, ValueError) as e:
print(f"Error: {e}", file=sys.stderr) # noqa: T201
sys.exit(1)

View File

@ -312,7 +312,7 @@ class TestRemarshal(unittest.TestCase):
assert output == reference
def test_missing_wrap(self):
with pytest.raises(ValueError):
with pytest.raises(TypeError):
self.convert_and_read("array.json", "json", "toml")
def test_wrap(self):
@ -397,10 +397,10 @@ class TestRemarshal(unittest.TestCase):
def test_format_string(s):
for from_str in "json", "toml", "yaml":
for to_str in "json", "toml", "yaml":
found, from_parsed, to_parsed = remarshal.argv0_to_format(
from_parsed, to_parsed = remarshal.argv0_to_format(
s.format(from_str, to_str)
)
assert (found, from_parsed, to_parsed) == (found, from_str, to_str)
assert (from_parsed, to_parsed) == (from_str, to_str)
test_format_string("{0}2{1}")
test_format_string("{0}2{1}.exe")