Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 194 additions & 51 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import urllib.parse
import uuid
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Callable, Any, Iterable, Mapping, Optional, Sequence, Tuple, Union, cast
from typing import Callable, Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union, cast
from typing_extensions import TypedDict
from urllib3.util.retry import Retry

Expand Down Expand Up @@ -73,7 +73,18 @@
from ._read_items_helper import ReadItemsHelperSync
from ._request_object import RequestObject
from ._retry_utility import ConnectionRetryPolicy
from ._routing import routing_map_provider, routing_range
from ._routing import routing_map_provider
from ._routing.feed_range_continuation import (
_apply_feedrange_request_headers,
_build_scope_from_overlaps,
_decode_token,
_derive_initial_feedranges,
_explode_feedrange_on_multi_overlap,
_extract_resume_state,
_normalize_max_item_count,
_set_outbound_continuation,
_validate_token_identity,
)
from ._query_advisor import get_query_advice_info
from ._inference_service import _InferenceService
from .documents import ConnectionPolicy, DatabaseAccount
Expand Down Expand Up @@ -3240,7 +3251,17 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements, too-ma
if timeout is not None:
kwargs.setdefault("timeout", timeout)

internal_headers_capture = kwargs.pop("_internal_response_headers_capture", None)
internal_headers_capture: Optional[Dict[str, Any]] = kwargs.pop(
"_internal_response_headers_capture", None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Recommendation · Correctness: _internal_response_headers_capture may be routed through wrong parameter path

The execution context injects the capture dict into new_options (the options parameter):

# base_execution_context.py line 126
new_options["_internal_response_headers_capture"] = self._internal_response_headers_capture

But __QueryFeed reads it from kwargs:

# _cosmos_client_connection.py line 3265-3266
internal_headers_capture: Optional[Dict[str, Any]] = kwargs.pop(
    "_internal_response_headers_capture", None
)

Whether this works depends on how the execution context's fetch_function forwards new_options. If fetch_function(new_options) passes new_options as the options positional parameter (which __QueryFeed accepts as options), then the capture dict is in options, not kwargs — and kwargs.pop(...) always returns None.

I traced the QueryItems fetch_fn closure and it does: self.__QueryFeed(path, ..., options, ..., **kwargs) where options is positional. So _internal_response_headers_capture ends up in options but is read from kwargs.

If this is the case, the capture dict never reaches __QueryFeed, and _capture_internal_headers is always a no-op — making the entire checkpoint mechanism (including the fix in Finding #1) non-functional. Please verify the fetch_fn closure chain and confirm the capture dict actually reaches kwargs.

⚠️ AI-generated review — may be incorrect. Agree? → resolve the conversation. Disagree? → reply with your reasoning.

)

def _capture_internal_headers(headers: Mapping[str, Any]) -> None:
# Local helper so flow analysis can narrow Optional[Dict] once
# and every call site stays a single line.
if internal_headers_capture is None:
return
internal_headers_capture.clear()
internal_headers_capture.update(headers)

if query:
__GetBodiesFromQueryResult = result_fn
Expand Down Expand Up @@ -3293,8 +3314,7 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]:
result, last_response_headers = self.__Get(path, request_params, headers, **kwargs)
self.last_response_headers = last_response_headers
if internal_headers_capture is not None:
internal_headers_capture.clear()
internal_headers_capture.update(last_response_headers)
_capture_internal_headers(last_response_headers)
if response_headers_list is not None:
response_headers_list.append(last_response_headers.copy())
if response_hook:
Expand Down Expand Up @@ -3348,64 +3368,187 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]:

# If feed_range_epk exist, query with the range
if feed_range_epk is not None:
Comment thread
dibahlfi marked this conversation as resolved.
over_lapping_ranges = self._routing_map_provider.get_overlapping_ranges(resource_id, [feed_range_epk],
options)
# It is possible to get more than one over lapping range. We need to get the query results for each one
# (a) Look at the continuation the caller passed in.
# - Empty or from a pre-fix SDK: start fresh.
# - One of our v=1 envelopes: check the collection, query, and
# feed_range still match before resuming from it.
inbound = _decode_token(options.get("continuation"))
if inbound is not None:
_validate_token_identity(inbound, resource_id, query, feed_range_epk)
current_feedrange, remaining_feedranges, next_backend_cont = _extract_resume_state(inbound)
else:
# First call (or legacy passthrough). Ask the routing map which
# partitions the input feed_range overlaps right now and turn
# each overlap into a feedrange (intersection of that partition
# and the input feed_range).
first_overlaps = self._routing_map_provider.get_overlapping_ranges(
resource_id, [feed_range_epk], options
)
all_feedranges = _derive_initial_feedranges(feed_range_epk, first_overlaps)
if not all_feedranges:
# The input feed_range does not overlap any current
# physical partition. Mirror the async path: stamp
# an empty headers object onto self.last_response_headers
# so callers reading it after a no-op call see a
# consistent (empty) value instead of whatever the
# previous request left behind.
empty_headers: CaseInsensitiveDict = CaseInsensitiveDict()
self.last_response_headers = empty_headers
return [], empty_headers
current_feedrange, remaining_feedranges = all_feedranges[0], all_feedranges[1:]
next_backend_cont = None

# (b) max_item_count is the cap for the page we hand back to the
# caller. We may need several backend POSTs to fill it (one per
# feedrange we have to query). Once the cap is reached we stop;
# any feedranges we have not started yet go into the outbound
# token so the next call picks up from there.
#
# Non-positive or non-numeric caps are normalized to "unbounded"
# (None) by _normalize_max_item_count: a cap of 0 / -1 would
# otherwise make the loop short-circuit before issuing any POST
# and emit a continuation token whose current_feedrange is
# unchanged - the caller would then pull empty pages forever
# without making progress.
remaining_budget = _normalize_max_item_count(options.get("maxItemCount"))

results: dict[str, Any] = {}
# For each over lapping range we will take a sub range of the feed range EPK that overlaps with the over
# lapping physical partition. The EPK sub range will be one of four:
# 1) Will have a range min equal to the feed range EPK min, and a range max equal to the over lapping
# partition
# 2) Will have a range min equal to the over lapping partition range min, and a range max equal to the
# feed range EPK range max.
# 3) will match exactly with the current over lapping physical partition, so we just return the over lapping
# physical partition's partition key id.
# 4) Will equal the feed range EPK since it is a sub range of a single physical partition
for over_lapping_range in over_lapping_ranges:
single_range = routing_range.Range.PartitionKeyRangeToRange(over_lapping_range)
# Since the range min and max are all Upper Cased string Hex Values,
# we can compare the values lexicographically
EPK_sub_range = routing_range.Range(range_min=max(single_range.min, feed_range_epk.min),
range_max=min(single_range.max, feed_range_epk.max),
isMinInclusive=True, isMaxInclusive=False)

# set the session token for this specific partition to avoid sending compound token for all partitions
base.set_session_token_header(self, req_headers, path, request_params, options,
over_lapping_range["id"])
if single_range.min == EPK_sub_range.min and EPK_sub_range.max == single_range.max:
# The Epk Sub Range spans exactly one physical partition
# In this case we can route to the physical pk range id
req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = over_lapping_range["id"]
last_response_headers: CaseInsensitiveDict = CaseInsensitiveDict()

while True:
Comment thread
dibahlfi marked this conversation as resolved.
Outdated
if remaining_budget is not None and remaining_budget <= 0:
break # cap reached; carry the unfinished feedranges forward

# Look up the live routing map for the current feedrange.
# Doing this every iteration is what makes the token
# split-safe.
overlapping = self._routing_map_provider.get_overlapping_ranges(
Comment thread
dibahlfi marked this conversation as resolved.
Outdated
resource_id, [current_feedrange], options
)
overlapping, partition_scope = _build_scope_from_overlaps(
overlapping, current_feedrange
)

# Handle the case where Cosmos split a partition between the
# previous run and this one. Example: the saved
# current_feedrange used to live inside one partition X, but X
# has since been split into children X1 and X2. The routing
# map now returns two partitions for the same feedrange. If
# we sent one POST to X1 with X's full range as the EPK
# filter, the backend would filter in-partition only and
# silently drop every row living on X2 (that is how a resume
# after a split came back 19 ids short of ground truth in
# test_post_split_resume_async).
#
# So when the lookup returns more than one partition, slice
# the saved feedrange into one sub-feedrange per child
# (intersection with the saved feedrange, ordered by EPK
# min), make the first sub-feedrange the new current one,
# put the rest in front of the remaining list, and clear the
# saved backend continuation - it was issued by the old
# parent partition and the children won't accept it. The next
# loop iteration sees a single overlap and falls through to
# the normal single-partition POST below.
#
# Note: if the caller had already pulled some rows from X
# before the split, those rows show up again on this resume.
# The customer dedupes by document id.
current_feedrange, remaining_feedranges, did_explode = _explode_feedrange_on_multi_overlap(
current_feedrange,
overlapping,
remaining_feedranges,
)
if did_explode:
next_backend_cont = None
overlapping = self._routing_map_provider.get_overlapping_ranges(
resource_id, [current_feedrange], options
)
overlapping, partition_scope = _build_scope_from_overlaps(
overlapping, current_feedrange
)

sub_options = dict(options)
if remaining_budget is not None:
sub_options["maxItemCount"] = remaining_budget
if next_backend_cont is not None:
sub_options["continuation"] = next_backend_cont
else:
# The Epk Sub Range spans less than a single physical partition
# In this case we route to the physical partition and
# pass the epk sub range to the headers to filter within partition
req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = over_lapping_range["id"]
req_headers[http_constants.HttpHeaders.StartEpkString] = EPK_sub_range.min
req_headers[http_constants.HttpHeaders.EndEpkString] = EPK_sub_range.max
req_headers[http_constants.HttpHeaders.ReadFeedKeyType] = "EffectivePartitionKeyRange"
partial_result, last_response_headers = self.__Post(
sub_options.pop("continuation", None)

# Populate request headers for this single backend POST.
# The shared helper handles partition routing (PKR id +
# optional EPK filter), page-size cap, and continuation
# set/clear so the same rules apply to sync and async.
_apply_feedrange_request_headers(
req_headers,
overlapping,
partition_scope,
current_feedrange,
remaining_budget,
sub_options.get("continuation"),
)
# Use the session token for this specific partition so we don't
# send a compound token covering all partitions.
base.set_session_token_header(
self, req_headers, path, request_params, sub_options, overlapping[0]["id"]
)

partial_result, sub_response_headers = self.__Post(
path, request_params, query, req_headers, **kwargs
)
last_response_headers = sub_response_headers
self.last_response_headers = last_response_headers
if internal_headers_capture is not None:
internal_headers_capture.clear()
internal_headers_capture.update(last_response_headers)
self._UpdateSessionIfRequired(req_headers, partial_result, last_response_headers)
# Introducing a temporary complex function into a critical path to handle aggregated queries
# during splits, as a precaution falling back to the original logic if anything goes wrong
_capture_internal_headers(sub_response_headers)
self._UpdateSessionIfRequired(req_headers, partial_result, sub_response_headers)

# Merge results, falling back to a plain extend if the
# aggregating merge raises (it can on aggregated queries
# during splits).
try:
results = base._merge_query_results(results, partial_result, query)
Comment thread
dibahlfi marked this conversation as resolved.
Outdated
except Exception: # pylint: disable=broad-exception-caught
# If the new merge logic fails, fall back to the original logic.
except Exception: # pylint: disable=broad-exception-caught
Comment thread
dibahlfi marked this conversation as resolved.
Outdated
if results:
results["Documents"].extend(partial_result["Documents"])
else:
results = partial_result

items_returned = len(partial_result.get("Documents", []))
if remaining_budget is not None:
remaining_budget -= items_returned
if response_headers_list is not None:
response_headers_list.append(last_response_headers.copy())
response_headers_list.append(sub_response_headers.copy())
if response_hook:
response_hook(last_response_headers, partial_result)
response_hook(sub_response_headers, partial_result)

next_backend_cont = sub_response_headers.get(http_constants.HttpHeaders.Continuation)
if next_backend_cont:
# Current feedrange has more to give. Stay on it; the
# budget check at the top of the next iteration decides
# whether to issue another POST.
continue

# Current feedrange is drained. Move to the next one if there
# is one; otherwise we are done.
if not remaining_feedranges:
current_feedrange = None
break
current_feedrange = remaining_feedranges.pop(0)
next_backend_cont = None

# (c) Build the outbound token. Clear the continuation header if
# there is no work left at all.
_set_outbound_continuation(
Comment thread
dibahlfi marked this conversation as resolved.
Outdated
last_response_headers,
resource_id,
query,
feed_range_epk,
current_feedrange,
remaining_feedranges,
next_backend_cont,
)
self.last_response_headers = last_response_headers

# if the prefix partition query has results lets return it
if results:
if last_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None:
Expand All @@ -3417,12 +3560,12 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]:
last_response_headers[http_constants.HttpHeaders.QueryAdvice] = (
get_query_advice_info(query_advice_raw))
return __GetBodiesFromQueryResult(results), last_response_headers
return [], last_response_headers

result, last_response_headers = self.__Post(path, request_params, query, req_headers, **kwargs)
self.last_response_headers = last_response_headers
if internal_headers_capture is not None:
internal_headers_capture.clear()
internal_headers_capture.update(last_response_headers)
_capture_internal_headers(last_response_headers)
self._UpdateSessionIfRequired(req_headers, result, last_response_headers)
if last_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None:
INDEX_METRICS_HEADER = http_constants.HttpHeaders.IndexUtilization
Expand Down
Loading
Loading