diff --git a/cognite/client/_api/data_modeling/records.py b/cognite/client/_api/data_modeling/records.py index 3d7591d727..fdc319d885 100644 --- a/cognite/client/_api/data_modeling/records.py +++ b/cognite/client/_api/data_modeling/records.py @@ -2,11 +2,11 @@ import asyncio from collections.abc import Sequence -from typing import TYPE_CHECKING, ClassVar, Literal +from typing import TYPE_CHECKING, Literal from cognite.client._api_client import APIClient from cognite.client.data_classes.data_modeling.records import RecordId, RecordIdSequence, RecordWrite -from cognite.client.utils._concurrency import RecordsConcurrencyOperation +from cognite.client.utils._concurrency import HierarchicalBoundedSemaphore, RecordsConcurrencyOperation from cognite.client.utils._experimental import FeaturePreviewWarning from cognite.client.utils._url import interpolate_and_url_encode @@ -14,6 +14,8 @@ from cognite.client import AsyncCogniteClient from cognite.client.config import ClientConfig +StreamType = Literal["mutable", "immutable"] + class RecordsAPI(APIClient): def __init__(self, config: ClientConfig, api_version: str | None, cognite_client: AsyncCogniteClient) -> None: @@ -22,18 +24,42 @@ def __init__(self, config: ClientConfig, api_version: str | None, cognite_client api_maturity="General Availability", sdk_maturity="alpha", feature_name="Records" ) - _OPERATION_TO_RATE_LIMIT: ClassVar[dict[str, RecordsConcurrencyOperation]] = { - "write": RecordsConcurrencyOperation.WRITE, - "delete": RecordsConcurrencyOperation.WRITE, - } - - def _get_semaphore(self, operation: Literal["write", "delete"]) -> asyncio.BoundedSemaphore: + def _get_semaphore( # type: ignore[override] + self, + operation: Literal["write", "delete", "query", "retrieve", "aggregate"], + stream_type: StreamType = "immutable", + ) -> asyncio.BoundedSemaphore | HierarchicalBoundedSemaphore: from cognite.client import global_config - return global_config.concurrency_settings.records._semaphore_factory( - self._OPERATION_TO_RATE_LIMIT[operation], project=self._cognite_client.config.project + write_op = RecordsConcurrencyOperation.WRITE + query_op = ( + RecordsConcurrencyOperation.QUERY_MUTABLE + if stream_type == "mutable" + else RecordsConcurrencyOperation.QUERY_IMMUTABLE ) + factory = global_config.concurrency_settings.records._semaphore_factory + project = self._cognite_client.config.project + match operation: + case "write" | "delete": + return factory(write_op, project) + case "query": + return factory(query_op, project) + case "retrieve": + dedicated_op = ( + RecordsConcurrencyOperation.RETRIEVE_MUTABLE + if stream_type == "mutable" + else RecordsConcurrencyOperation.RETRIEVE_IMMUTABLE + ) + return HierarchicalBoundedSemaphore(factory(dedicated_op, project), factory(query_op, project)) + case "aggregate": + dedicated_op = ( + RecordsConcurrencyOperation.AGGREGATE_MUTABLE + if stream_type == "mutable" + else RecordsConcurrencyOperation.AGGREGATE_IMMUTABLE + ) + return HierarchicalBoundedSemaphore(factory(dedicated_op, project), factory(query_op, project)) + def _records_url(self, stream_id: str, suffix: str = "") -> str: # Encode only stream_id; the suffix is a literal path segment (e.g. "/upsert"), # so it must not be percent-encoded. diff --git a/cognite/client/_sync_api/data_modeling/records.py b/cognite/client/_sync_api/data_modeling/records.py index 9897313c9d..26a3eb4d7e 100644 --- a/cognite/client/_sync_api/data_modeling/records.py +++ b/cognite/client/_sync_api/data_modeling/records.py @@ -1,6 +1,6 @@ """ =============================================================================== -f86364d61385123f12bc60dd004ea1c2 +5920bce88870da17ef2034ac258b62ac This file is auto-generated from the Async API modules, - do not edit manually! =============================================================================== """ @@ -18,6 +18,8 @@ if TYPE_CHECKING: from cognite.client import AsyncCogniteClient +StreamType = Literal["mutable", "immutable"] + class SyncRecordsAPI(SyncAPIClient): """Auto-generated, do not modify manually.""" diff --git a/cognite/client/utils/_concurrency.py b/cognite/client/utils/_concurrency.py index b5df86a34b..60244b9d42 100644 --- a/cognite/client/utils/_concurrency.py +++ b/cognite/client/utils/_concurrency.py @@ -222,26 +222,172 @@ def __repr__(self) -> str: ) +class HierarchicalBoundedSemaphore: + """Acquires multiple semaphores in order, releases in reverse. + + Used to model the Records API's hierarchical rate limits where an endpoint + (e.g. retrieve) must pass both its dedicated budget and the shared query budget. + + If acquisition is interrupted (e.g. by cancellation), all already-acquired + semaphores are released before the exception propagates. Similarly, if a + release raises, the remaining semaphores are still released. + """ + + def __init__(self, *semaphores: asyncio.BoundedSemaphore) -> None: + self._semaphores = semaphores + + async def __aenter__(self) -> None: + acquired: list[asyncio.BoundedSemaphore] = [] + try: + for sem in self._semaphores: + await sem.__aenter__() + acquired.append(sem) + except BaseException: + for sem in reversed(acquired): + await sem.__aexit__(None, None, None) + raise + + async def __aexit__(self, *exc: Any) -> None: + first_err: BaseException | None = None + for sem in reversed(self._semaphores): + try: + await sem.__aexit__(*exc) + except BaseException as e: + if first_err is None: + first_err = e + if first_err is not None: + raise first_err + + class RecordsConcurrencyOperation(Enum): WRITE = "write" + QUERY_MUTABLE = "query_mutable" + QUERY_IMMUTABLE = "query_immutable" + RETRIEVE_MUTABLE = "retrieve_mutable" + RETRIEVE_IMMUTABLE = "retrieve_immutable" + AGGREGATE_MUTABLE = "aggregate_mutable" + AGGREGATE_IMMUTABLE = "aggregate_immutable" class RecordsGlobalConcurrencyConfig(ConcurrencyConfig): """ - Global concurrency settings for the Records API. Named "global" to distinguish from - future per-endpoint rate limits that may be added later. + Global concurrency settings for the Records API. + + The Records API has separate rate-limit budgets for reads and writes, and read budgets + differ between mutable and immutable streams. Read budgets are hierarchical: the + retrieve and aggregate endpoints each have a dedicated budget that is checked *before* + the shared query budget (both must pass). + + - **write**: Shared across ingest, upsert, and delete (same for both stream types). + - **query_mutable / query_immutable**: Shared read budget for all query endpoints. + - **retrieve_mutable / retrieve_immutable**: Dedicated retrieve budget (+ shared query). + - **aggregate_mutable / aggregate_immutable**: Dedicated aggregate budget (+ shared query). Args: concurrency_settings (ConcurrencySettings): Reference to the parent settings object. - write (int): Maximum concurrent write requests (ingest, delete). + write (int): Maximum concurrent write requests (ingest, upsert, delete). + query_mutable (int): Maximum concurrent query requests against mutable streams. + query_immutable (int): Maximum concurrent query requests against immutable streams. + retrieve_mutable (int): Dedicated retrieve concurrency for mutable streams. + retrieve_immutable (int): Dedicated retrieve concurrency for immutable streams. + aggregate_mutable (int): Dedicated aggregate concurrency for mutable streams. + aggregate_immutable (int): Dedicated aggregate concurrency for immutable streams. """ def __init__( self, concurrency_settings: ConcurrencySettings, write: int, + query_mutable: int, + query_immutable: int, + retrieve_mutable: int, + retrieve_immutable: int, + aggregate_mutable: int, + aggregate_immutable: int, ) -> None: super().__init__(concurrency_settings, "records", read=0, write=write, delete=0) + self._query_mutable = query_mutable + self._query_immutable = query_immutable + self._retrieve_mutable = retrieve_mutable + self._retrieve_immutable = retrieve_immutable + self._aggregate_mutable = aggregate_mutable + self._aggregate_immutable = aggregate_immutable + self._validate_budgets() + + def _validate_budgets(self, **overrides: int) -> None: + for name in overrides: + self._check_frozen(name) + + def resolve(name: str) -> int: + return overrides.get(name, getattr(self, f"_{name}")) + + for dedicated_name, shared_name in [ + ("retrieve_mutable", "query_mutable"), + ("retrieve_immutable", "query_immutable"), + ("aggregate_mutable", "query_mutable"), + ("aggregate_immutable", "query_immutable"), + ]: + dedicated = resolve(dedicated_name) + shared = resolve(shared_name) + if dedicated > shared: + raise ValueError( + f"Dedicated budget must be <= shared query budget " + f"({dedicated_name} vs {shared_name}): {dedicated} > {shared}" + ) + + @property + def query_mutable(self) -> int: + return self._query_mutable + + @query_mutable.setter + def query_mutable(self, value: int) -> None: + self._validate_budgets(query_mutable=value) + self._query_mutable = value + + @property + def query_immutable(self) -> int: + return self._query_immutable + + @query_immutable.setter + def query_immutable(self, value: int) -> None: + self._validate_budgets(query_immutable=value) + self._query_immutable = value + + @property + def retrieve_mutable(self) -> int: + return self._retrieve_mutable + + @retrieve_mutable.setter + def retrieve_mutable(self, value: int) -> None: + self._validate_budgets(retrieve_mutable=value) + self._retrieve_mutable = value + + @property + def retrieve_immutable(self) -> int: + return self._retrieve_immutable + + @retrieve_immutable.setter + def retrieve_immutable(self, value: int) -> None: + self._validate_budgets(retrieve_immutable=value) + self._retrieve_immutable = value + + @property + def aggregate_mutable(self) -> int: + return self._aggregate_mutable + + @aggregate_mutable.setter + def aggregate_mutable(self, value: int) -> None: + self._validate_budgets(aggregate_mutable=value) + self._aggregate_mutable = value + + @property + def aggregate_immutable(self) -> int: + return self._aggregate_immutable + + @aggregate_immutable.setter + def aggregate_immutable(self, value: int) -> None: + self._validate_budgets(aggregate_immutable=value) + self._aggregate_immutable = value def _semaphore_factory(self, operation: RecordsConcurrencyOperation, project: str) -> asyncio.BoundedSemaphore: key = (operation.value, project, asyncio.get_running_loop()) @@ -254,13 +400,31 @@ def _semaphore_factory(self, operation: RecordsConcurrencyOperation, project: st match operation: case RecordsConcurrencyOperation.WRITE: sem = asyncio.BoundedSemaphore(self._write) + case RecordsConcurrencyOperation.QUERY_MUTABLE: + sem = asyncio.BoundedSemaphore(self._query_mutable) + case RecordsConcurrencyOperation.QUERY_IMMUTABLE: + sem = asyncio.BoundedSemaphore(self._query_immutable) + case RecordsConcurrencyOperation.RETRIEVE_MUTABLE: + sem = asyncio.BoundedSemaphore(self._retrieve_mutable) + case RecordsConcurrencyOperation.RETRIEVE_IMMUTABLE: + sem = asyncio.BoundedSemaphore(self._retrieve_immutable) + case RecordsConcurrencyOperation.AGGREGATE_MUTABLE: + sem = asyncio.BoundedSemaphore(self._aggregate_mutable) + case RecordsConcurrencyOperation.AGGREGATE_IMMUTABLE: + sem = asyncio.BoundedSemaphore(self._aggregate_immutable) case _: assert_never(operation) self._semaphore_cache[key] = sem return sem def __repr__(self) -> str: - return f"Concurrency[records](write={self._write})" + return ( + f"Concurrency[records](" + f"write={self._write}, " + f"query_mutable={self._query_mutable}, query_immutable={self._query_immutable}, " + f"retrieve_mutable={self._retrieve_mutable}, retrieve_immutable={self._retrieve_immutable}, " + f"aggregate_mutable={self._aggregate_mutable}, aggregate_immutable={self._aggregate_immutable})" + ) class FileConcurrencyConfig(ConcurrencyConfig): @@ -425,7 +589,16 @@ def __init__(self) -> None: write_schema=1, ) self._files = FileConcurrencyConfig(self, read=4, write=2, upload=5, download=5, delete=2, open_files=15) - self._records = RecordsGlobalConcurrencyConfig(self, write=20) + self._records = RecordsGlobalConcurrencyConfig( + self, + write=20, + query_mutable=30, + query_immutable=10, + retrieve_mutable=20, + retrieve_immutable=10, + aggregate_mutable=10, + aggregate_immutable=5, + ) @functools.cached_property def _all_concurrency_configs(self) -> list[ConcurrencyConfig]: diff --git a/tests/tests_unit/test_utils/test_concurrency.py b/tests/tests_unit/test_utils/test_concurrency.py index c817a5862c..37da6bf255 100644 --- a/tests/tests_unit/test_utils/test_concurrency.py +++ b/tests/tests_unit/test_utils/test_concurrency.py @@ -7,13 +7,15 @@ import pytest -from cognite.client import global_config +from cognite.client import AsyncCogniteClient, global_config from cognite.client.exceptions import CogniteAPIError from cognite.client.utils._concurrency import ( AsyncSDKTask, ConcurrencyConfig, ConcurrencySettings, EventLoopThreadExecutor, + HierarchicalBoundedSemaphore, + RecordsConcurrencyOperation, _get_event_loop_executor, execute_async_tasks, ) @@ -136,6 +138,516 @@ async def test_invalid_operation_hits_assert_never(self) -> None: self.cs.general._semaphore_factory("totally_invalid", "proj-a") # type: ignore[arg-type] +class TestRecordsConcurrencyConfig: + def test_defaults(self) -> None: + cs = ConcurrencySettings() + assert cs.records.write == 20 + assert cs.records.query_mutable == 30 + assert cs.records.query_immutable == 10 + assert cs.records.retrieve_mutable == 20 + assert cs.records.retrieve_immutable == 10 + assert cs.records.aggregate_mutable == 10 + assert cs.records.aggregate_immutable == 5 + + def test_setters_work_before_freeze(self) -> None: + cs = ConcurrencySettings() + cs.records.write = 10 + cs.records.retrieve_mutable = 12 + cs.records.retrieve_immutable = 4 + cs.records.aggregate_mutable = 8 + cs.records.aggregate_immutable = 3 + cs.records.query_mutable = 15 + cs.records.query_immutable = 5 + assert cs.records.write == 10 + assert cs.records.query_mutable == 15 + assert cs.records.query_immutable == 5 + assert cs.records.retrieve_mutable == 12 + assert cs.records.retrieve_immutable == 4 + assert cs.records.aggregate_mutable == 8 + assert cs.records.aggregate_immutable == 3 + + @pytest.mark.parametrize( + "attr", + [ + "write", + "query_mutable", + "query_immutable", + "retrieve_mutable", + "retrieve_immutable", + "aggregate_mutable", + "aggregate_immutable", + ], + ) + def test_setter_raises_after_freeze(self, attr: str) -> None: + cs = ConcurrencySettings() + cs._freeze() + with pytest.raises(RuntimeError, match="Cannot modify"): + setattr(cs.records, attr, 1) + + def test_repr(self) -> None: + cs = ConcurrencySettings() + r = repr(cs.records) + assert "write=20" in r + assert "query_mutable=30" in r + assert "query_immutable=10" in r + assert "retrieve_mutable=20" in r + assert "retrieve_immutable=10" in r + assert "aggregate_mutable=10" in r + assert "aggregate_immutable=5" in r + + @pytest.mark.parametrize( + "dedicated, shared", + [ + ("retrieve_mutable", "query_mutable"), + ("retrieve_immutable", "query_immutable"), + ("aggregate_mutable", "query_mutable"), + ("aggregate_immutable", "query_immutable"), + ], + ) + def test_dedicated_exceeding_shared_raises_on_init(self, dedicated: str, shared: str) -> None: + cs = ConcurrencySettings() + defaults = { + "write": 20, + "query_mutable": 30, + "query_immutable": 10, + "retrieve_mutable": 20, + "retrieve_immutable": 10, + "aggregate_mutable": 10, + "aggregate_immutable": 5, + } + shared_val = defaults[shared] + defaults[dedicated] = shared_val + 1 + from cognite.client.utils._concurrency import RecordsGlobalConcurrencyConfig + + with pytest.raises(ValueError, match="Dedicated budget must be <= shared query budget"): + RecordsGlobalConcurrencyConfig(cs, **defaults) + + @pytest.mark.parametrize( + "dedicated, shared", + [ + ("retrieve_mutable", "query_mutable"), + ("retrieve_immutable", "query_immutable"), + ("aggregate_mutable", "query_mutable"), + ("aggregate_immutable", "query_immutable"), + ], + ) + def test_dedicated_exceeding_shared_raises_on_setter(self, dedicated: str, shared: str) -> None: + cs = ConcurrencySettings() + shared_val = getattr(cs.records, shared) + with pytest.raises(ValueError, match="Dedicated budget must be <= shared query budget"): + setattr(cs.records, dedicated, shared_val + 1) + + def test_lowering_shared_below_dedicated_raises(self) -> None: + cs = ConcurrencySettings() + with pytest.raises(ValueError, match="Dedicated budget must be <= shared query budget"): + cs.records.query_mutable = 5 # retrieve_mutable=20 > 5 + + def test_dedicated_equal_to_shared_is_valid(self) -> None: + cs = ConcurrencySettings() + cs.records.retrieve_mutable = 30 # equal to query_mutable=30, should be fine + assert cs.records.retrieve_mutable == 30 + + +@pytest.mark.usefixtures("fresh_unfrozen_global_concurrency") +class TestRecordsSemaphoreFactory: + cs: ClassVar[ConcurrencySettings] = global_config.concurrency_settings + + @pytest.mark.parametrize( + "operation, expected_value", + [ + (RecordsConcurrencyOperation.WRITE, 20), + (RecordsConcurrencyOperation.QUERY_MUTABLE, 30), + (RecordsConcurrencyOperation.QUERY_IMMUTABLE, 10), + (RecordsConcurrencyOperation.RETRIEVE_MUTABLE, 20), + (RecordsConcurrencyOperation.RETRIEVE_IMMUTABLE, 10), + (RecordsConcurrencyOperation.AGGREGATE_MUTABLE, 10), + (RecordsConcurrencyOperation.AGGREGATE_IMMUTABLE, 5), + ], + ) + async def test_semaphore_values(self, operation: RecordsConcurrencyOperation, expected_value: int) -> None: + sem = self.cs.records._semaphore_factory(operation, "proj-a") + assert sem._value == expected_value + + async def test_all_operations_produce_distinct_semaphores(self) -> None: + sems = {op: self.cs.records._semaphore_factory(op, "proj-a") for op in RecordsConcurrencyOperation} + assert len(set(id(s) for s in sems.values())) == len(RecordsConcurrencyOperation) + + async def test_cache_hit(self) -> None: + sem1 = self.cs.records._semaphore_factory(RecordsConcurrencyOperation.QUERY_MUTABLE, "proj-a") + sem2 = self.cs.records._semaphore_factory(RecordsConcurrencyOperation.QUERY_MUTABLE, "proj-a") + assert sem1 is sem2 + + async def test_different_project_different_semaphore(self) -> None: + sem_a = self.cs.records._semaphore_factory(RecordsConcurrencyOperation.QUERY_MUTABLE, "proj-a") + sem_b = self.cs.records._semaphore_factory(RecordsConcurrencyOperation.QUERY_MUTABLE, "proj-b") + assert sem_a is not sem_b + + +class TestHierarchicalBoundedSemaphore: + async def test_acquires_both_semaphores(self) -> None: + outer = asyncio.BoundedSemaphore(2) + inner = asyncio.BoundedSemaphore(3) + h = HierarchicalBoundedSemaphore(outer, inner) + async with h: + assert outer._value == 1 + assert inner._value == 2 + + async def test_releases_both_on_exit(self) -> None: + outer = asyncio.BoundedSemaphore(1) + inner = asyncio.BoundedSemaphore(1) + h = HierarchicalBoundedSemaphore(outer, inner) + async with h: + pass + assert outer._value == 1 + assert inner._value == 1 + + async def test_releases_on_exception(self) -> None: + outer = asyncio.BoundedSemaphore(1) + inner = asyncio.BoundedSemaphore(1) + h = HierarchicalBoundedSemaphore(outer, inner) + with pytest.raises(ValueError): + async with h: + raise ValueError("boom") + assert outer._value == 1 + assert inner._value == 1 + + async def test_limits_concurrency_to_min(self) -> None: + dedicated = asyncio.BoundedSemaphore(2) + query = asyncio.BoundedSemaphore(5) + entered = asyncio.Event() + hold = asyncio.Event() + + async def worker() -> None: + async with HierarchicalBoundedSemaphore(dedicated, query): + entered.set() + await hold.wait() + + tasks = [asyncio.create_task(worker()) for _ in range(3)] + await asyncio.sleep(0.01) + assert dedicated._value == 0 + assert query._value == 3 + hold.set() + await asyncio.gather(*tasks) + + +class TestHierarchicalBoundedSemaphoreAdversarial: + """Adversarial tests targeting real failure modes in HierarchicalBoundedSemaphore. + + Two confirmed bugs are documented in the tests below: + + BUG-1 (semaphore leak on cancellation): When a task is cancelled while + __aenter__ is blocked waiting on the second semaphore, the first semaphore + has already been acquired but __aexit__ is never called, so it leaks. + + BUG-2 (incomplete release on mid-exit exception): When __aexit__ iterates + in reversed order and one semaphore's release raises, the remaining + semaphores (earlier in the list) are never released. + """ + + # --- BUG-1: Cancellation during acquisition leaks already-acquired semaphores --- + + async def test_bug1_cancellation_while_waiting_on_second_semaphore_leaks_first(self) -> None: + """BUG: If cancelled between acquiring sem[0] and sem[1], sem[0] is never released. + + Root cause: __aenter__ acquires semaphores in a plain for-loop with no + try/except around individual acquisitions. A CancelledError raised inside + sem[1].__aenter__() (while it is blocked) unwinds the coroutine without + giving __aexit__ a chance to run, so sem[0] stays acquired forever. + """ + dedicated = asyncio.BoundedSemaphore(1) + query = asyncio.BoundedSemaphore(0) # permanently blocked + h = HierarchicalBoundedSemaphore(dedicated, query) + + async def worker() -> None: + async with h: + pass + + task = asyncio.create_task(worker()) + await asyncio.sleep(0.02) # let it acquire dedicated and block on query + + assert dedicated._value == 0, "dedicated should be held at this point" + + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # BUG: dedicated._value is 0, not 1 — it was never released + assert dedicated._value == 1, ( + "BUG-1: dedicated semaphore leaked after cancellation. " + "sem[0] was acquired by __aenter__ but CancelledError prevented __aexit__ from running." + ) + + async def test_bug1_two_tasks_cancelled_both_leak(self) -> None: + """BUG: Each cancelled task leaks one slot; with two tasks both slots are gone.""" + dedicated = asyncio.BoundedSemaphore(2) + query = asyncio.BoundedSemaphore(0) # permanently blocked + h = HierarchicalBoundedSemaphore(dedicated, query) + + async def worker() -> None: + async with h: + pass + + t1 = asyncio.create_task(worker()) + t2 = asyncio.create_task(worker()) + await asyncio.sleep(0.02) + + t1.cancel() + t2.cancel() + await asyncio.gather(t1, t2, return_exceptions=True) + + # BUG: both slots leaked — dedicated is exhausted even though no work was done + assert dedicated._value == 2, ( + "BUG-1: both dedicated slots leaked; subsequent real work can never acquire the semaphore." + ) + + async def test_bug1_wait_for_timeout_leaks_first_semaphore(self) -> None: + """BUG: asyncio.wait_for cancels the coroutine on timeout, triggering the same leak.""" + dedicated = asyncio.BoundedSemaphore(1) + query = asyncio.BoundedSemaphore(0) # permanently blocked + h = HierarchicalBoundedSemaphore(dedicated, query) + + async def worker() -> None: + async with h: + pass + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(worker(), timeout=0.05) + + # BUG: dedicated._value remains 0 after the timeout + assert dedicated._value == 1, ( + "BUG-1: dedicated semaphore leaked after asyncio.wait_for timeout. " + "Timeout internally cancels the task, hitting the same code path." + ) + + async def test_bug1_cancellation_with_three_semaphores_leaks_two(self) -> None: + """BUG: With three semaphores, cancellation while waiting on sem[2] leaks both sem[0] and sem[1].""" + s0 = asyncio.BoundedSemaphore(1) + s1 = asyncio.BoundedSemaphore(1) + s2 = asyncio.BoundedSemaphore(0) # permanently blocked + h = HierarchicalBoundedSemaphore(s0, s1, s2) + + async def worker() -> None: + async with h: + pass + + task = asyncio.create_task(worker()) + await asyncio.sleep(0.02) + + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + assert s0._value == 1, "BUG-1: s0 leaked (first semaphore)" + assert s1._value == 1, "BUG-1: s1 leaked (second semaphore)" + + # --- BUG-2: Exception in __aexit__ of one semaphore skips releasing earlier ones --- + + async def test_bug2_exception_in_middle_release_skips_earlier_releases(self) -> None: + """BUG: If releasing sem[1] raises, sem[0] is never released. + + __aexit__ iterates with a plain for-loop over reversed semaphores. + An exception from any intermediate release propagates immediately, + abandoning all remaining release calls. + """ + + class BoomSemaphore: + """Always acquires fine; always raises on release.""" + + async def __aenter__(self) -> None: + pass + + async def __aexit__(self, *exc: object) -> None: + raise RuntimeError("boom on release") + + sem0 = asyncio.BoundedSemaphore(1) + sem1 = BoomSemaphore() + sem2 = asyncio.BoundedSemaphore(1) + # Acquire order: sem0, sem1, sem2 + # Release order (reversed): sem2, sem1 (boom!), sem0 — sem0 is never reached + h = HierarchicalBoundedSemaphore(sem0, sem1, sem2) # type: ignore[arg-type] + + with pytest.raises(RuntimeError, match="boom on release"): + async with h: + pass + + assert sem0._value == 1, ( + "BUG-2: sem0 was not released because the exception from sem1's __aexit__ " + "aborted the release loop before sem0's turn." + ) + assert sem2._value == 1, "sem2 (released before the bomb) should be fine" + + async def test_bug2_exception_in_first_release_skips_all_remaining(self) -> None: + """BUG: Exception from the first release (last-acquired semaphore) skips all others.""" + + class BoomSemaphore: + async def __aenter__(self) -> None: + pass + + async def __aexit__(self, *exc: object) -> None: + raise RuntimeError("boom") + + sem0 = asyncio.BoundedSemaphore(1) + sem1 = asyncio.BoundedSemaphore(1) + sem2 = BoomSemaphore() # last acquired = first released = first to explode + h = HierarchicalBoundedSemaphore(sem0, sem1, sem2) # type: ignore[arg-type] + + with pytest.raises(RuntimeError, match="boom"): + async with h: + pass + + assert sem0._value == 1, "BUG-2: sem0 leaked because release loop aborted at sem2" + assert sem1._value == 1, "BUG-2: sem1 leaked because release loop aborted at sem2" + + # --- Non-bug adversarial cases (expected to pass) --- + + async def test_zero_semaphores_is_a_noop(self) -> None: + """Edge case: constructing with no semaphores should be a transparent noop.""" + h = HierarchicalBoundedSemaphore() + async with h: + pass # should not raise or block + + async def test_single_semaphore_behaves_like_plain_async_with(self) -> None: + sem = asyncio.BoundedSemaphore(1) + h = HierarchicalBoundedSemaphore(sem) + async with h: + assert sem._value == 0 + assert sem._value == 1 + + async def test_nested_usage_deadlocks_with_value_one(self) -> None: + """Nested async with on the same HierarchicalBoundedSemaphore(value=1) deadlocks. + + This is expected — BoundedSemaphore is not reentrant. The test confirms + that the implementation does NOT protect against reentrancy. + """ + sem = asyncio.BoundedSemaphore(1) + h = HierarchicalBoundedSemaphore(sem) + + inner_started = False + + async def nested_worker() -> None: + nonlocal inner_started + async with h: + inner_started = True + # Attempt reentrant acquire — will deadlock + async with h: + pass # unreachable + + task = asyncio.create_task(nested_worker()) + await asyncio.sleep(0.05) + + assert inner_started, "outer context should have been entered" + assert not task.done(), "task should still be blocked waiting on inner acquire" + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + async def test_release_on_exception_inside_context_body_both_semaphores_freed(self) -> None: + """Normal exception inside the body (not in __aexit__) must release all semaphores.""" + sem0 = asyncio.BoundedSemaphore(1) + sem1 = asyncio.BoundedSemaphore(1) + h = HierarchicalBoundedSemaphore(sem0, sem1) + + with pytest.raises(ValueError, match="body exception"): + async with h: + raise ValueError("body exception") + + assert sem0._value == 1 + assert sem1._value == 1 + + async def test_concurrent_retrieve_and_list_complete_without_deadlock(self) -> None: + """retrieve (HierarchicalBoundedSemaphore) and list (plain semaphore) sharing + the query semaphore must not deadlock each other.""" + dedicated = asyncio.BoundedSemaphore(2) + query = asyncio.BoundedSemaphore(3) + h_retrieve = HierarchicalBoundedSemaphore(dedicated, query) + + completed: list[str] = [] + + async def retrieve(name: str) -> None: + async with h_retrieve: + await asyncio.sleep(0.01) + completed.append(f"retrieve_{name}") + + async def list_op(name: str) -> None: + async with query: + await asyncio.sleep(0.01) + completed.append(f"list_{name}") + + await asyncio.gather( + asyncio.create_task(retrieve("A")), + asyncio.create_task(retrieve("B")), + asyncio.create_task(list_op("C")), + asyncio.create_task(list_op("D")), + asyncio.create_task(list_op("E")), + ) + + assert len(completed) == 5 + assert all(name in completed for name in ["retrieve_A", "retrieve_B", "list_C", "list_D", "list_E"]) + + async def test_semaphores_fully_restored_after_many_sequential_uses(self) -> None: + """Semaphore values must be fully restored after many normal acquire/release cycles.""" + sem0 = asyncio.BoundedSemaphore(3) + sem1 = asyncio.BoundedSemaphore(5) + h = HierarchicalBoundedSemaphore(sem0, sem1) + + for _ in range(20): + async with h: + pass + + assert sem0._value == 3 + assert sem1._value == 5 + + +@pytest.mark.usefixtures("fresh_unfrozen_global_concurrency") +class TestRecordsGetSemaphore: + """Tests that RecordsAPI._get_semaphore returns the right semaphore type and composition.""" + + async def test_write_returns_plain_semaphore(self, async_client: AsyncCogniteClient) -> None: + sem = async_client.data_modeling.records._get_semaphore("write") + assert isinstance(sem, asyncio.BoundedSemaphore) + + async def test_delete_returns_plain_semaphore(self, async_client: AsyncCogniteClient) -> None: + sem = async_client.data_modeling.records._get_semaphore("delete") + assert isinstance(sem, asyncio.BoundedSemaphore) + + async def test_query_returns_plain_semaphore(self, async_client: AsyncCogniteClient) -> None: + sem = async_client.data_modeling.records._get_semaphore("query", "mutable") + assert isinstance(sem, asyncio.BoundedSemaphore) + + async def test_retrieve_returns_hierarchical(self, async_client: AsyncCogniteClient) -> None: + sem = async_client.data_modeling.records._get_semaphore("retrieve", "mutable") + assert isinstance(sem, HierarchicalBoundedSemaphore) + + async def test_aggregate_returns_hierarchical(self, async_client: AsyncCogniteClient) -> None: + sem = async_client.data_modeling.records._get_semaphore("aggregate", "immutable") + assert isinstance(sem, HierarchicalBoundedSemaphore) + + async def test_retrieve_hierarchical_wraps_correct_semaphores(self, async_client: AsyncCogniteClient) -> None: + sem = async_client.data_modeling.records._get_semaphore("retrieve", "mutable") + assert isinstance(sem, HierarchicalBoundedSemaphore) + dedicated, query = sem._semaphores + assert dedicated._value == 20 # retrieve_mutable default + assert query._value == 30 # query_mutable default + + async def test_aggregate_immutable_wraps_correct_semaphores(self, async_client: AsyncCogniteClient) -> None: + sem = async_client.data_modeling.records._get_semaphore("aggregate", "immutable") + assert isinstance(sem, HierarchicalBoundedSemaphore) + dedicated, query = sem._semaphores + assert dedicated._value == 5 # aggregate_immutable default + assert query._value == 10 # query_immutable default + + async def test_retrieve_and_query_share_query_semaphore(self, async_client: AsyncCogniteClient) -> None: + retrieve_sem = async_client.data_modeling.records._get_semaphore("retrieve", "mutable") + query_sem = async_client.data_modeling.records._get_semaphore("query", "mutable") + assert isinstance(retrieve_sem, HierarchicalBoundedSemaphore) + assert retrieve_sem._semaphores[1] is query_sem + + async def i_dont_like_5(i: int) -> int: if i < 5: return i diff --git a/tests/tests_unit/test_utils/test_concurrency_api_routing.py b/tests/tests_unit/test_utils/test_concurrency_api_routing.py index 0c869de293..f4084928c3 100644 --- a/tests/tests_unit/test_utils/test_concurrency_api_routing.py +++ b/tests/tests_unit/test_utils/test_concurrency_api_routing.py @@ -9,15 +9,20 @@ from __future__ import annotations +import asyncio import re from collections.abc import Awaitable, Callable, Iterator from typing import Any +import httpx import pytest from pytest_httpx import HTTPXMock from cognite.client import AsyncCogniteClient from cognite.client.data_classes.data_modeling.ids import NodeId +from cognite.client.data_classes.data_modeling.records import RecordId, RecordWrite +from cognite.client.exceptions import CogniteAPIError +from cognite.client.utils._concurrency import HierarchicalBoundedSemaphore from tests.utils import fresh_concurrency_state SemCall = tuple[str, str, str] # (sub_config_name (eg 'general'), operation, project) @@ -152,6 +157,240 @@ async def test_raw_read( assert_routed(semaphore_spy, "raw", "read") +class TestRecordsSemaphoreRouting: + """Records API uses RecordsConcurrencyOperation enums (not plain strings), + so we spy on the enum values directly.""" + + @pytest.fixture + def records_spy(self, monkeypatch: pytest.MonkeyPatch) -> Iterator[list[tuple[str, str]]]: + calls: list[tuple[str, str]] = [] + with fresh_concurrency_state() as settings: + original = settings.records._semaphore_factory + + def spy(operation: Any, project: str) -> Any: + calls.append((operation.value, project)) + return original(operation, project) + + monkeypatch.setattr(settings.records, "_semaphore_factory", spy) + yield calls + + @pytest.mark.usefixtures("mock_any_request") + @pytest.mark.parametrize( + "api_call, expected_operation", + [ + pytest.param( + lambda c: c.data_modeling.records.ingest( + RecordWrite(space="sp", external_id="r1", sources=[]), + stream_id="s1", + ), + "write", + id="ingest_write", + ), + pytest.param( + lambda c: c.data_modeling.records.upsert( + RecordWrite(space="sp", external_id="r1", sources=[]), + stream_id="s1", + ), + "write", + id="upsert_write", + ), + pytest.param( + lambda c: c.data_modeling.records.delete( + RecordId(space="sp", external_id="r1"), + stream_id="s1", + ), + "write", + id="delete_write", + ), + ], + ) + async def test_write_routing( + self, + async_client: AsyncCogniteClient, + records_spy: list[tuple[str, str]], + api_call: ApiCall, + expected_operation: str, + ) -> None: + await api_call(async_client) + ops = [op for op, _ in records_spy] + assert expected_operation in ops, f"Expected {expected_operation!r} in {ops}" + assert async_client.config.project in {proj for _, proj in records_spy} + + +class TestHierarchicalSemaphoreThroughHTTPClient: + """Verify that HierarchicalBoundedSemaphore works through the real SDK HTTP chain. + + The type hints say asyncio.BoundedSemaphore, but at runtime the HTTP client + just does ``async with semaphore``. These tests confirm that a + HierarchicalBoundedSemaphore actually acquires and releases both inner + semaphores when passed through _post. + """ + + async def test_post_with_hierarchical_semaphore_succeeds( + self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response(method="POST", url=re.compile(r".*"), status_code=200, json={"items": []}) + dedicated = asyncio.BoundedSemaphore(1) + query = asyncio.BoundedSemaphore(1) + h = HierarchicalBoundedSemaphore(dedicated, query) + + await async_client.data_modeling.records._post( + url_path="/streams/test/records/filter", + json={"limit": 10}, + semaphore=h, + ) + assert dedicated._value == 1, "dedicated semaphore should be released after request" + assert query._value == 1, "query semaphore should be released after request" + + async def test_post_with_hierarchical_semaphore_releases_on_http_error( + self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response(method="POST", url=re.compile(r".*"), status_code=500, json={"error": {"message": "fail", "code": 500}}) + dedicated = asyncio.BoundedSemaphore(1) + query = asyncio.BoundedSemaphore(1) + h = HierarchicalBoundedSemaphore(dedicated, query) + + with pytest.raises(CogniteAPIError): + await async_client.data_modeling.records._post( + url_path="/streams/test/records/filter", + json={"limit": 10}, + semaphore=h, + ) + assert dedicated._value == 1, "dedicated semaphore should be released after error" + assert query._value == 1, "query semaphore should be released after error" + + async def test_hierarchical_semaphore_limits_concurrent_requests( + self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock + ) -> None: + hold = asyncio.Event() + + async def slow_response(request: Any) -> Any: + await hold.wait() + return httpx.Response(200, json={"items": []}) + + httpx_mock.add_callback(slow_response, method="POST", url=re.compile(r".*"), is_optional=True) + httpx_mock.add_callback(slow_response, method="POST", url=re.compile(r".*"), is_optional=True) + httpx_mock.add_callback(slow_response, method="POST", url=re.compile(r".*"), is_optional=True) + + dedicated = asyncio.BoundedSemaphore(1) + query = asyncio.BoundedSemaphore(2) + + async def make_request() -> None: + h = HierarchicalBoundedSemaphore(dedicated, query) + await async_client.data_modeling.records._post( + url_path="/streams/test/records/filter", + json={"limit": 10}, + semaphore=h, + ) + + t1 = asyncio.create_task(make_request()) + t2 = asyncio.create_task(make_request()) + await asyncio.sleep(0.05) + + assert dedicated._value == 0, "dedicated(1) should be fully consumed by first request" + assert query._value == 1, "query(2) should have 1 slot consumed" + + hold.set() + await asyncio.gather(t1, t2) + assert dedicated._value == 1 + assert query._value == 2 + + +class TestRecordsSemaphoreEndpointPatterns: + """Simulate the exact patterns that records endpoints will use: + - list/sync: _get_semaphore("query", stream_type) → override_semaphore in _list + - retrieve: _get_semaphore("retrieve", stream_type) → override_semaphore in _post (hierarchical) + - aggregate: _get_semaphore("aggregate", stream_type) → override_semaphore in _post (hierarchical) + """ + + @pytest.fixture(autouse=True) + def _fresh_state(self) -> Iterator[None]: + with fresh_concurrency_state(): + yield + + async def test_list_pattern_mutable(self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock) -> None: + httpx_mock.add_response(method="POST", url=re.compile(r".*"), status_code=200, json={"items": []}) + records_api = async_client.data_modeling.records + sem = records_api._get_semaphore("query", "mutable") + assert isinstance(sem, asyncio.BoundedSemaphore) + assert sem._value == 30 + + await records_api._post(url_path="/streams/s1/records/filter", json={"limit": 10}, semaphore=sem) + assert sem._value == 30 + + async def test_list_pattern_immutable(self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock) -> None: + httpx_mock.add_response(method="POST", url=re.compile(r".*"), status_code=200, json={"items": []}) + records_api = async_client.data_modeling.records + sem = records_api._get_semaphore("query", "immutable") + assert isinstance(sem, asyncio.BoundedSemaphore) + assert sem._value == 10 + + await records_api._post(url_path="/streams/s1/records/filter", json={"limit": 10}, semaphore=sem) + assert sem._value == 10 + + async def test_retrieve_pattern_mutable(self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock) -> None: + httpx_mock.add_response(method="POST", url=re.compile(r".*"), status_code=200, json={"items": []}) + records_api = async_client.data_modeling.records + sem = records_api._get_semaphore("retrieve", "mutable") + assert isinstance(sem, HierarchicalBoundedSemaphore) + + await records_api._post(url_path="/streams/s1/records/retrieve", json={"items": []}, semaphore=sem) + dedicated, query = sem._semaphores + assert dedicated._value == 20 + assert query._value == 30 + + async def test_retrieve_pattern_immutable(self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock) -> None: + httpx_mock.add_response(method="POST", url=re.compile(r".*"), status_code=200, json={"items": []}) + records_api = async_client.data_modeling.records + sem = records_api._get_semaphore("retrieve", "immutable") + assert isinstance(sem, HierarchicalBoundedSemaphore) + + await records_api._post(url_path="/streams/s1/records/retrieve", json={"items": []}, semaphore=sem) + dedicated, query = sem._semaphores + assert dedicated._value == 10 + assert query._value == 10 + + async def test_aggregate_pattern_mutable(self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock) -> None: + httpx_mock.add_response(method="POST", url=re.compile(r".*"), status_code=200, json={"items": []}) + records_api = async_client.data_modeling.records + sem = records_api._get_semaphore("aggregate", "mutable") + assert isinstance(sem, HierarchicalBoundedSemaphore) + + await records_api._post(url_path="/streams/s1/records/aggregate", json={}, semaphore=sem) + dedicated, query = sem._semaphores + assert dedicated._value == 10 + assert query._value == 30 + + async def test_aggregate_pattern_immutable(self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock) -> None: + httpx_mock.add_response(method="POST", url=re.compile(r".*"), status_code=200, json={"items": []}) + records_api = async_client.data_modeling.records + sem = records_api._get_semaphore("aggregate", "immutable") + assert isinstance(sem, HierarchicalBoundedSemaphore) + + await records_api._post(url_path="/streams/s1/records/aggregate", json={}, semaphore=sem) + dedicated, query = sem._semaphores + assert dedicated._value == 5 + assert query._value == 10 + + async def test_retrieve_and_list_share_query_semaphore_through_post( + self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock + ) -> None: + """A retrieve request and a list request against the same stream type + must compete for the same query semaphore.""" + httpx_mock.add_response(method="POST", url=re.compile(r".*"), status_code=200, json={"items": []}, is_optional=True) + httpx_mock.add_response(method="POST", url=re.compile(r".*"), status_code=200, json={"items": []}, is_optional=True) + records_api = async_client.data_modeling.records + + retrieve_sem = records_api._get_semaphore("retrieve", "mutable") + list_sem = records_api._get_semaphore("query", "mutable") + + assert isinstance(retrieve_sem, HierarchicalBoundedSemaphore) + assert retrieve_sem._semaphores[1] is list_sem + + await records_api._post(url_path="/streams/s1/records/retrieve", json={"items": []}, semaphore=retrieve_sem) + await records_api._post(url_path="/streams/s1/records/filter", json={"limit": 10}, semaphore=list_sem) + + class TestStrictFixtureCatchesMissingSemaphore: """Sanity check that the suite-wide strict fixture in tests/conftest.py still works.