Skip to content

Commit 904b3e1

Browse files
committed
PKRange Cache fix
1 parent d6f84be commit 904b3e1

4 files changed

Lines changed: 762 additions & 73 deletions

File tree

sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py

Lines changed: 144 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
if TYPE_CHECKING:
4545
from ...aio._cosmos_client_connection_async import CosmosClientConnection
4646

47-
# Module-level shared state, keyed by endpoint URL. All four dicts and the
47+
# Module-level shared state, keyed by endpoint URL. All five dicts and the
4848
# refcount are mutated only while holding ``_shared_cache_lock``. Sharing across
4949
# every async CosmosClient that targets the same endpoint is what eliminates
5050
# the per-client duplicate copies of the routing map (the memory win driving
@@ -75,15 +75,24 @@
7575
# and defeat the single-flight invariant.
7676
_shared_locks_locks: Dict[str, threading.Lock] = {}
7777

78+
# endpoint -> { (loop_id, collection_id) -> asyncio.Task }. The single
79+
# in-flight fetch-and-publish task per (loop, collection). Any caller that
80+
# arrives during a cold-cache fetch joins this task via ``asyncio.shield``
81+
# instead of issuing its own network round trip, so concurrent callers
82+
# share a single fetch. The task body owns both the fetch and the cache
83+
# write, so the publish survives any individual caller being cancelled
84+
# (e.g. by ``asyncio.wait_for``) while awaiting it.
85+
_shared_inflight_fetches: Dict[str, Dict[tuple, asyncio.Task]] = {}
86+
7887
# endpoint -> int. Number of live async ``PartitionKeyRangeCache`` instances
7988
# using this endpoint. Incremented on construction and decremented in
8089
# ``release`` (called from ``CosmosClient.__aexit__`` / ``close`` / ``__del__``).
81-
# When the count hits zero we drop the entry from all four dicts so an idle
90+
# When the count hits zero we drop the entry from all five dicts so an idle
8291
# endpoint does not pin memory forever. ``clear_cache`` does NOT touch this
8392
# count — it only wipes routing-map contents.
8493
_shared_cache_refcounts: Dict[str, int] = {}
8594

86-
# Process-wide lock guarding the four dicts above for *this* (async) module.
95+
# Process-wide lock guarding the five dicts above for *this* (async) module.
8796
# Note: the sync module ``_routing/routing_map_provider.py`` defines its own
8897
# independent set of module-level dicts and its own ``_shared_cache_lock`` —
8998
# state is NOT shared between the sync and async modules. A sync and an async
@@ -123,20 +132,23 @@ def __init__(self, client: Any):
123132
self._endpoint = _resolve_endpoint(client)
124133
self._released = False
125134

126-
# Share routing map cache, per-collection asyncio locks, and the
127-
# per-endpoint meta-lock that guards the per-collection-lock dict
128-
# across all clients with the same endpoint. Refcount lets us evict
129-
# the entry when the last sharing client releases it (see ``release``).
135+
# Share routing map cache, per-collection asyncio locks, the
136+
# per-endpoint meta-lock that guards the per-collection-lock dict,
137+
# and the in-flight fetch-task dict across all clients with the same
138+
# endpoint. Refcount lets us evict the entry when the last sharing
139+
# client releases it (see ``release``).
130140
with _shared_cache_lock:
131141
if self._endpoint not in _shared_routing_map_cache:
132142
_shared_routing_map_cache[self._endpoint] = {}
133143
_shared_collection_locks[self._endpoint] = {}
134144
_shared_locks_locks[self._endpoint] = threading.Lock()
145+
_shared_inflight_fetches[self._endpoint] = {}
135146
_shared_cache_refcounts[self._endpoint] = 0
136147
_shared_cache_refcounts[self._endpoint] += 1
137148
self._collection_routing_map_by_item = _shared_routing_map_cache[self._endpoint]
138149
self._collection_locks: Dict[tuple, asyncio.Lock] = _shared_collection_locks[self._endpoint]
139150
self._locks_lock: threading.Lock = _shared_locks_locks[self._endpoint]
151+
self._inflight_fetches: Dict[tuple, asyncio.Task] = _shared_inflight_fetches[self._endpoint]
140152

141153
def clear_cache(self):
142154
"""Clear the shared routing map cache for this endpoint.
@@ -145,13 +157,13 @@ def clear_cache(self):
145157
client references to the same dict object, so concurrent clients
146158
sharing the endpoint continue to share a single cache instance.
147159
148-
The per-collection locks dict is intentionally **not** cleared here:
149-
an in-flight ``_fetch_routing_map`` caller holds one of those locks
150-
and will write its result into the (now-empty) shared cache when it
151-
completes. Keeping the lock in place ensures that any concurrent
152-
arrival serialises behind the in-flight refresh (single-flight
153-
invariant) instead of racing it with a fresh lock. The locks dict
154-
is evicted in ``release()`` once the endpoint refcount hits zero.
160+
The per-collection locks dict and the in-flight fetch-task dict are
161+
intentionally **not** cleared here. A fetch task scheduled before
162+
this call keeps a reference to the (now-empty) cache dict and will
163+
publish its result into it when it completes; any concurrent arrival
164+
meanwhile joins that same task instead of racing it. Both auxiliary
165+
dicts are evicted in ``release()`` once the endpoint refcount hits
166+
zero.
155167
"""
156168
with _shared_cache_lock:
157169
if self._endpoint in _shared_routing_map_cache:
@@ -180,6 +192,7 @@ def release(self) -> None:
180192
_shared_routing_map_cache.pop(endpoint, None)
181193
_shared_collection_locks.pop(endpoint, None)
182194
_shared_locks_locks.pop(endpoint, None)
195+
_shared_inflight_fetches.pop(endpoint, None)
183196
else:
184197
_shared_cache_refcounts[endpoint] = count
185198
except Exception: # pylint: disable=broad-except
@@ -267,9 +280,13 @@ async def get_routing_map(
267280
) -> Optional[CollectionRoutingMap]:
268281
"""Gets or refreshes the routing map for a collection.
269282
270-
This method handles the logic for fetching, caching, and updating the
271-
collection's routing map. It uses a locking mechanism to prevent race
272-
conditions during concurrent updates.
283+
Concurrent callers that arrive while a fetch is already in flight for
284+
the same collection join that fetch via ``asyncio.shield`` rather than
285+
issuing their own network round trip. The fetch task owns the cache
286+
write, so the publish completes even if every awaiting caller is
287+
cancelled (for example by ``asyncio.wait_for``) before the fetch
288+
returns. The next caller — whether the original caller retrying or a
289+
new one — finds the cache populated.
273290
274291
:param str collection_link: The link to the collection.
275292
:param Optional[Dict[str, Any]] feed_options: Optional query options.
@@ -281,37 +298,136 @@ async def get_routing_map(
281298
"""
282299
collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link)
283300

284-
# First check (no lock) for the fast path.
301+
# Fast path: cache hit without acquiring any lock.
285302
if not force_refresh:
286303
cached_map = self._collection_routing_map_by_item.get(collection_id)
287304
if cached_map:
288305
return cached_map
289306

290-
# Acquire lock only when a refresh or initial load is likely needed.
307+
fetch_task = await self._register_or_join_inflight_fetch(
308+
collection_id,
309+
collection_link,
310+
feed_options,
311+
force_refresh,
312+
previous_routing_map,
313+
kwargs,
314+
)
315+
316+
if fetch_task is not None:
317+
# ``shield`` ensures our cancellation only unwinds *this* awaiter;
318+
# the underlying task keeps running on the event loop and the
319+
# cache write inside the task body still happens. Other waiters
320+
# (and any subsequent caller hitting the now-populated cache) are
321+
# unaffected by our cancellation.
322+
await asyncio.shield(fetch_task)
323+
324+
return self._collection_routing_map_by_item.get(collection_id)
325+
326+
async def _register_or_join_inflight_fetch(
327+
self,
328+
collection_id: str,
329+
collection_link: str,
330+
feed_options: Optional[Dict[str, Any]],
331+
force_refresh: bool,
332+
previous_routing_map: Optional[CollectionRoutingMap],
333+
fetch_kwargs: Dict[str, Any],
334+
) -> Optional[asyncio.Task]:
335+
"""Return the in-flight fetch task for this collection, creating one if needed.
336+
337+
Holding the per-collection lock for just the check-or-create window
338+
(no network I/O inside the lock) keeps the critical section small.
339+
The returned task may be one this call just scheduled or one a
340+
concurrent caller scheduled moments earlier — either way the caller
341+
should await it through ``asyncio.shield``.
342+
343+
:param str collection_id: The resolved collection identifier used as the cache key.
344+
:param str collection_link: The link to the collection.
345+
:param Optional[Dict[str, Any]] feed_options: Optional query options.
346+
:param bool force_refresh: Whether the caller asked for a refresh.
347+
:param Optional[CollectionRoutingMap] previous_routing_map: The caller's last
348+
observed routing map, used by the refresh-decision helper.
349+
:param Dict[str, Any] fetch_kwargs: Pipeline kwargs forwarded to the fetch.
350+
:return: A running ``asyncio.Task`` to await, or ``None`` if no fetch
351+
is needed (cache was populated by a concurrent caller after the
352+
fast-path check).
353+
:rtype: Optional[asyncio.Task]
354+
"""
355+
inflight_key = (id(asyncio.get_running_loop()), collection_id)
291356
collection_lock = await self._get_lock_for_collection(collection_id)
292357
async with collection_lock:
293-
# Second check (with lock) — use shared helper for the decision logic.
358+
existing_task = self._inflight_fetches.get(inflight_key)
359+
if existing_task is not None:
360+
return existing_task
361+
294362
should_fetch, base_routing_map = determine_refresh_action(
295363
self._collection_routing_map_by_item,
296364
collection_id,
297365
force_refresh,
298366
previous_routing_map,
299367
)
368+
if not should_fetch:
369+
return None
300370

301-
if should_fetch:
302-
new_routing_map = await self._fetch_routing_map(
303-
collection_link,
371+
new_task = asyncio.create_task(
372+
self._fetch_and_publish(
304373
collection_id,
374+
collection_link,
305375
base_routing_map,
306376
feed_options,
307-
**kwargs
377+
inflight_key,
378+
fetch_kwargs,
308379
)
380+
)
381+
self._inflight_fetches[inflight_key] = new_task
382+
return new_task
383+
384+
async def _fetch_and_publish(
385+
self,
386+
collection_id: str,
387+
collection_link: str,
388+
base_routing_map: Optional[CollectionRoutingMap],
389+
feed_options: Optional[Dict[str, Any]],
390+
inflight_key: tuple,
391+
fetch_kwargs: Dict[str, Any],
392+
) -> Optional[CollectionRoutingMap]:
393+
"""Run ``_fetch_routing_map`` and publish its result, then free the in-flight slot.
394+
395+
The cache assignment lives inside this task body so a caller's
396+
cancellation while awaiting the task cannot interrupt the publish.
397+
The ``finally`` block always frees the in-flight slot — on success,
398+
on a fetch error, or on cancellation — so the next caller is free to
399+
schedule a fresh attempt.
309400
310-
# Update the cache.
311-
if new_routing_map:
312-
self._collection_routing_map_by_item[collection_id] = new_routing_map
401+
:param str collection_id: The resolved collection identifier used as the cache key.
402+
:param str collection_link: The link to the collection.
403+
:param Optional[CollectionRoutingMap] base_routing_map: The base routing map
404+
for incremental updates, or ``None`` for a full load.
405+
:param Optional[Dict[str, Any]] feed_options: Optional query options.
406+
:param tuple inflight_key: The ``(loop_id, collection_id)`` key into the in-flight dict.
407+
:param Dict[str, Any] fetch_kwargs: Pipeline kwargs forwarded to the fetch.
408+
:return: The new routing map, or ``None`` if the fetch produced nothing.
409+
:rtype: Optional[CollectionRoutingMap]
410+
"""
411+
try:
412+
new_routing_map = await self._fetch_routing_map(
413+
collection_link,
414+
collection_id,
415+
base_routing_map,
416+
feed_options,
417+
**fetch_kwargs,
418+
)
313419

314-
return self._collection_routing_map_by_item.get(collection_id)
420+
if new_routing_map:
421+
self._collection_routing_map_by_item[collection_id] = new_routing_map
422+
423+
return new_routing_map
424+
finally:
425+
# Atomic single-key removal; no lock needed. Runs on success,
426+
# on fetch error, and on cancellation alike, so the next caller
427+
# can register a fresh fetch immediately.
428+
inflight_fetches = self._inflight_fetches
429+
if inflight_key in inflight_fetches:
430+
del inflight_fetches[inflight_key]
315431

316432

317433
async def _fetch_routing_map(

0 commit comments

Comments
 (0)