diff --git a/tests/unit/_autoscaling/test_autoscaled_pool.py b/tests/unit/_autoscaling/test_autoscaled_pool.py index 74a3a75c60..d2cb0c869a 100644 --- a/tests/unit/_autoscaling/test_autoscaled_pool.py +++ b/tests/unit/_autoscaling/test_autoscaled_pool.py @@ -4,7 +4,7 @@ import asyncio from contextlib import suppress -from datetime import datetime, timedelta, timezone +from datetime import timedelta from itertools import chain, repeat from typing import TYPE_CHECKING, TypeVar, cast from unittest.mock import Mock @@ -17,7 +17,7 @@ from crawlee._utils.time import measure_time if TYPE_CHECKING: - from collections.abc import Awaitable + from collections.abc import Awaitable, Callable @pytest.fixture @@ -28,6 +28,16 @@ def system_status() -> SystemStatus | Mock: T = TypeVar('T') +async def _wait_for(condition: Callable[[], bool], *, timeout: float = 5.0, poll_interval: float = 0.05) -> None: + """Poll ``condition`` until it returns True, or raise ``AssertionError`` on timeout.""" + deadline = asyncio.get_event_loop().time() + timeout + while asyncio.get_event_loop().time() < deadline: + if condition(): + return + await asyncio.sleep(poll_interval) + raise AssertionError(f'Condition not met within {timeout}s') + + def future(value: T, /) -> Awaitable[T]: f = asyncio.Future[T]() f.set_result(value) @@ -145,10 +155,6 @@ async def run() -> None: await pool.run() -@pytest.mark.flaky( - rerun=3, - reason='Test is flaky on Windows and MacOS, see https://github.com/apify/crawlee-python/issues/1655.', -) async def test_autoscales( monkeypatch: pytest.MonkeyPatch, system_status: SystemStatus | Mock, @@ -160,7 +166,7 @@ async def run() -> None: nonlocal done_count done_count += 1 - start = datetime.now(timezone.utc) + overload_active = False def get_historical_system_info() -> SystemInfo: result = SystemInfo( @@ -170,8 +176,7 @@ def get_historical_system_info() -> SystemInfo: client_info=LoadRatioInfo(limit_ratio=0.9, actual_ratio=0.3), ) - # 0.5 seconds after the start of the test, pretend the CPU became overloaded - if result.created_at - start >= timedelta(seconds=0.5): + if overload_active: result.cpu_info = LoadRatioInfo(limit_ratio=0.9, actual_ratio=1.0) return result @@ -196,24 +201,21 @@ def get_historical_system_info() -> SystemInfo: pool_run_task = asyncio.create_task(pool.run(), name='pool run task') try: - # After 0.2s, there should be an increase in concurrency - await asyncio.sleep(0.2) - assert pool.desired_concurrency > 1 + # Wait until concurrency scales up above 1. + await _wait_for(lambda: pool.desired_concurrency > 1, timeout=5.0) - # After 0.5s, the concurrency should reach max concurrency - await asyncio.sleep(0.3) - assert pool.desired_concurrency == 4 + # Wait until concurrency reaches maximum. + await _wait_for(lambda: pool.desired_concurrency == 4, timeout=5.0) - # The concurrency should guarantee completion of more than 10 tasks (a single worker would complete ~5) - assert done_count > 10 + # Multiple concurrent workers should have completed more tasks than a single worker could. + await _wait_for(lambda: done_count > 10, timeout=5.0) - # After 0.7s, the pretend overload should have kicked in and there should be a drop in desired concurrency - await asyncio.sleep(0.2) - assert pool.desired_concurrency < 4 + # Simulate CPU overload and wait for the pool to scale down. + overload_active = True + await _wait_for(lambda: pool.desired_concurrency < 4, timeout=5.0) - # After a full second, the pool should scale down all the way to 1 - await asyncio.sleep(0.3) - assert pool.desired_concurrency == 1 + # Wait until the pool scales all the way down to minimum. + await _wait_for(lambda: pool.desired_concurrency == 1, timeout=5.0) finally: pool_run_task.cancel() with suppress(asyncio.CancelledError):