diff --git a/esphome_device_builder/discover.py b/esphome_device_builder/discover.py index bc6db3e5..be42fc67 100644 --- a/esphome_device_builder/discover.py +++ b/esphome_device_builder/discover.py @@ -22,6 +22,7 @@ import asyncio import contextlib import logging +import re import sys from zeroconf import IPVersion, ServiceStateChange, Zeroconf @@ -41,6 +42,27 @@ ) _UNKNOWN = "unknown" +# Per-column display caps for peer-supplied mDNS labels, derived from the +# _FORMAT widths so a hostile broadcaster can't widen a column by stuffing a +# long value; deriving from _FORMAT keeps the caps in lock-step if the table +# layout is ever re-tuned. +_COLUMN_WIDTHS = tuple(int(w) for w in re.findall(r"<\s*(\d+)", _FORMAT)) +if len(_COLUMN_WIDTHS) != len(_COLUMN_NAMES): + # Runtime check, not `assert`, so the invariant still holds under + # `python -O` (which strips assert statements). + raise RuntimeError( + "_FORMAT width count must match _COLUMN_NAMES; update one and the other together" + ) +_MAX_NAME_DISPLAY = _COLUMN_WIDTHS[_COLUMN_NAMES.index("Name")] +_MAX_SERVER_DISPLAY = _COLUMN_WIDTHS[_COLUMN_NAMES.index("Server")] +_MAX_ESPHOME_DISPLAY = _COLUMN_WIDTHS[_COLUMN_NAMES.index("ESPHome")] +_MAX_PORT_DISPLAY = _COLUMN_WIDTHS[_COLUMN_NAMES.index("RB Port")] +# Pin column is 16 chars wide but `_truncate_pin` collapses a full 64-hex pin +# to 12 chars + ellipsis at print time, so the raw cap stays at 64 to keep +# legitimate pins intact; an oversized hostile value is still bounded by the +# subsequent truncation. +_MAX_PIN_DISPLAY = 64 + def main() -> None: """CLI entry point. @@ -122,16 +144,20 @@ async def _run(args: argparse.Namespace) -> None: await aiozc.async_close() -def _decode(data: str | bytes | None) -> str: - """Decode a TXT-record value to ``str``, or return ``unknown``.""" +def _safe_label(raw: str, limit: int) -> str: + """Strip non-printables and length-cap a peer-supplied label for stdout.""" + return "".join(filter(str.isprintable, raw))[:limit] + + +def _decode_mdns_label_or_unknown(data: str | bytes | None, limit: int = _MAX_NAME_DISPLAY) -> str: + """Decode peer-supplied mDNS bytes, strip non-printables, length-cap.""" if data is None: return _UNKNOWN if isinstance(data, bytes): - try: - return data.decode("utf-8") - except UnicodeDecodeError: - return data.decode("utf-8", errors="replace") - return data + # A device on the LAN can broadcast arbitrary bytes; use "replace" so + # a malformed UTF-8 payload doesn't raise out of the zeroconf callback. + data = data.decode("utf-8", "replace") + return _safe_label(data, limit) def _truncate_pin(pin: str) -> str: @@ -172,7 +198,10 @@ def _on_service_state_change( :mod:`controllers._device_state_monitor` / :mod:`controllers.remote_build.controller`). """ - short_name = name.partition(".")[0] + # The mDNS service name is peer-controlled; sanitize before printing so a + # hostile broadcaster can't inject ANSI escapes / newlines / null bytes + # into the terminal via the instance label. + short_name = _safe_label(name.partition(".")[0], _MAX_NAME_DISPLAY) state = "OFFLINE" if state_change is ServiceStateChange.Removed else "ONLINE" info = AsyncServiceInfo(service_type, name) # ``load_from_cache`` returns ``False`` when the browser @@ -185,10 +214,16 @@ def _on_service_state_change( # resolve catches up. info.load_from_cache(zeroconf) properties = info.properties or {} - server_version = _decode(properties.get(b"server_version")) - esphome_version = _decode(properties.get(b"esphome_version")) - pin_sha256 = _decode(properties.get(b"pin_sha256")) - remote_build_port = _decode(properties.get(b"remote_build_port")) + server_version = _decode_mdns_label_or_unknown( + properties.get(b"server_version"), _MAX_SERVER_DISPLAY + ) + esphome_version = _decode_mdns_label_or_unknown( + properties.get(b"esphome_version"), _MAX_ESPHOME_DISPLAY + ) + pin_sha256 = _decode_mdns_label_or_unknown(properties.get(b"pin_sha256"), _MAX_PIN_DISPLAY) + remote_build_port = _decode_mdns_label_or_unknown( + properties.get(b"remote_build_port"), _MAX_PORT_DISPLAY + ) address = "" if v4_addresses := info.ip_addresses_by_version(IPVersion.V4Only): diff --git a/tests/test_discover.py b/tests/test_discover.py index 5acae16d..f2b1b40f 100644 --- a/tests/test_discover.py +++ b/tests/test_discover.py @@ -19,6 +19,7 @@ import argparse import asyncio import logging +import re from collections.abc import Coroutine from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -28,11 +29,18 @@ from esphome_device_builder.discover import ( _COLUMN_NAMES, + _FORMAT, + _MAX_ESPHOME_DISPLAY, + _MAX_NAME_DISPLAY, + _MAX_PIN_DISPLAY, + _MAX_PORT_DISPLAY, + _MAX_SERVER_DISPLAY, _UNKNOWN, _build_parser, - _decode, + _decode_mdns_label_or_unknown, _on_service_state_change, _run, + _safe_label, _truncate_pin, main, ) @@ -44,16 +52,133 @@ (b"hello", "hello"), ("plain", "plain"), (None, _UNKNOWN), - # ``bytes`` containing a non-UTF-8 sequence falls through - # the strict decode and lands on the replacement-char path - # (rather than raising) so a malformed TXT entry doesn't - # crash the browse loop. + # ``bytes`` containing a non-UTF-8 sequence falls through the + # ``"replace"`` handler so a malformed TXT entry doesn't crash the + # browse loop. Pin the actual U+FFFD output (one per invalid byte) + # so a future refactor that silently swaps the handler for an + # UNKNOWN-or-empty fallback trips a red test. (b"\xff\xfe", "��"), ], ) def test_decode_handles_every_txt_wire_shape(raw: str | bytes | None, expected: str) -> None: - """``_decode`` round-trips bytes, leaves strings alone, marks missing.""" - assert _decode(raw) == expected + """``_decode_mdns_label_or_unknown`` decodes / sanitizes bytes + str, marks missing.""" + assert _decode_mdns_label_or_unknown(raw) == expected + + +def test_safe_label_strips_ansi_escape_introducer() -> None: + """ESC bytes are stripped; trailing printable tail survives.""" + assert _safe_label("\x1b[2Jvers1.0", 32) == "[2Jvers1.0" + + +def test_safe_label_strips_newline_cr_null_tab() -> None: + """Control bytes that could reflow or terminate the printed row are dropped.""" + assert _safe_label("line1\r\nline2", 32) == "line1line2" + assert _safe_label("col\tumn", 32) == "column" + assert _safe_label("esp\x0032", 32) == "esp32" + + +def test_safe_label_caps_length() -> None: + """Oversized peer-supplied labels can't break the column-aligned table.""" + assert _safe_label("x" * 200, 10) == "x" * 10 + + +def test_safe_label_preserves_non_ascii_printable() -> None: + """Non-ASCII printable characters survive (``str.isprintable`` is Unicode-aware).""" + assert _safe_label("café", 32) == "café" + + +def test_decode_mdns_label_or_unknown_strips_control_chars_in_bytes() -> None: + """Bytes path runs the ANSI / CR / LF / NUL / TAB strip.""" + assert _decode_mdns_label_or_unknown(b"\x1b[2J0.1.62", 32) == "[2J0.1.62" + assert _decode_mdns_label_or_unknown(b"line1\r\nline2", 32) == "line1line2" + assert _decode_mdns_label_or_unknown(b"col\tumn", 32) == "column" + assert _decode_mdns_label_or_unknown(b"esp\x0032", 32) == "esp32" + + +def test_decode_mdns_label_or_unknown_strips_control_chars_in_str() -> None: + """Str path also runs the sanitizer (peer-provided strs are equally hostile).""" + assert _decode_mdns_label_or_unknown("\x1b[2J0.1.62", 32) == "[2J0.1.62" + + +def test_decode_mdns_label_or_unknown_caps_length_with_explicit_limit() -> None: + assert _decode_mdns_label_or_unknown(b"x" * 200, 10) == "x" * 10 + + +def test_decode_mdns_label_or_unknown_default_limit_caps_long_value() -> None: + """Default cap is the Name column width from ``_FORMAT``.""" + assert len(_decode_mdns_label_or_unknown("a" * 200)) == _MAX_NAME_DISPLAY + + +def test_decode_mdns_label_or_unknown_unicode_printable_survives() -> None: + assert _decode_mdns_label_or_unknown("café") == "café" + + +def test_per_column_caps_match_format_widths() -> None: + """Per-column caps stay locked to the ``_FORMAT`` widths. + + A peer-controlled value can never widen a column past its slot; + if ``_FORMAT`` changes and this fires, update the cap derivation + in ``discover.py`` rather than bumping the expected values. The + pin cap stays at 64 because ``_truncate_pin`` collapses to 12 + chars + ellipsis at print time, bounded independently. + """ + widths = tuple(int(w) for w in re.findall(r"<\s*(\d+)", _FORMAT)) + assert widths[_COLUMN_NAMES.index("Name")] == _MAX_NAME_DISPLAY + assert widths[_COLUMN_NAMES.index("Server")] == _MAX_SERVER_DISPLAY + assert widths[_COLUMN_NAMES.index("ESPHome")] == _MAX_ESPHOME_DISPLAY + assert widths[_COLUMN_NAMES.index("RB Port")] == _MAX_PORT_DISPLAY + assert _MAX_PIN_DISPLAY == 64 + + +def test_on_service_state_change_sanitizes_hostile_service_name( + capsys: pytest.CaptureFixture[str], +) -> None: + """ESC bytes in the mDNS instance name don't reach stdout.""" + fake_info = MagicMock() + fake_info.properties = {} + fake_info.ip_addresses_by_version.return_value = ["192.168.1.10"] + fake_info.port = 6052 + + with patch("esphome_device_builder.discover.AsyncServiceInfo", return_value=fake_info): + _on_service_state_change( + MagicMock(), + "_esphomebuilder._tcp.local.", + "\x1b[2Jevil._esphomebuilder._tcp.local.", + ServiceStateChange.Added, + ) + + captured = capsys.readouterr().out + assert "\x1b" not in captured + assert "[2Jevil" in captured + + +def test_on_service_state_change_sanitizes_hostile_txt_values( + capsys: pytest.CaptureFixture[str], +) -> None: + """ESC / CR / LF in TXT values don't reach stdout.""" + fake_info = MagicMock() + fake_info.properties = { + b"server_version": b"\x1b[2J0.1.62", + b"esphome_version": b"line1\r\nline2", + b"pin_sha256": b"a" * 64, + b"remote_build_port": b"6053", + } + fake_info.ip_addresses_by_version.return_value = ["192.168.1.10"] + fake_info.port = 6052 + + with patch("esphome_device_builder.discover.AsyncServiceInfo", return_value=fake_info): + _on_service_state_change( + MagicMock(), + "_esphomebuilder._tcp.local.", + "build-server._esphomebuilder._tcp.local.", + ServiceStateChange.Added, + ) + + captured = capsys.readouterr().out + assert "\x1b" not in captured + assert "\r" not in captured + assert "[2J0.1.62" in captured + assert "line1line2" in captured @pytest.mark.parametrize(