4444if TYPE_CHECKING :
4545 from ...aio ._cosmos_client_connection_async import CosmosClientConnection
4646
47- # Module-level shared state, keyed by endpoint URL. All four dicts and the
47+ # Module-level shared state, keyed by endpoint URL. All five dicts and the
4848# refcount are mutated only while holding ``_shared_cache_lock``. Sharing across
4949# every async CosmosClient that targets the same endpoint is what eliminates
5050# the per-client duplicate copies of the routing map (the memory win driving
7575# and defeat the single-flight invariant.
7676_shared_locks_locks : Dict [str , threading .Lock ] = {}
7777
78+ # endpoint -> { (loop_id, collection_id) -> asyncio.Task }. The single
79+ # in-flight fetch-and-publish task per (loop, collection). Any caller that
80+ # arrives during a cold-cache fetch joins this task via ``asyncio.shield``
81+ # instead of issuing its own network round trip, so concurrent callers
82+ # share a single fetch. The task body owns both the fetch and the cache
83+ # write, so the publish survives any individual caller being cancelled
84+ # (e.g. by ``asyncio.wait_for``) while awaiting it.
85+ _shared_inflight_fetches : Dict [str , Dict [tuple , asyncio .Task ]] = {}
86+
7887# endpoint -> int. Number of live async ``PartitionKeyRangeCache`` instances
7988# using this endpoint. Incremented on construction and decremented in
8089# ``release`` (called from ``CosmosClient.__aexit__`` / ``close`` / ``__del__``).
81- # When the count hits zero we drop the entry from all four dicts so an idle
90+ # When the count hits zero we drop the entry from all five dicts so an idle
8291# endpoint does not pin memory forever. ``clear_cache`` does NOT touch this
8392# count — it only wipes routing-map contents.
8493_shared_cache_refcounts : Dict [str , int ] = {}
8594
86- # Process-wide lock guarding the four dicts above for *this* (async) module.
95+ # Process-wide lock guarding the five dicts above for *this* (async) module.
8796# Note: the sync module ``_routing/routing_map_provider.py`` defines its own
8897# independent set of module-level dicts and its own ``_shared_cache_lock`` —
8998# state is NOT shared between the sync and async modules. A sync and an async
@@ -123,20 +132,23 @@ def __init__(self, client: Any):
123132 self ._endpoint = _resolve_endpoint (client )
124133 self ._released = False
125134
126- # Share routing map cache, per-collection asyncio locks, and the
127- # per-endpoint meta-lock that guards the per-collection-lock dict
128- # across all clients with the same endpoint. Refcount lets us evict
129- # the entry when the last sharing client releases it (see ``release``).
135+ # Share routing map cache, per-collection asyncio locks, the
136+ # per-endpoint meta-lock that guards the per-collection-lock dict,
137+ # and the in-flight fetch-task dict across all clients with the same
138+ # endpoint. Refcount lets us evict the entry when the last sharing
139+ # client releases it (see ``release``).
130140 with _shared_cache_lock :
131141 if self ._endpoint not in _shared_routing_map_cache :
132142 _shared_routing_map_cache [self ._endpoint ] = {}
133143 _shared_collection_locks [self ._endpoint ] = {}
134144 _shared_locks_locks [self ._endpoint ] = threading .Lock ()
145+ _shared_inflight_fetches [self ._endpoint ] = {}
135146 _shared_cache_refcounts [self ._endpoint ] = 0
136147 _shared_cache_refcounts [self ._endpoint ] += 1
137148 self ._collection_routing_map_by_item = _shared_routing_map_cache [self ._endpoint ]
138149 self ._collection_locks : Dict [tuple , asyncio .Lock ] = _shared_collection_locks [self ._endpoint ]
139150 self ._locks_lock : threading .Lock = _shared_locks_locks [self ._endpoint ]
151+ self ._inflight_fetches : Dict [tuple , asyncio .Task ] = _shared_inflight_fetches [self ._endpoint ]
140152
141153 def clear_cache (self ):
142154 """Clear the shared routing map cache for this endpoint.
@@ -145,13 +157,13 @@ def clear_cache(self):
145157 client references to the same dict object, so concurrent clients
146158 sharing the endpoint continue to share a single cache instance.
147159
148- The per-collection locks dict is intentionally **not** cleared here:
149- an in-flight ``_fetch_routing_map`` caller holds one of those locks
150- and will write its result into the (now-empty) shared cache when it
151- completes. Keeping the lock in place ensures that any concurrent
152- arrival serialises behind the in-flight refresh (single-flight
153- invariant) instead of racing it with a fresh lock. The locks dict
154- is evicted in ``release()`` once the endpoint refcount hits zero.
160+ The per-collection locks dict and the in-flight fetch-task dict are
161+ intentionally **not** cleared here. A fetch task scheduled before
162+ this call keeps a reference to the (now-empty) cache dict and will
163+ publish its result into it when it completes; any concurrent arrival
164+ meanwhile joins that same task instead of racing it. Both auxiliary
165+ dicts are evicted in ``release()`` once the endpoint refcount hits
166+ zero.
155167 """
156168 with _shared_cache_lock :
157169 if self ._endpoint in _shared_routing_map_cache :
@@ -180,6 +192,7 @@ def release(self) -> None:
180192 _shared_routing_map_cache .pop (endpoint , None )
181193 _shared_collection_locks .pop (endpoint , None )
182194 _shared_locks_locks .pop (endpoint , None )
195+ _shared_inflight_fetches .pop (endpoint , None )
183196 else :
184197 _shared_cache_refcounts [endpoint ] = count
185198 except Exception : # pylint: disable=broad-except
@@ -267,9 +280,13 @@ async def get_routing_map(
267280 ) -> Optional [CollectionRoutingMap ]:
268281 """Gets or refreshes the routing map for a collection.
269282
270- This method handles the logic for fetching, caching, and updating the
271- collection's routing map. It uses a locking mechanism to prevent race
272- conditions during concurrent updates.
283+ Concurrent callers that arrive while a fetch is already in flight for
284+ the same collection join that fetch via ``asyncio.shield`` rather than
285+ issuing their own network round trip. The fetch task owns the cache
286+ write, so the publish completes even if every awaiting caller is
287+ cancelled (for example by ``asyncio.wait_for``) before the fetch
288+ returns. The next caller — whether the original caller retrying or a
289+ new one — finds the cache populated.
273290
274291 :param str collection_link: The link to the collection.
275292 :param Optional[Dict[str, Any]] feed_options: Optional query options.
@@ -281,37 +298,136 @@ async def get_routing_map(
281298 """
282299 collection_id = _base .GetResourceIdOrFullNameFromLink (collection_link )
283300
284- # First check (no lock) for the fast path .
301+ # Fast path: cache hit without acquiring any lock .
285302 if not force_refresh :
286303 cached_map = self ._collection_routing_map_by_item .get (collection_id )
287304 if cached_map :
288305 return cached_map
289306
290- # Acquire lock only when a refresh or initial load is likely needed.
307+ fetch_task = await self ._register_or_join_inflight_fetch (
308+ collection_id ,
309+ collection_link ,
310+ feed_options ,
311+ force_refresh ,
312+ previous_routing_map ,
313+ kwargs ,
314+ )
315+
316+ if fetch_task is not None :
317+ # ``shield`` ensures our cancellation only unwinds *this* awaiter;
318+ # the underlying task keeps running on the event loop and the
319+ # cache write inside the task body still happens. Other waiters
320+ # (and any subsequent caller hitting the now-populated cache) are
321+ # unaffected by our cancellation.
322+ await asyncio .shield (fetch_task )
323+
324+ return self ._collection_routing_map_by_item .get (collection_id )
325+
326+ async def _register_or_join_inflight_fetch (
327+ self ,
328+ collection_id : str ,
329+ collection_link : str ,
330+ feed_options : Optional [Dict [str , Any ]],
331+ force_refresh : bool ,
332+ previous_routing_map : Optional [CollectionRoutingMap ],
333+ fetch_kwargs : Dict [str , Any ],
334+ ) -> Optional [asyncio .Task ]:
335+ """Return the in-flight fetch task for this collection, creating one if needed.
336+
337+ Holding the per-collection lock for just the check-or-create window
338+ (no network I/O inside the lock) keeps the critical section small.
339+ The returned task may be one this call just scheduled or one a
340+ concurrent caller scheduled moments earlier — either way the caller
341+ should await it through ``asyncio.shield``.
342+
343+ :param str collection_id: The resolved collection identifier used as the cache key.
344+ :param str collection_link: The link to the collection.
345+ :param Optional[Dict[str, Any]] feed_options: Optional query options.
346+ :param bool force_refresh: Whether the caller asked for a refresh.
347+ :param Optional[CollectionRoutingMap] previous_routing_map: The caller's last
348+ observed routing map, used by the refresh-decision helper.
349+ :param Dict[str, Any] fetch_kwargs: Pipeline kwargs forwarded to the fetch.
350+ :return: A running ``asyncio.Task`` to await, or ``None`` if no fetch
351+ is needed (cache was populated by a concurrent caller after the
352+ fast-path check).
353+ :rtype: Optional[asyncio.Task]
354+ """
355+ inflight_key = (id (asyncio .get_running_loop ()), collection_id )
291356 collection_lock = await self ._get_lock_for_collection (collection_id )
292357 async with collection_lock :
293- # Second check (with lock) — use shared helper for the decision logic.
358+ existing_task = self ._inflight_fetches .get (inflight_key )
359+ if existing_task is not None :
360+ return existing_task
361+
294362 should_fetch , base_routing_map = determine_refresh_action (
295363 self ._collection_routing_map_by_item ,
296364 collection_id ,
297365 force_refresh ,
298366 previous_routing_map ,
299367 )
368+ if not should_fetch :
369+ return None
300370
301- if should_fetch :
302- new_routing_map = await self ._fetch_routing_map (
303- collection_link ,
371+ new_task = asyncio .create_task (
372+ self ._fetch_and_publish (
304373 collection_id ,
374+ collection_link ,
305375 base_routing_map ,
306376 feed_options ,
307- ** kwargs
377+ inflight_key ,
378+ fetch_kwargs ,
308379 )
380+ )
381+ self ._inflight_fetches [inflight_key ] = new_task
382+ return new_task
383+
384+ async def _fetch_and_publish (
385+ self ,
386+ collection_id : str ,
387+ collection_link : str ,
388+ base_routing_map : Optional [CollectionRoutingMap ],
389+ feed_options : Optional [Dict [str , Any ]],
390+ inflight_key : tuple ,
391+ fetch_kwargs : Dict [str , Any ],
392+ ) -> Optional [CollectionRoutingMap ]:
393+ """Run ``_fetch_routing_map`` and publish its result, then free the in-flight slot.
394+
395+ The cache assignment lives inside this task body so a caller's
396+ cancellation while awaiting the task cannot interrupt the publish.
397+ The ``finally`` block always frees the in-flight slot — on success,
398+ on a fetch error, or on cancellation — so the next caller is free to
399+ schedule a fresh attempt.
309400
310- # Update the cache.
311- if new_routing_map :
312- self ._collection_routing_map_by_item [collection_id ] = new_routing_map
401+ :param str collection_id: The resolved collection identifier used as the cache key.
402+ :param str collection_link: The link to the collection.
403+ :param Optional[CollectionRoutingMap] base_routing_map: The base routing map
404+ for incremental updates, or ``None`` for a full load.
405+ :param Optional[Dict[str, Any]] feed_options: Optional query options.
406+ :param tuple inflight_key: The ``(loop_id, collection_id)`` key into the in-flight dict.
407+ :param Dict[str, Any] fetch_kwargs: Pipeline kwargs forwarded to the fetch.
408+ :return: The new routing map, or ``None`` if the fetch produced nothing.
409+ :rtype: Optional[CollectionRoutingMap]
410+ """
411+ try :
412+ new_routing_map = await self ._fetch_routing_map (
413+ collection_link ,
414+ collection_id ,
415+ base_routing_map ,
416+ feed_options ,
417+ ** fetch_kwargs ,
418+ )
313419
314- return self ._collection_routing_map_by_item .get (collection_id )
420+ if new_routing_map :
421+ self ._collection_routing_map_by_item [collection_id ] = new_routing_map
422+
423+ return new_routing_map
424+ finally :
425+ # Atomic single-key removal; no lock needed. Runs on success,
426+ # on fetch error, and on cancellation alike, so the next caller
427+ # can register a fresh fetch immediately.
428+ inflight_fetches = self ._inflight_fetches
429+ if inflight_key in inflight_fetches :
430+ del inflight_fetches [inflight_key ]
315431
316432
317433 async def _fetch_routing_map (
0 commit comments