|
8 | 8 | import pytest_asyncio |
9 | 9 | from pytest_codspeed import BenchmarkFixture # type: ignore[import-untyped] |
10 | 10 |
|
| 11 | +from aioesphomeapi._frame_helper.base import ( |
| 12 | + MAX_EXPLANATION_LEN, |
| 13 | + MAX_MAC_LEN, |
| 14 | + MAX_NAME_LEN, |
| 15 | + safe_label_str, |
| 16 | +) |
11 | 17 | from aioesphomeapi._frame_helper.noise_encryption import EncryptCipher |
12 | 18 | from aioesphomeapi._frame_helper.packets import make_noise_packets |
13 | 19 |
|
@@ -214,3 +220,90 @@ def _drop(data: Iterable[bytes]) -> None: |
214 | 220 | def write_packets() -> None: |
215 | 221 | for _ in range(100): |
216 | 222 | helper.write_packets(packets, False) |
| 223 | + |
| 224 | + |
| 225 | +async def test_noise_hello_parse(benchmark: BenchmarkFixture) -> None: |
| 226 | + """Benchmark the noise hello parse path (name + mac sanitize).""" |
| 227 | + noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" |
| 228 | + |
| 229 | + def _drop(data: Iterable[bytes]) -> None: |
| 230 | + """Skip the actual write so we measure only the parse path.""" |
| 231 | + |
| 232 | + hello_pkt = _make_noise_hello_pkt(b"\x01servicetest\0aabbccddeeff\0") |
| 233 | + |
| 234 | + @benchmark |
| 235 | + def parse_hello_packets() -> None: |
| 236 | + for _ in range(20): |
| 237 | + connection, _ = _make_mock_connection() |
| 238 | + helper = MockAPINoiseFrameHelper( |
| 239 | + connection=connection, |
| 240 | + noise_psk=noise_psk, |
| 241 | + expected_name=None, |
| 242 | + expected_mac=None, |
| 243 | + client_info="my client", |
| 244 | + log_name="test", |
| 245 | + writer=_drop, |
| 246 | + ) |
| 247 | + mock_data_received(helper, hello_pkt) |
| 248 | + helper.close() |
| 249 | + |
| 250 | + |
| 251 | +async def test_noise_handshake_reject_parse(benchmark: BenchmarkFixture) -> None: |
| 252 | + """Benchmark the handshake-reject parse path (explanation sanitize).""" |
| 253 | + noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" |
| 254 | + |
| 255 | + def _drop(data: Iterable[bytes]) -> None: |
| 256 | + """Skip the actual write so we measure only the parse path.""" |
| 257 | + |
| 258 | + hello_pkt = _make_noise_hello_pkt(b"\x01servicetest\0aabbccddeeff\0") |
| 259 | + |
| 260 | + # Inner first byte != 0x00 routes _handle_handshake into |
| 261 | + # _error_on_incorrect_preamble. |
| 262 | + explanation = b"boom\r\nFAKE LOG\x1b[31mreason text" |
| 263 | + reject_inner = b"\x01" + explanation |
| 264 | + reject_header = bytes( |
| 265 | + (0x01, (len(reject_inner) >> 8) & 0xFF, len(reject_inner) & 0xFF) |
| 266 | + ) |
| 267 | + reject_frame = reject_header + reject_inner |
| 268 | + |
| 269 | + @benchmark |
| 270 | + def parse_reject_packets() -> None: |
| 271 | + for _ in range(20): |
| 272 | + connection, _ = _make_mock_connection() |
| 273 | + helper = MockAPINoiseFrameHelper( |
| 274 | + connection=connection, |
| 275 | + noise_psk=noise_psk, |
| 276 | + expected_name=None, |
| 277 | + expected_mac=None, |
| 278 | + client_info="my client", |
| 279 | + log_name="test", |
| 280 | + writer=_drop, |
| 281 | + ) |
| 282 | + mock_data_received(helper, hello_pkt) |
| 283 | + mock_data_received(helper, reject_frame) |
| 284 | + helper.close() |
| 285 | + |
| 286 | + |
| 287 | +_SAFE_LABEL_CASES = [ |
| 288 | + ("name", "servicetest", MAX_NAME_LEN), |
| 289 | + ("name_noisy", "service\r\ntest\x1b[31m", MAX_NAME_LEN), |
| 290 | + ("mac", "aabbccddeeff", MAX_MAC_LEN), |
| 291 | + ("explanation", "Handshake MAC failure", MAX_EXPLANATION_LEN), |
| 292 | + ("explanation_noisy", "Handshake\r\nMAC\x1b[31m failure", MAX_EXPLANATION_LEN), |
| 293 | +] |
| 294 | + |
| 295 | + |
| 296 | +@pytest.mark.parametrize( |
| 297 | + ("label", "raw", "limit"), |
| 298 | + _SAFE_LABEL_CASES, |
| 299 | + ids=[case[0] for case in _SAFE_LABEL_CASES], |
| 300 | +) |
| 301 | +def test_safe_label_str_throughput( |
| 302 | + benchmark: BenchmarkFixture, label: str, raw: str, limit: int |
| 303 | +) -> None: |
| 304 | + """Throughput benchmark for `safe_label_str` (printable filter + cap).""" |
| 305 | + |
| 306 | + @benchmark |
| 307 | + def run_safe_label_str() -> None: |
| 308 | + for _ in range(1000): |
| 309 | + safe_label_str(raw, limit) |
0 commit comments