diff --git a/tests/benchmarks/test_noise.py b/tests/benchmarks/test_noise.py index f5ff01d3..0df3fbc3 100644 --- a/tests/benchmarks/test_noise.py +++ b/tests/benchmarks/test_noise.py @@ -8,6 +8,12 @@ import pytest_asyncio from pytest_codspeed import BenchmarkFixture # type: ignore[import-untyped] +from aioesphomeapi._frame_helper.base import ( + MAX_EXPLANATION_LEN, + MAX_MAC_LEN, + MAX_NAME_LEN, + safe_label_str, +) from aioesphomeapi._frame_helper.noise_encryption import EncryptCipher from aioesphomeapi._frame_helper.packets import make_noise_packets @@ -214,3 +220,90 @@ def _drop(data: Iterable[bytes]) -> None: def write_packets() -> None: for _ in range(100): helper.write_packets(packets, False) + + +async def test_noise_hello_parse(benchmark: BenchmarkFixture) -> None: + """Benchmark the noise hello parse path (name + mac sanitize).""" + noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" + + def _drop(data: Iterable[bytes]) -> None: + """Skip the actual write so we measure only the parse path.""" + + hello_pkt = _make_noise_hello_pkt(b"\x01servicetest\0aabbccddeeff\0") + + @benchmark + def parse_hello_packets() -> None: + for _ in range(20): + connection, _ = _make_mock_connection() + helper = MockAPINoiseFrameHelper( + connection=connection, + noise_psk=noise_psk, + expected_name=None, + expected_mac=None, + client_info="my client", + log_name="test", + writer=_drop, + ) + mock_data_received(helper, hello_pkt) + helper.close() + + +async def test_noise_handshake_reject_parse(benchmark: BenchmarkFixture) -> None: + """Benchmark the handshake-reject parse path (explanation sanitize).""" + noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" + + def _drop(data: Iterable[bytes]) -> None: + """Skip the actual write so we measure only the parse path.""" + + hello_pkt = _make_noise_hello_pkt(b"\x01servicetest\0aabbccddeeff\0") + + # Inner first byte != 0x00 routes _handle_handshake into + # _error_on_incorrect_preamble. + explanation = b"boom\r\nFAKE LOG\x1b[31mreason text" + reject_inner = b"\x01" + explanation + reject_header = bytes( + (0x01, (len(reject_inner) >> 8) & 0xFF, len(reject_inner) & 0xFF) + ) + reject_frame = reject_header + reject_inner + + @benchmark + def parse_reject_packets() -> None: + for _ in range(20): + connection, _ = _make_mock_connection() + helper = MockAPINoiseFrameHelper( + connection=connection, + noise_psk=noise_psk, + expected_name=None, + expected_mac=None, + client_info="my client", + log_name="test", + writer=_drop, + ) + mock_data_received(helper, hello_pkt) + mock_data_received(helper, reject_frame) + helper.close() + + +_SAFE_LABEL_CASES = [ + ("name", "servicetest", MAX_NAME_LEN), + ("name_noisy", "service\r\ntest\x1b[31m", MAX_NAME_LEN), + ("mac", "aabbccddeeff", MAX_MAC_LEN), + ("explanation", "Handshake MAC failure", MAX_EXPLANATION_LEN), + ("explanation_noisy", "Handshake\r\nMAC\x1b[31m failure", MAX_EXPLANATION_LEN), +] + + +@pytest.mark.parametrize( + ("label", "raw", "limit"), + _SAFE_LABEL_CASES, + ids=[case[0] for case in _SAFE_LABEL_CASES], +) +def test_safe_label_str_throughput( + benchmark: BenchmarkFixture, label: str, raw: str, limit: int +) -> None: + """Throughput benchmark for `safe_label_str` (printable filter + cap).""" + + @benchmark + def run_safe_label_str() -> None: + for _ in range(1000): + safe_label_str(raw, limit)