From 6f3ca412e3bb84c005f99788496db125ec9a9d46 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 20 May 2026 08:11:25 -0500 Subject: [PATCH] Benchmark safe_label_str hello and handshake-reject paths Adds CodSpeed coverage for the three places `safe_label_str` runs on the wire-format hot path so any future regression to the printable filter or length-cap shows up as a measurable delta. - `test_noise_hello_parse` drives a hello frame through a fresh MockAPINoiseFrameHelper, exercising `_handle_hello` and its two safe_label_str calls (server name and MAC). - `test_noise_handshake_reject_parse` drives hello + reject frames so `_error_on_incorrect_preamble` runs the explanation sanitize. - `test_safe_label_str_throughput` is a parametrized direct benchmark of `safe_label_str` over realistic name / MAC / explanation inputs (clean and control-char-laden), catching algorithm-level regressions in the printable-filter / length-cap path. --- tests/benchmarks/test_noise.py | 93 ++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/tests/benchmarks/test_noise.py b/tests/benchmarks/test_noise.py index f5ff01d3..0df3fbc3 100644 --- a/tests/benchmarks/test_noise.py +++ b/tests/benchmarks/test_noise.py @@ -8,6 +8,12 @@ import pytest_asyncio from pytest_codspeed import BenchmarkFixture # type: ignore[import-untyped] +from aioesphomeapi._frame_helper.base import ( + MAX_EXPLANATION_LEN, + MAX_MAC_LEN, + MAX_NAME_LEN, + safe_label_str, +) from aioesphomeapi._frame_helper.noise_encryption import EncryptCipher from aioesphomeapi._frame_helper.packets import make_noise_packets @@ -214,3 +220,90 @@ def _drop(data: Iterable[bytes]) -> None: def write_packets() -> None: for _ in range(100): helper.write_packets(packets, False) + + +async def test_noise_hello_parse(benchmark: BenchmarkFixture) -> None: + """Benchmark the noise hello parse path (name + mac sanitize).""" + noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" + + def _drop(data: Iterable[bytes]) -> None: + """Skip the actual write so we measure only the parse path.""" + + hello_pkt = _make_noise_hello_pkt(b"\x01servicetest\0aabbccddeeff\0") + + @benchmark + def parse_hello_packets() -> None: + for _ in range(20): + connection, _ = _make_mock_connection() + helper = MockAPINoiseFrameHelper( + connection=connection, + noise_psk=noise_psk, + expected_name=None, + expected_mac=None, + client_info="my client", + log_name="test", + writer=_drop, + ) + mock_data_received(helper, hello_pkt) + helper.close() + + +async def test_noise_handshake_reject_parse(benchmark: BenchmarkFixture) -> None: + """Benchmark the handshake-reject parse path (explanation sanitize).""" + noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" + + def _drop(data: Iterable[bytes]) -> None: + """Skip the actual write so we measure only the parse path.""" + + hello_pkt = _make_noise_hello_pkt(b"\x01servicetest\0aabbccddeeff\0") + + # Inner first byte != 0x00 routes _handle_handshake into + # _error_on_incorrect_preamble. + explanation = b"boom\r\nFAKE LOG\x1b[31mreason text" + reject_inner = b"\x01" + explanation + reject_header = bytes( + (0x01, (len(reject_inner) >> 8) & 0xFF, len(reject_inner) & 0xFF) + ) + reject_frame = reject_header + reject_inner + + @benchmark + def parse_reject_packets() -> None: + for _ in range(20): + connection, _ = _make_mock_connection() + helper = MockAPINoiseFrameHelper( + connection=connection, + noise_psk=noise_psk, + expected_name=None, + expected_mac=None, + client_info="my client", + log_name="test", + writer=_drop, + ) + mock_data_received(helper, hello_pkt) + mock_data_received(helper, reject_frame) + helper.close() + + +_SAFE_LABEL_CASES = [ + ("name", "servicetest", MAX_NAME_LEN), + ("name_noisy", "service\r\ntest\x1b[31m", MAX_NAME_LEN), + ("mac", "aabbccddeeff", MAX_MAC_LEN), + ("explanation", "Handshake MAC failure", MAX_EXPLANATION_LEN), + ("explanation_noisy", "Handshake\r\nMAC\x1b[31m failure", MAX_EXPLANATION_LEN), +] + + +@pytest.mark.parametrize( + ("label", "raw", "limit"), + _SAFE_LABEL_CASES, + ids=[case[0] for case in _SAFE_LABEL_CASES], +) +def test_safe_label_str_throughput( + benchmark: BenchmarkFixture, label: str, raw: str, limit: int +) -> None: + """Throughput benchmark for `safe_label_str` (printable filter + cap).""" + + @benchmark + def run_safe_label_str() -> None: + for _ in range(1000): + safe_label_str(raw, limit)