Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 47 additions & 12 deletions esphome_device_builder/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import asyncio
import contextlib
import logging
import re
import sys

from zeroconf import IPVersion, ServiceStateChange, Zeroconf
Expand All @@ -41,6 +42,27 @@
)
_UNKNOWN = "unknown"

# Per-column display caps for peer-supplied mDNS labels, derived from the
# _FORMAT widths so a hostile broadcaster can't widen a column by stuffing a
# long value; deriving from _FORMAT keeps the caps in lock-step if the table
# layout is ever re-tuned.
_COLUMN_WIDTHS = tuple(int(w) for w in re.findall(r"<\s*(\d+)", _FORMAT))
if len(_COLUMN_WIDTHS) != len(_COLUMN_NAMES):
# Runtime check, not `assert`, so the invariant still holds under
# `python -O` (which strips assert statements).
raise RuntimeError(
"_FORMAT width count must match _COLUMN_NAMES; update one and the other together"
)
_MAX_NAME_DISPLAY = _COLUMN_WIDTHS[_COLUMN_NAMES.index("Name")]
_MAX_SERVER_DISPLAY = _COLUMN_WIDTHS[_COLUMN_NAMES.index("Server")]
_MAX_ESPHOME_DISPLAY = _COLUMN_WIDTHS[_COLUMN_NAMES.index("ESPHome")]
_MAX_PORT_DISPLAY = _COLUMN_WIDTHS[_COLUMN_NAMES.index("RB Port")]
# Pin column is 16 chars wide but `_truncate_pin` collapses a full 64-hex pin
# to 12 chars + ellipsis at print time, so the raw cap stays at 64 to keep
# legitimate pins intact; an oversized hostile value is still bounded by the
# subsequent truncation.
_MAX_PIN_DISPLAY = 64


def main() -> None:
"""CLI entry point.
Expand Down Expand Up @@ -122,16 +144,20 @@ async def _run(args: argparse.Namespace) -> None:
await aiozc.async_close()


def _decode(data: str | bytes | None) -> str:
"""Decode a TXT-record value to ``str``, or return ``unknown``."""
def _safe_label(raw: str, limit: int) -> str:
"""Strip non-printables and length-cap a peer-supplied label for stdout."""
return "".join(filter(str.isprintable, raw))[:limit]


def _decode_mdns_label_or_unknown(data: str | bytes | None, limit: int = _MAX_NAME_DISPLAY) -> str:
"""Decode peer-supplied mDNS bytes, strip non-printables, length-cap."""
if data is None:
return _UNKNOWN
if isinstance(data, bytes):
try:
return data.decode("utf-8")
except UnicodeDecodeError:
return data.decode("utf-8", errors="replace")
return data
# A device on the LAN can broadcast arbitrary bytes; use "replace" so
# a malformed UTF-8 payload doesn't raise out of the zeroconf callback.
data = data.decode("utf-8", "replace")
return _safe_label(data, limit)


def _truncate_pin(pin: str) -> str:
Expand Down Expand Up @@ -172,7 +198,10 @@ def _on_service_state_change(
:mod:`controllers._device_state_monitor` /
:mod:`controllers.remote_build.controller`).
"""
short_name = name.partition(".")[0]
# The mDNS service name is peer-controlled; sanitize before printing so a
# hostile broadcaster can't inject ANSI escapes / newlines / null bytes
# into the terminal via the instance label.
short_name = _safe_label(name.partition(".")[0], _MAX_NAME_DISPLAY)
state = "OFFLINE" if state_change is ServiceStateChange.Removed else "ONLINE"
info = AsyncServiceInfo(service_type, name)
# ``load_from_cache`` returns ``False`` when the browser
Expand All @@ -185,10 +214,16 @@ def _on_service_state_change(
# resolve catches up.
info.load_from_cache(zeroconf)
properties = info.properties or {}
server_version = _decode(properties.get(b"server_version"))
esphome_version = _decode(properties.get(b"esphome_version"))
pin_sha256 = _decode(properties.get(b"pin_sha256"))
remote_build_port = _decode(properties.get(b"remote_build_port"))
server_version = _decode_mdns_label_or_unknown(
properties.get(b"server_version"), _MAX_SERVER_DISPLAY
)
esphome_version = _decode_mdns_label_or_unknown(
properties.get(b"esphome_version"), _MAX_ESPHOME_DISPLAY
)
pin_sha256 = _decode_mdns_label_or_unknown(properties.get(b"pin_sha256"), _MAX_PIN_DISPLAY)
remote_build_port = _decode_mdns_label_or_unknown(
properties.get(b"remote_build_port"), _MAX_PORT_DISPLAY
)

address = ""
if v4_addresses := info.ip_addresses_by_version(IPVersion.V4Only):
Expand Down
139 changes: 132 additions & 7 deletions tests/test_discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import argparse
import asyncio
import logging
import re
from collections.abc import Coroutine
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
Expand All @@ -28,11 +29,18 @@

from esphome_device_builder.discover import (
_COLUMN_NAMES,
_FORMAT,
_MAX_ESPHOME_DISPLAY,
_MAX_NAME_DISPLAY,
_MAX_PIN_DISPLAY,
_MAX_PORT_DISPLAY,
_MAX_SERVER_DISPLAY,
_UNKNOWN,
_build_parser,
_decode,
_decode_mdns_label_or_unknown,
_on_service_state_change,
_run,
_safe_label,
_truncate_pin,
main,
)
Expand All @@ -44,16 +52,133 @@
(b"hello", "hello"),
("plain", "plain"),
(None, _UNKNOWN),
# ``bytes`` containing a non-UTF-8 sequence falls through
# the strict decode and lands on the replacement-char path
# (rather than raising) so a malformed TXT entry doesn't
# crash the browse loop.
# ``bytes`` containing a non-UTF-8 sequence falls through the
# ``"replace"`` handler so a malformed TXT entry doesn't crash the
# browse loop. Pin the actual U+FFFD output (one per invalid byte)
# so a future refactor that silently swaps the handler for an
# UNKNOWN-or-empty fallback trips a red test.
(b"\xff\xfe", "��"),
],
)
def test_decode_handles_every_txt_wire_shape(raw: str | bytes | None, expected: str) -> None:
"""``_decode`` round-trips bytes, leaves strings alone, marks missing."""
assert _decode(raw) == expected
"""``_decode_mdns_label_or_unknown`` decodes / sanitizes bytes + str, marks missing."""
assert _decode_mdns_label_or_unknown(raw) == expected


def test_safe_label_strips_ansi_escape_introducer() -> None:
"""ESC bytes are stripped; trailing printable tail survives."""
assert _safe_label("\x1b[2Jvers1.0", 32) == "[2Jvers1.0"


def test_safe_label_strips_newline_cr_null_tab() -> None:
"""Control bytes that could reflow or terminate the printed row are dropped."""
assert _safe_label("line1\r\nline2", 32) == "line1line2"
assert _safe_label("col\tumn", 32) == "column"
assert _safe_label("esp\x0032", 32) == "esp32"


def test_safe_label_caps_length() -> None:
"""Oversized peer-supplied labels can't break the column-aligned table."""
assert _safe_label("x" * 200, 10) == "x" * 10


def test_safe_label_preserves_non_ascii_printable() -> None:
"""Non-ASCII printable characters survive (``str.isprintable`` is Unicode-aware)."""
assert _safe_label("café", 32) == "café"


def test_decode_mdns_label_or_unknown_strips_control_chars_in_bytes() -> None:
"""Bytes path runs the ANSI / CR / LF / NUL / TAB strip."""
assert _decode_mdns_label_or_unknown(b"\x1b[2J0.1.62", 32) == "[2J0.1.62"
assert _decode_mdns_label_or_unknown(b"line1\r\nline2", 32) == "line1line2"
assert _decode_mdns_label_or_unknown(b"col\tumn", 32) == "column"
assert _decode_mdns_label_or_unknown(b"esp\x0032", 32) == "esp32"


def test_decode_mdns_label_or_unknown_strips_control_chars_in_str() -> None:
"""Str path also runs the sanitizer (peer-provided strs are equally hostile)."""
assert _decode_mdns_label_or_unknown("\x1b[2J0.1.62", 32) == "[2J0.1.62"


def test_decode_mdns_label_or_unknown_caps_length_with_explicit_limit() -> None:
assert _decode_mdns_label_or_unknown(b"x" * 200, 10) == "x" * 10


def test_decode_mdns_label_or_unknown_default_limit_caps_long_value() -> None:
"""Default cap is the Name column width from ``_FORMAT``."""
assert len(_decode_mdns_label_or_unknown("a" * 200)) == _MAX_NAME_DISPLAY


def test_decode_mdns_label_or_unknown_unicode_printable_survives() -> None:
assert _decode_mdns_label_or_unknown("café") == "café"


def test_per_column_caps_match_format_widths() -> None:
"""Per-column caps stay locked to the ``_FORMAT`` widths.

A peer-controlled value can never widen a column past its slot;
if ``_FORMAT`` changes and this fires, update the cap derivation
in ``discover.py`` rather than bumping the expected values. The
pin cap stays at 64 because ``_truncate_pin`` collapses to 12
chars + ellipsis at print time, bounded independently.
"""
widths = tuple(int(w) for w in re.findall(r"<\s*(\d+)", _FORMAT))
assert widths[_COLUMN_NAMES.index("Name")] == _MAX_NAME_DISPLAY
assert widths[_COLUMN_NAMES.index("Server")] == _MAX_SERVER_DISPLAY
assert widths[_COLUMN_NAMES.index("ESPHome")] == _MAX_ESPHOME_DISPLAY
assert widths[_COLUMN_NAMES.index("RB Port")] == _MAX_PORT_DISPLAY
assert _MAX_PIN_DISPLAY == 64


def test_on_service_state_change_sanitizes_hostile_service_name(
capsys: pytest.CaptureFixture[str],
) -> None:
"""ESC bytes in the mDNS instance name don't reach stdout."""
fake_info = MagicMock()
fake_info.properties = {}
fake_info.ip_addresses_by_version.return_value = ["192.168.1.10"]
fake_info.port = 6052

with patch("esphome_device_builder.discover.AsyncServiceInfo", return_value=fake_info):
_on_service_state_change(
MagicMock(),
"_esphomebuilder._tcp.local.",
"\x1b[2Jevil._esphomebuilder._tcp.local.",
ServiceStateChange.Added,
)

captured = capsys.readouterr().out
assert "\x1b" not in captured
assert "[2Jevil" in captured


def test_on_service_state_change_sanitizes_hostile_txt_values(
capsys: pytest.CaptureFixture[str],
) -> None:
"""ESC / CR / LF in TXT values don't reach stdout."""
fake_info = MagicMock()
fake_info.properties = {
b"server_version": b"\x1b[2J0.1.62",
b"esphome_version": b"line1\r\nline2",
b"pin_sha256": b"a" * 64,
b"remote_build_port": b"6053",
}
fake_info.ip_addresses_by_version.return_value = ["192.168.1.10"]
fake_info.port = 6052

with patch("esphome_device_builder.discover.AsyncServiceInfo", return_value=fake_info):
_on_service_state_change(
MagicMock(),
"_esphomebuilder._tcp.local.",
"build-server._esphomebuilder._tcp.local.",
ServiceStateChange.Added,
)

captured = capsys.readouterr().out
assert "\x1b" not in captured
assert "\r" not in captured
assert "[2J0.1.62" in captured
assert "line1line2" in captured


@pytest.mark.parametrize(
Expand Down
Loading