diff --git a/tests/unit/events/test_apify_event_manager.py b/tests/unit/events/test_apify_event_manager.py index 0954bed4..a6b9c4e9 100644 --- a/tests/unit/events/test_apify_event_manager.py +++ b/tests/unit/events/test_apify_event_manager.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import json import logging from collections import defaultdict @@ -20,7 +21,40 @@ from apify.events._types import SystemInfoEventData if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import AsyncGenerator, Callable + + +@contextlib.asynccontextmanager +async def _platform_ws_server( + monkeypatch: pytest.MonkeyPatch, +) -> AsyncGenerator[tuple[set[websockets.asyncio.server.ServerConnection], asyncio.Event]]: + """Create a local WebSocket server that simulates Apify platform events. + + Binds explicitly to ``127.0.0.1`` instead of ``localhost`` so that only a + single IPv4 socket is created. On Windows, ``localhost`` resolves to both + ``127.0.0.1`` *and* ``::1``, and the OS may assign **different** random + ports to each address — causing the client to connect to the wrong port. + + Yields a ``(connected_ws_clients, client_connected_event)`` tuple. After + opening an `ApifyEventManager`, ``await client_connected_event.wait()`` + before sending any messages to guarantee the server handler has registered + the connection. + """ + connected_ws_clients: set[websockets.asyncio.server.ServerConnection] = set() + client_connected = asyncio.Event() + + async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None: + connected_ws_clients.add(websocket) + client_connected.set() + try: + await websocket.wait_closed() + finally: + connected_ws_clients.remove(websocket) + + async with websockets.asyncio.server.serve(handler, host='127.0.0.1') as ws_server: + port: int = ws_server.sockets[0].getsockname()[1] + monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://127.0.0.1:{port}') + yield connected_ws_clients, client_connected async def test_lifecycle_local(caplog: pytest.LogCaptureFixture) -> None: @@ -137,47 +171,23 @@ async def test_lifecycle_on_platform_without_websocket(monkeypatch: pytest.Monke async def test_lifecycle_on_platform(monkeypatch: pytest.MonkeyPatch) -> None: - connected_ws_clients: set[websockets.asyncio.server.ServerConnection] = set() - - async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None: - connected_ws_clients.add(websocket) - try: - await websocket.wait_closed() - finally: - connected_ws_clients.remove(websocket) - - async with websockets.asyncio.server.serve(handler, host='localhost') as ws_server: - # When you don't specify a port explicitly, the websocket connection is opened on a random free port. - # We need to find out which port is that. - port: int = ws_server.sockets[0].getsockname()[1] - monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://localhost:{port}') - - async with ApifyEventManager(Configuration.get_global_configuration()): - assert len(connected_ws_clients) == 1 + async with ( + _platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected), + ApifyEventManager(Configuration.get_global_configuration()), + ): + await client_connected.wait() + assert len(connected_ws_clients) == 1 async def test_event_handling_on_platform(monkeypatch: pytest.MonkeyPatch) -> None: - connected_ws_clients: set[websockets.asyncio.server.ServerConnection] = set() + async with _platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected): - async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None: - connected_ws_clients.add(websocket) - try: - await websocket.wait_closed() - finally: - connected_ws_clients.remove(websocket) + async def send_platform_event(event_name: Event, data: Any = None) -> None: + message: dict[str, Any] = {'name': event_name.value} + if data: + message['data'] = data - async def send_platform_event(event_name: Event, data: Any = None) -> None: - message: dict[str, Any] = {'name': event_name.value} - if data: - message['data'] = data - - websockets.broadcast(connected_ws_clients, json.dumps(message)) - - async with websockets.asyncio.server.serve(handler, host='localhost') as ws_server: - # When you don't specify a port explicitly, the websocket connection is opened on a random free port. - # We need to find out which port is that. - port: int = ws_server.sockets[0].getsockname()[1] - monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://localhost:{port}') + websockets.broadcast(connected_ws_clients, json.dumps(message)) dummy_system_info = { 'memAvgBytes': 19328860.328293584, @@ -192,6 +202,7 @@ async def send_platform_event(event_name: Event, data: Any = None) -> None: SystemInfoEventData.model_validate(dummy_system_info) async with ApifyEventManager(Configuration.get_global_configuration()) as event_manager: + await client_connected.wait() event_calls = [] def listener(data: Any) -> None: @@ -232,124 +243,90 @@ async def handler(_data: Any) -> None: async def test_deprecated_event_is_skipped(monkeypatch: pytest.MonkeyPatch) -> None: """Test that deprecated events (like CPU_INFO) are silently skipped.""" - connected_ws_clients: set[websockets.asyncio.server.ServerConnection] = set() - - async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None: - connected_ws_clients.add(websocket) - try: - await websocket.wait_closed() - finally: - connected_ws_clients.remove(websocket) - - async with websockets.asyncio.server.serve(handler, host='localhost') as ws_server: - port: int = ws_server.sockets[0].getsockname()[1] - monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://localhost:{port}') - - async with ApifyEventManager(Configuration.get_global_configuration()) as event_manager: - event_calls: list[Any] = [] - - def listener(data: Any) -> None: - event_calls.append(data) + async with ( + _platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected), + ApifyEventManager(Configuration.get_global_configuration()) as event_manager, + ): + await client_connected.wait() + event_calls: list[Any] = [] + + def listener(data: Any) -> None: + event_calls.append(data) - event_manager.on(event=Event.SYSTEM_INFO, listener=listener) + event_manager.on(event=Event.SYSTEM_INFO, listener=listener) - # Send a deprecated event (cpuInfo is deprecated) - deprecated_message = json.dumps({'name': 'cpuInfo', 'data': {}}) - websockets.broadcast(connected_ws_clients, deprecated_message) - await asyncio.sleep(0.2) + # Send a deprecated event (cpuInfo is deprecated) + deprecated_message = json.dumps({'name': 'cpuInfo', 'data': {}}) + websockets.broadcast(connected_ws_clients, deprecated_message) + await asyncio.sleep(0.2) - # No events should have been emitted - assert len(event_calls) == 0 + # No events should have been emitted + assert len(event_calls) == 0 async def test_unknown_event_is_logged(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None: """Test that unknown events are logged and not emitted.""" - connected_ws_clients: set[websockets.asyncio.server.ServerConnection] = set() + async with ( + _platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected), + ApifyEventManager(Configuration.get_global_configuration()), + ): + await client_connected.wait() - async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None: - connected_ws_clients.add(websocket) - try: - await websocket.wait_closed() - finally: - connected_ws_clients.remove(websocket) + # Send an unknown event + unknown_message = json.dumps({'name': 'totallyNewEvent2099', 'data': {'foo': 'bar'}}) + websockets.broadcast(connected_ws_clients, unknown_message) + await asyncio.sleep(0.2) - async with websockets.asyncio.server.serve(handler, host='localhost') as ws_server: - port: int = ws_server.sockets[0].getsockname()[1] - monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://localhost:{port}') - - async with ApifyEventManager(Configuration.get_global_configuration()): - # Send an unknown event - unknown_message = json.dumps({'name': 'totallyNewEvent2099', 'data': {'foo': 'bar'}}) - websockets.broadcast(connected_ws_clients, unknown_message) - await asyncio.sleep(0.2) - - assert 'Unknown message received' in caplog.text - assert 'totallyNewEvent2099' in caplog.text + assert 'Unknown message received' in caplog.text + assert 'totallyNewEvent2099' in caplog.text async def test_migrating_event_triggers_persist_state(monkeypatch: pytest.MonkeyPatch) -> None: """Test that a MIGRATING event triggers a PERSIST_STATE event with is_migrating=True.""" - connected_ws_clients: set[websockets.asyncio.server.ServerConnection] = set() + async with ( + _platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected), + ApifyEventManager(Configuration.get_global_configuration()) as event_manager, + ): + await client_connected.wait() + persist_calls: list[Any] = [] + migrating_calls: list[Any] = [] - async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None: - connected_ws_clients.add(websocket) - try: - await websocket.wait_closed() - finally: - connected_ws_clients.remove(websocket) - - async with websockets.asyncio.server.serve(handler, host='localhost') as ws_server: - port: int = ws_server.sockets[0].getsockname()[1] - monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://localhost:{port}') - - async with ApifyEventManager(Configuration.get_global_configuration()) as event_manager: - persist_calls: list[Any] = [] - migrating_calls: list[Any] = [] - - def persist_listener(data: Any) -> None: - persist_calls.append(data) + def persist_listener(data: Any) -> None: + persist_calls.append(data) - def migrating_listener(data: Any) -> None: - migrating_calls.append(data) + def migrating_listener(data: Any) -> None: + migrating_calls.append(data) - event_manager.on(event=Event.PERSIST_STATE, listener=persist_listener) - event_manager.on(event=Event.MIGRATING, listener=migrating_listener) + event_manager.on(event=Event.PERSIST_STATE, listener=persist_listener) + event_manager.on(event=Event.MIGRATING, listener=migrating_listener) - # Clear any initial persist state events - await asyncio.sleep(0.2) - persist_calls.clear() + # Clear any initial persist state events + await asyncio.sleep(0.2) + persist_calls.clear() - # Send migrating event - migrating_message = json.dumps({'name': 'migrating'}) - websockets.broadcast(connected_ws_clients, migrating_message) - await asyncio.sleep(0.2) + # Send migrating event + migrating_message = json.dumps({'name': 'migrating'}) + websockets.broadcast(connected_ws_clients, migrating_message) + await asyncio.sleep(0.2) - assert len(migrating_calls) == 1 - # MIGRATING should also trigger a PERSIST_STATE with is_migrating=True - migration_persist_events = [c for c in persist_calls if hasattr(c, 'is_migrating') and c.is_migrating] - assert len(migration_persist_events) >= 1 + assert len(migrating_calls) == 1 + # MIGRATING should also trigger a PERSIST_STATE with is_migrating=True + migration_persist_events = [c for c in persist_calls if hasattr(c, 'is_migrating') and c.is_migrating] + assert len(migration_persist_events) >= 1 async def test_malformed_message_logs_exception( monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture ) -> None: """Test that malformed websocket messages are logged and don't crash the event manager.""" - connected_ws_clients: set[websockets.asyncio.server.ServerConnection] = set() - - async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None: - connected_ws_clients.add(websocket) - try: - await websocket.wait_closed() - finally: - connected_ws_clients.remove(websocket) - - async with websockets.asyncio.server.serve(handler, host='localhost') as ws_server: - port: int = ws_server.sockets[0].getsockname()[1] - monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://localhost:{port}') + async with ( + _platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected), + ApifyEventManager(Configuration.get_global_configuration()), + ): + await client_connected.wait() - async with ApifyEventManager(Configuration.get_global_configuration()): - # Send malformed message - websockets.broadcast(connected_ws_clients, 'this is not valid json{{{') - await asyncio.sleep(0.2) + # Send malformed message + websockets.broadcast(connected_ws_clients, 'this is not valid json{{{') + await asyncio.sleep(0.2) - assert 'Cannot parse Actor event' in caplog.text + assert 'Cannot parse Actor event' in caplog.text