Skip to content

Commit 49c258f

Browse files
authored
Sanitize peer-supplied mDNS labels in discover CLI (#911)
1 parent a15196a commit 49c258f

2 files changed

Lines changed: 179 additions & 19 deletions

File tree

esphome_device_builder/discover.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import asyncio
2323
import contextlib
2424
import logging
25+
import re
2526
import sys
2627

2728
from zeroconf import IPVersion, ServiceStateChange, Zeroconf
@@ -41,6 +42,27 @@
4142
)
4243
_UNKNOWN = "unknown"
4344

45+
# Per-column display caps for peer-supplied mDNS labels, derived from the
46+
# _FORMAT widths so a hostile broadcaster can't widen a column by stuffing a
47+
# long value; deriving from _FORMAT keeps the caps in lock-step if the table
48+
# layout is ever re-tuned.
49+
_COLUMN_WIDTHS = tuple(int(w) for w in re.findall(r"<\s*(\d+)", _FORMAT))
50+
if len(_COLUMN_WIDTHS) != len(_COLUMN_NAMES):
51+
# Runtime check, not `assert`, so the invariant still holds under
52+
# `python -O` (which strips assert statements).
53+
raise RuntimeError(
54+
"_FORMAT width count must match _COLUMN_NAMES; update one and the other together"
55+
)
56+
_MAX_NAME_DISPLAY = _COLUMN_WIDTHS[_COLUMN_NAMES.index("Name")]
57+
_MAX_SERVER_DISPLAY = _COLUMN_WIDTHS[_COLUMN_NAMES.index("Server")]
58+
_MAX_ESPHOME_DISPLAY = _COLUMN_WIDTHS[_COLUMN_NAMES.index("ESPHome")]
59+
_MAX_PORT_DISPLAY = _COLUMN_WIDTHS[_COLUMN_NAMES.index("RB Port")]
60+
# Pin column is 16 chars wide but `_truncate_pin` collapses a full 64-hex pin
61+
# to 12 chars + ellipsis at print time, so the raw cap stays at 64 to keep
62+
# legitimate pins intact; an oversized hostile value is still bounded by the
63+
# subsequent truncation.
64+
_MAX_PIN_DISPLAY = 64
65+
4466

4567
def main() -> None:
4668
"""CLI entry point.
@@ -122,16 +144,20 @@ async def _run(args: argparse.Namespace) -> None:
122144
await aiozc.async_close()
123145

124146

125-
def _decode(data: str | bytes | None) -> str:
126-
"""Decode a TXT-record value to ``str``, or return ``unknown``."""
147+
def _safe_label(raw: str, limit: int) -> str:
148+
"""Strip non-printables and length-cap a peer-supplied label for stdout."""
149+
return "".join(filter(str.isprintable, raw))[:limit]
150+
151+
152+
def _decode_mdns_label_or_unknown(data: str | bytes | None, limit: int = _MAX_NAME_DISPLAY) -> str:
153+
"""Decode peer-supplied mDNS bytes, strip non-printables, length-cap."""
127154
if data is None:
128155
return _UNKNOWN
129156
if isinstance(data, bytes):
130-
try:
131-
return data.decode("utf-8")
132-
except UnicodeDecodeError:
133-
return data.decode("utf-8", errors="replace")
134-
return data
157+
# A device on the LAN can broadcast arbitrary bytes; use "replace" so
158+
# a malformed UTF-8 payload doesn't raise out of the zeroconf callback.
159+
data = data.decode("utf-8", "replace")
160+
return _safe_label(data, limit)
135161

136162

137163
def _truncate_pin(pin: str) -> str:
@@ -172,7 +198,10 @@ def _on_service_state_change(
172198
:mod:`controllers._device_state_monitor` /
173199
:mod:`controllers.remote_build.controller`).
174200
"""
175-
short_name = name.partition(".")[0]
201+
# The mDNS service name is peer-controlled; sanitize before printing so a
202+
# hostile broadcaster can't inject ANSI escapes / newlines / null bytes
203+
# into the terminal via the instance label.
204+
short_name = _safe_label(name.partition(".")[0], _MAX_NAME_DISPLAY)
176205
state = "OFFLINE" if state_change is ServiceStateChange.Removed else "ONLINE"
177206
info = AsyncServiceInfo(service_type, name)
178207
# ``load_from_cache`` returns ``False`` when the browser
@@ -185,10 +214,16 @@ def _on_service_state_change(
185214
# resolve catches up.
186215
info.load_from_cache(zeroconf)
187216
properties = info.properties or {}
188-
server_version = _decode(properties.get(b"server_version"))
189-
esphome_version = _decode(properties.get(b"esphome_version"))
190-
pin_sha256 = _decode(properties.get(b"pin_sha256"))
191-
remote_build_port = _decode(properties.get(b"remote_build_port"))
217+
server_version = _decode_mdns_label_or_unknown(
218+
properties.get(b"server_version"), _MAX_SERVER_DISPLAY
219+
)
220+
esphome_version = _decode_mdns_label_or_unknown(
221+
properties.get(b"esphome_version"), _MAX_ESPHOME_DISPLAY
222+
)
223+
pin_sha256 = _decode_mdns_label_or_unknown(properties.get(b"pin_sha256"), _MAX_PIN_DISPLAY)
224+
remote_build_port = _decode_mdns_label_or_unknown(
225+
properties.get(b"remote_build_port"), _MAX_PORT_DISPLAY
226+
)
192227

193228
address = ""
194229
if v4_addresses := info.ip_addresses_by_version(IPVersion.V4Only):

tests/test_discover.py

Lines changed: 132 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import argparse
2020
import asyncio
2121
import logging
22+
import re
2223
from collections.abc import Coroutine
2324
from typing import Any
2425
from unittest.mock import AsyncMock, MagicMock, patch
@@ -28,11 +29,18 @@
2829

2930
from esphome_device_builder.discover import (
3031
_COLUMN_NAMES,
32+
_FORMAT,
33+
_MAX_ESPHOME_DISPLAY,
34+
_MAX_NAME_DISPLAY,
35+
_MAX_PIN_DISPLAY,
36+
_MAX_PORT_DISPLAY,
37+
_MAX_SERVER_DISPLAY,
3138
_UNKNOWN,
3239
_build_parser,
33-
_decode,
40+
_decode_mdns_label_or_unknown,
3441
_on_service_state_change,
3542
_run,
43+
_safe_label,
3644
_truncate_pin,
3745
main,
3846
)
@@ -44,16 +52,133 @@
4452
(b"hello", "hello"),
4553
("plain", "plain"),
4654
(None, _UNKNOWN),
47-
# ``bytes`` containing a non-UTF-8 sequence falls through
48-
# the strict decode and lands on the replacement-char path
49-
# (rather than raising) so a malformed TXT entry doesn't
50-
# crash the browse loop.
55+
# ``bytes`` containing a non-UTF-8 sequence falls through the
56+
# ``"replace"`` handler so a malformed TXT entry doesn't crash the
57+
# browse loop. Pin the actual U+FFFD output (one per invalid byte)
58+
# so a future refactor that silently swaps the handler for an
59+
# UNKNOWN-or-empty fallback trips a red test.
5160
(b"\xff\xfe", "��"),
5261
],
5362
)
5463
def test_decode_handles_every_txt_wire_shape(raw: str | bytes | None, expected: str) -> None:
55-
"""``_decode`` round-trips bytes, leaves strings alone, marks missing."""
56-
assert _decode(raw) == expected
64+
"""``_decode_mdns_label_or_unknown`` decodes / sanitizes bytes + str, marks missing."""
65+
assert _decode_mdns_label_or_unknown(raw) == expected
66+
67+
68+
def test_safe_label_strips_ansi_escape_introducer() -> None:
69+
"""ESC bytes are stripped; trailing printable tail survives."""
70+
assert _safe_label("\x1b[2Jvers1.0", 32) == "[2Jvers1.0"
71+
72+
73+
def test_safe_label_strips_newline_cr_null_tab() -> None:
74+
"""Control bytes that could reflow or terminate the printed row are dropped."""
75+
assert _safe_label("line1\r\nline2", 32) == "line1line2"
76+
assert _safe_label("col\tumn", 32) == "column"
77+
assert _safe_label("esp\x0032", 32) == "esp32"
78+
79+
80+
def test_safe_label_caps_length() -> None:
81+
"""Oversized peer-supplied labels can't break the column-aligned table."""
82+
assert _safe_label("x" * 200, 10) == "x" * 10
83+
84+
85+
def test_safe_label_preserves_non_ascii_printable() -> None:
86+
"""Non-ASCII printable characters survive (``str.isprintable`` is Unicode-aware)."""
87+
assert _safe_label("café", 32) == "café"
88+
89+
90+
def test_decode_mdns_label_or_unknown_strips_control_chars_in_bytes() -> None:
91+
"""Bytes path runs the ANSI / CR / LF / NUL / TAB strip."""
92+
assert _decode_mdns_label_or_unknown(b"\x1b[2J0.1.62", 32) == "[2J0.1.62"
93+
assert _decode_mdns_label_or_unknown(b"line1\r\nline2", 32) == "line1line2"
94+
assert _decode_mdns_label_or_unknown(b"col\tumn", 32) == "column"
95+
assert _decode_mdns_label_or_unknown(b"esp\x0032", 32) == "esp32"
96+
97+
98+
def test_decode_mdns_label_or_unknown_strips_control_chars_in_str() -> None:
99+
"""Str path also runs the sanitizer (peer-provided strs are equally hostile)."""
100+
assert _decode_mdns_label_or_unknown("\x1b[2J0.1.62", 32) == "[2J0.1.62"
101+
102+
103+
def test_decode_mdns_label_or_unknown_caps_length_with_explicit_limit() -> None:
104+
assert _decode_mdns_label_or_unknown(b"x" * 200, 10) == "x" * 10
105+
106+
107+
def test_decode_mdns_label_or_unknown_default_limit_caps_long_value() -> None:
108+
"""Default cap is the Name column width from ``_FORMAT``."""
109+
assert len(_decode_mdns_label_or_unknown("a" * 200)) == _MAX_NAME_DISPLAY
110+
111+
112+
def test_decode_mdns_label_or_unknown_unicode_printable_survives() -> None:
113+
assert _decode_mdns_label_or_unknown("café") == "café"
114+
115+
116+
def test_per_column_caps_match_format_widths() -> None:
117+
"""Per-column caps stay locked to the ``_FORMAT`` widths.
118+
119+
A peer-controlled value can never widen a column past its slot;
120+
if ``_FORMAT`` changes and this fires, update the cap derivation
121+
in ``discover.py`` rather than bumping the expected values. The
122+
pin cap stays at 64 because ``_truncate_pin`` collapses to 12
123+
chars + ellipsis at print time, bounded independently.
124+
"""
125+
widths = tuple(int(w) for w in re.findall(r"<\s*(\d+)", _FORMAT))
126+
assert widths[_COLUMN_NAMES.index("Name")] == _MAX_NAME_DISPLAY
127+
assert widths[_COLUMN_NAMES.index("Server")] == _MAX_SERVER_DISPLAY
128+
assert widths[_COLUMN_NAMES.index("ESPHome")] == _MAX_ESPHOME_DISPLAY
129+
assert widths[_COLUMN_NAMES.index("RB Port")] == _MAX_PORT_DISPLAY
130+
assert _MAX_PIN_DISPLAY == 64
131+
132+
133+
def test_on_service_state_change_sanitizes_hostile_service_name(
134+
capsys: pytest.CaptureFixture[str],
135+
) -> None:
136+
"""ESC bytes in the mDNS instance name don't reach stdout."""
137+
fake_info = MagicMock()
138+
fake_info.properties = {}
139+
fake_info.ip_addresses_by_version.return_value = ["192.168.1.10"]
140+
fake_info.port = 6052
141+
142+
with patch("esphome_device_builder.discover.AsyncServiceInfo", return_value=fake_info):
143+
_on_service_state_change(
144+
MagicMock(),
145+
"_esphomebuilder._tcp.local.",
146+
"\x1b[2Jevil._esphomebuilder._tcp.local.",
147+
ServiceStateChange.Added,
148+
)
149+
150+
captured = capsys.readouterr().out
151+
assert "\x1b" not in captured
152+
assert "[2Jevil" in captured
153+
154+
155+
def test_on_service_state_change_sanitizes_hostile_txt_values(
156+
capsys: pytest.CaptureFixture[str],
157+
) -> None:
158+
"""ESC / CR / LF in TXT values don't reach stdout."""
159+
fake_info = MagicMock()
160+
fake_info.properties = {
161+
b"server_version": b"\x1b[2J0.1.62",
162+
b"esphome_version": b"line1\r\nline2",
163+
b"pin_sha256": b"a" * 64,
164+
b"remote_build_port": b"6053",
165+
}
166+
fake_info.ip_addresses_by_version.return_value = ["192.168.1.10"]
167+
fake_info.port = 6052
168+
169+
with patch("esphome_device_builder.discover.AsyncServiceInfo", return_value=fake_info):
170+
_on_service_state_change(
171+
MagicMock(),
172+
"_esphomebuilder._tcp.local.",
173+
"build-server._esphomebuilder._tcp.local.",
174+
ServiceStateChange.Added,
175+
)
176+
177+
captured = capsys.readouterr().out
178+
assert "\x1b" not in captured
179+
assert "\r" not in captured
180+
assert "[2J0.1.62" in captured
181+
assert "line1line2" in captured
57182

58183

59184
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)