Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 108 additions & 131 deletions tests/unit/events/test_apify_event_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextlib
import json
import logging
from collections import defaultdict
Expand All @@ -20,7 +21,40 @@
from apify.events._types import SystemInfoEventData

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import AsyncGenerator, Callable


@contextlib.asynccontextmanager
async def _platform_ws_server(
monkeypatch: pytest.MonkeyPatch,
) -> AsyncGenerator[tuple[set[websockets.asyncio.server.ServerConnection], asyncio.Event]]:
"""Create a local WebSocket server that simulates Apify platform events.

Binds explicitly to ``127.0.0.1`` instead of ``localhost`` so that only a
single IPv4 socket is created. On Windows, ``localhost`` resolves to both
``127.0.0.1`` *and* ``::1``, and the OS may assign **different** random
ports to each address — causing the client to connect to the wrong port.

Yields a ``(connected_ws_clients, client_connected_event)`` tuple. After
opening an `ApifyEventManager`, ``await client_connected_event.wait()``
before sending any messages to guarantee the server handler has registered
the connection.
"""
connected_ws_clients: set[websockets.asyncio.server.ServerConnection] = set()
client_connected = asyncio.Event()

async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None:
connected_ws_clients.add(websocket)
client_connected.set()
try:
await websocket.wait_closed()
finally:
connected_ws_clients.remove(websocket)

async with websockets.asyncio.server.serve(handler, host='127.0.0.1') as ws_server:
port: int = ws_server.sockets[0].getsockname()[1]
monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://127.0.0.1:{port}')
yield connected_ws_clients, client_connected


async def test_lifecycle_local(caplog: pytest.LogCaptureFixture) -> None:
Expand Down Expand Up @@ -137,47 +171,23 @@ async def test_lifecycle_on_platform_without_websocket(monkeypatch: pytest.Monke


async def test_lifecycle_on_platform(monkeypatch: pytest.MonkeyPatch) -> None:
connected_ws_clients: set[websockets.asyncio.server.ServerConnection] = set()

async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None:
connected_ws_clients.add(websocket)
try:
await websocket.wait_closed()
finally:
connected_ws_clients.remove(websocket)

async with websockets.asyncio.server.serve(handler, host='localhost') as ws_server:
# When you don't specify a port explicitly, the websocket connection is opened on a random free port.
# We need to find out which port is that.
port: int = ws_server.sockets[0].getsockname()[1]
monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://localhost:{port}')

async with ApifyEventManager(Configuration.get_global_configuration()):
assert len(connected_ws_clients) == 1
async with (
_platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected),
ApifyEventManager(Configuration.get_global_configuration()),
):
await client_connected.wait()
assert len(connected_ws_clients) == 1


async def test_event_handling_on_platform(monkeypatch: pytest.MonkeyPatch) -> None:
connected_ws_clients: set[websockets.asyncio.server.ServerConnection] = set()
async with _platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected):

async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None:
connected_ws_clients.add(websocket)
try:
await websocket.wait_closed()
finally:
connected_ws_clients.remove(websocket)
async def send_platform_event(event_name: Event, data: Any = None) -> None:
message: dict[str, Any] = {'name': event_name.value}
if data:
message['data'] = data

async def send_platform_event(event_name: Event, data: Any = None) -> None:
message: dict[str, Any] = {'name': event_name.value}
if data:
message['data'] = data

websockets.broadcast(connected_ws_clients, json.dumps(message))

async with websockets.asyncio.server.serve(handler, host='localhost') as ws_server:
# When you don't specify a port explicitly, the websocket connection is opened on a random free port.
# We need to find out which port is that.
port: int = ws_server.sockets[0].getsockname()[1]
monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://localhost:{port}')
websockets.broadcast(connected_ws_clients, json.dumps(message))

dummy_system_info = {
'memAvgBytes': 19328860.328293584,
Expand All @@ -192,6 +202,7 @@ async def send_platform_event(event_name: Event, data: Any = None) -> None:
SystemInfoEventData.model_validate(dummy_system_info)

async with ApifyEventManager(Configuration.get_global_configuration()) as event_manager:
await client_connected.wait()
event_calls = []

def listener(data: Any) -> None:
Expand Down Expand Up @@ -232,124 +243,90 @@ async def handler(_data: Any) -> None:

async def test_deprecated_event_is_skipped(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that deprecated events (like CPU_INFO) are silently skipped."""
connected_ws_clients: set[websockets.asyncio.server.ServerConnection] = set()

async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None:
connected_ws_clients.add(websocket)
try:
await websocket.wait_closed()
finally:
connected_ws_clients.remove(websocket)

async with websockets.asyncio.server.serve(handler, host='localhost') as ws_server:
port: int = ws_server.sockets[0].getsockname()[1]
monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://localhost:{port}')

async with ApifyEventManager(Configuration.get_global_configuration()) as event_manager:
event_calls: list[Any] = []

def listener(data: Any) -> None:
event_calls.append(data)
async with (
_platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected),
ApifyEventManager(Configuration.get_global_configuration()) as event_manager,
):
await client_connected.wait()
event_calls: list[Any] = []

def listener(data: Any) -> None:
event_calls.append(data)

event_manager.on(event=Event.SYSTEM_INFO, listener=listener)
event_manager.on(event=Event.SYSTEM_INFO, listener=listener)

# Send a deprecated event (cpuInfo is deprecated)
deprecated_message = json.dumps({'name': 'cpuInfo', 'data': {}})
websockets.broadcast(connected_ws_clients, deprecated_message)
await asyncio.sleep(0.2)
# Send a deprecated event (cpuInfo is deprecated)
deprecated_message = json.dumps({'name': 'cpuInfo', 'data': {}})
websockets.broadcast(connected_ws_clients, deprecated_message)
await asyncio.sleep(0.2)

# No events should have been emitted
assert len(event_calls) == 0
# No events should have been emitted
assert len(event_calls) == 0


async def test_unknown_event_is_logged(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None:
"""Test that unknown events are logged and not emitted."""
connected_ws_clients: set[websockets.asyncio.server.ServerConnection] = set()
async with (
_platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected),
ApifyEventManager(Configuration.get_global_configuration()),
):
await client_connected.wait()

async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None:
connected_ws_clients.add(websocket)
try:
await websocket.wait_closed()
finally:
connected_ws_clients.remove(websocket)
# Send an unknown event
unknown_message = json.dumps({'name': 'totallyNewEvent2099', 'data': {'foo': 'bar'}})
websockets.broadcast(connected_ws_clients, unknown_message)
await asyncio.sleep(0.2)

async with websockets.asyncio.server.serve(handler, host='localhost') as ws_server:
port: int = ws_server.sockets[0].getsockname()[1]
monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://localhost:{port}')

async with ApifyEventManager(Configuration.get_global_configuration()):
# Send an unknown event
unknown_message = json.dumps({'name': 'totallyNewEvent2099', 'data': {'foo': 'bar'}})
websockets.broadcast(connected_ws_clients, unknown_message)
await asyncio.sleep(0.2)

assert 'Unknown message received' in caplog.text
assert 'totallyNewEvent2099' in caplog.text
assert 'Unknown message received' in caplog.text
assert 'totallyNewEvent2099' in caplog.text


async def test_migrating_event_triggers_persist_state(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that a MIGRATING event triggers a PERSIST_STATE event with is_migrating=True."""
connected_ws_clients: set[websockets.asyncio.server.ServerConnection] = set()
async with (
_platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected),
ApifyEventManager(Configuration.get_global_configuration()) as event_manager,
):
await client_connected.wait()
persist_calls: list[Any] = []
migrating_calls: list[Any] = []

async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None:
connected_ws_clients.add(websocket)
try:
await websocket.wait_closed()
finally:
connected_ws_clients.remove(websocket)

async with websockets.asyncio.server.serve(handler, host='localhost') as ws_server:
port: int = ws_server.sockets[0].getsockname()[1]
monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://localhost:{port}')

async with ApifyEventManager(Configuration.get_global_configuration()) as event_manager:
persist_calls: list[Any] = []
migrating_calls: list[Any] = []

def persist_listener(data: Any) -> None:
persist_calls.append(data)
def persist_listener(data: Any) -> None:
persist_calls.append(data)

def migrating_listener(data: Any) -> None:
migrating_calls.append(data)
def migrating_listener(data: Any) -> None:
migrating_calls.append(data)

event_manager.on(event=Event.PERSIST_STATE, listener=persist_listener)
event_manager.on(event=Event.MIGRATING, listener=migrating_listener)
event_manager.on(event=Event.PERSIST_STATE, listener=persist_listener)
event_manager.on(event=Event.MIGRATING, listener=migrating_listener)

# Clear any initial persist state events
await asyncio.sleep(0.2)
persist_calls.clear()
# Clear any initial persist state events
await asyncio.sleep(0.2)
persist_calls.clear()

# Send migrating event
migrating_message = json.dumps({'name': 'migrating'})
websockets.broadcast(connected_ws_clients, migrating_message)
await asyncio.sleep(0.2)
# Send migrating event
migrating_message = json.dumps({'name': 'migrating'})
websockets.broadcast(connected_ws_clients, migrating_message)
await asyncio.sleep(0.2)

assert len(migrating_calls) == 1
# MIGRATING should also trigger a PERSIST_STATE with is_migrating=True
migration_persist_events = [c for c in persist_calls if hasattr(c, 'is_migrating') and c.is_migrating]
assert len(migration_persist_events) >= 1
assert len(migrating_calls) == 1
# MIGRATING should also trigger a PERSIST_STATE with is_migrating=True
migration_persist_events = [c for c in persist_calls if hasattr(c, 'is_migrating') and c.is_migrating]
assert len(migration_persist_events) >= 1


async def test_malformed_message_logs_exception(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that malformed websocket messages are logged and don't crash the event manager."""
connected_ws_clients: set[websockets.asyncio.server.ServerConnection] = set()

async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None:
connected_ws_clients.add(websocket)
try:
await websocket.wait_closed()
finally:
connected_ws_clients.remove(websocket)

async with websockets.asyncio.server.serve(handler, host='localhost') as ws_server:
port: int = ws_server.sockets[0].getsockname()[1]
monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://localhost:{port}')
async with (
_platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected),
ApifyEventManager(Configuration.get_global_configuration()),
):
await client_connected.wait()

async with ApifyEventManager(Configuration.get_global_configuration()):
# Send malformed message
websockets.broadcast(connected_ws_clients, 'this is not valid json{{{')
await asyncio.sleep(0.2)
# Send malformed message
websockets.broadcast(connected_ws_clients, 'this is not valid json{{{')
await asyncio.sleep(0.2)

assert 'Cannot parse Actor event' in caplog.text
assert 'Cannot parse Actor event' in caplog.text