fix(yaml): limit maximum nodes

Prevent a billion laughs attack carried out with YAML anchors.

Thanks to Taichi Kotake of Akatsuki Inc. for finding this vulnerability
and to JPCERT/CC for reporting it.

JVN#86156389
This commit is contained in:
D. Bohdan 2023-09-05 12:10:06 +00:00
parent d95a1846b9
commit fd6ac799a0
4 changed files with 55 additions and 5 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "Remarshal" name = "Remarshal"
version = "0.17.0" version = "0.17.1"
description = "Convert between CBOR, JSON, MessagePack, TOML, and YAML" description = "Convert between CBOR, JSON, MessagePack, TOML, and YAML"
authors = ["D. Bohdan <dbohdan@dbohdan.com>"] authors = ["D. Bohdan <dbohdan@dbohdan.com>"]
license = "MIT" license = "MIT"
@ -136,8 +136,9 @@ max-complexity = 14
[tool.ruff.pylint] [tool.ruff.pylint]
allow-magic-value-types = ["int", "str"] allow-magic-value-types = ["int", "str"]
max-args = 11 max-args = 12
max-branches = 19 max-branches = 20
max-statements = 100
[tool.ruff.per-file-ignores] [tool.ruff.per-file-ignores]
"remarshal.py" = ["ARG001", "B904", "EM103", "RET506", "S506", "SIM115"] "remarshal.py" = ["ARG001", "B904", "EM103", "RET506", "S506", "SIM115"]

View File

@ -24,8 +24,9 @@ import yaml
import yaml.parser import yaml.parser
import yaml.scanner import yaml.scanner
__version__ = "0.17.0" __version__ = "0.17.1"
DEFAULT_MAX_NODES = 100000
FORMATS = ["cbor", "json", "msgpack", "toml", "yaml"] FORMATS = ["cbor", "json", "msgpack", "toml", "yaml"]
@ -142,6 +143,15 @@ def parse_command_line(argv: List[str]) -> argparse.Namespace: # noqa: C901.
), ),
) )
parser.add_argument(
"--max-nodes",
dest="max_nodes",
metavar="n",
type=int,
default=DEFAULT_MAX_NODES,
help="maximum number of nodes in input data (default %(default)s)",
)
output_group = parser.add_mutually_exclusive_group() output_group = parser.add_mutually_exclusive_group()
output_group.add_argument("output", nargs="?", default="-", help="output file") output_group.add_argument("output", nargs="?", default="-", help="output file")
output_group.add_argument( output_group.add_argument(
@ -431,6 +441,27 @@ def decode(input_format: str, input_data: bytes) -> Document:
return decoder[input_format](input_data) return decoder[input_format](input_data)
class TooManyNodesError(BaseException):
def __init__(self, msg: str = "document has too many nodes", *args, **kwargs):
super().__init__(msg, *args, **kwargs)
def validate_node_count(doc: Document, *, limit: int) -> None:
count = 0
def count_callback(x: Any) -> Any:
nonlocal count
nonlocal limit
count += 1
if count > limit:
raise TooManyNodesError
return x
traverse(doc, instance_callbacks={(object, count_callback)})
def reject_special_keys(key: Any) -> Any: def reject_special_keys(key: Any) -> Any:
if isinstance(key, bool): if isinstance(key, bool):
msg = "boolean key" msg = "boolean key"
@ -627,6 +658,7 @@ def run(argv: List[str]) -> None:
args.input_format, args.input_format,
args.output_format, args.output_format,
json_indent=args.json_indent, json_indent=args.json_indent,
max_nodes=args.max_nodes,
ordered=args.ordered, ordered=args.ordered,
stringify=args.stringify, stringify=args.stringify,
unwrap=args.unwrap, unwrap=args.unwrap,
@ -642,6 +674,7 @@ def remarshal(
output_format: str, output_format: str,
*, *,
json_indent: Union[int, None] = None, json_indent: Union[int, None] = None,
max_nodes: int = DEFAULT_MAX_NODES,
ordered: bool = True, ordered: bool = True,
stringify: bool = False, stringify: bool = False,
transform: Union[Callable[[Document], Document], None] = None, transform: Union[Callable[[Document], Document], None] = None,
@ -663,6 +696,8 @@ def remarshal(
parsed = decode(input_format, input_data) parsed = decode(input_format, input_data)
validate_node_count(parsed, limit=max_nodes)
if unwrap is not None: if unwrap is not None:
if not isinstance(parsed, Mapping): if not isinstance(parsed, Mapping):
msg = ( msg = (
@ -701,7 +736,7 @@ def main() -> None:
run(sys.argv) run(sys.argv)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
except (OSError, TypeError, ValueError) as e: except (OSError, TooManyNodesError, TypeError, ValueError) as e:
print(f"Error: {e}", file=sys.stderr) # noqa: T201 print(f"Error: {e}", file=sys.stderr) # noqa: T201
sys.exit(1) sys.exit(1)

10
tests/lol.yml Normal file
View File

@ -0,0 +1,10 @@
lol1: &lol1 "lol"
lol2: &lol2 [*lol1,*lol1,*lol1,*lol1,*lol1,*lol1,*lol1,*lol1,*lol1]
lol3: &lol3 [*lol2,*lol2,*lol2,*lol2,*lol2,*lol2,*lol2,*lol2,*lol2]
lol4: &lol4 [*lol3,*lol3,*lol3,*lol3,*lol3,*lol3,*lol3,*lol3,*lol3]
lol5: &lol5 [*lol4,*lol4,*lol4,*lol4,*lol4,*lol4,*lol4,*lol4,*lol4]
lol6: &lol6 [*lol5,*lol5,*lol5,*lol5,*lol5,*lol5,*lol5,*lol5,*lol5]
lol7: &lol7 [*lol6,*lol6,*lol6,*lol6,*lol6,*lol6,*lol6,*lol6,*lol6]
lol8: &lol8 [*lol7,*lol7,*lol7,*lol7,*lol7,*lol7,*lol7,*lol7,*lol7]
lol9: &lol9 [*lol8,*lol8,*lol8,*lol8,*lol8,*lol8,*lol8,*lol8,*lol8]
lol10: &lol10 [*lol9,*lol9,*lol9,*lol9,*lol9,*lol9,*lol9,*lol9,*lol9]

View File

@ -601,3 +601,7 @@ class TestRemarshal(unittest.TestCase):
) )
reference = read_file("numeric-key-null-value.toml") reference = read_file("numeric-key-null-value.toml")
assert output == reference assert output == reference
def test_yaml_billion_laughs(self) -> None:
with pytest.raises(remarshal.TooManyNodesError):
self.convert_and_read("lol.yml", "yaml", "json")