Skip to content

Commit b98cd8a

Browse files
committed
addressing PR comments
1 parent 9a1e9df commit b98cd8a

8 files changed

Lines changed: 431 additions & 113 deletions

sdk/cosmos/azure-cosmos/azure/cosmos/_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,12 @@ def _merge_query_results(
211211
results_docs[0] = min(results_docs[0], partial_docs[0]) # type: ignore[index]
212212
elif aggregate_fn == "MAX":
213213
results_docs[0] = max(results_docs[0], partial_docs[0]) # type: ignore[index]
214+
elif aggregate_fn == "AVG":
215+
raise ValueError(
216+
"VALUE AVG aggregate merge across partitions is not supported client-side."
217+
)
214218
else:
215-
# COUNT/SUM are additive; VALUE AVG is not fully supported client-side yet.
219+
# COUNT/SUM are additive.
216220
results_docs[0] += partial_docs[0] # type: ignore[index]
217221
return results
218222

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 24 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3376,19 +3376,21 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]:
33763376
# the feed_range helpers below read as ``resource_id_str`` instead
33773377
# of the generic ``resource_id``.
33783378
resource_id_str: str = resource_id
3379-
# (a) Look at the continuation the caller passed in.
3380-
# - Empty or from a pre-fix SDK: start fresh.
3381-
# - One of our v=1 envelopes: check the collection, query, and
3382-
# feed_range still match before resuming from it.
3383-
#
3384-
# Shared state transitions (resume, split handling, page-item update,
3385-
# outbound token) live in _FeedRangePaginationState so sync/async
3386-
# stay behaviorally aligned.
3379+
# Decode the inbound continuation. Empty/legacy -> start fresh
3380+
# (``_decode_token`` returns ``None``); a valid v=1 envelope
3381+
# is checked against the current collection/query/feed_range
3382+
# before we resume from it. The shared
3383+
# ``_FeedRangePaginationState`` owns all state transitions
3384+
# (resume, split handling, page-item update, outbound token)
3385+
# so the sync and async loops below remain twin code paths
3386+
# — change one, change the other.
33873387
items_left_in_page = _normalize_max_item_count(options.get("maxItemCount"))
3388-
inbound = _decode_token(options.get("continuation"))
3389-
if inbound is not None:
3390-
_validate_token_identity(inbound, resource_id_str, query, feed_range_epk)
3391-
pagination_state = _FeedRangePaginationState.from_inbound(inbound, items_left_in_page)
3388+
inbound_token_payload = _decode_token(options.get("continuation"))
3389+
if inbound_token_payload is not None:
3390+
_validate_token_identity(inbound_token_payload, resource_id_str, query, feed_range_epk)
3391+
pagination_state = _FeedRangePaginationState.from_inbound(
3392+
inbound_token_payload, items_left_in_page
3393+
)
33923394
else:
33933395
# First call (or legacy passthrough). Ask the routing map which
33943396
# partitions the input feed_range overlaps right now and turn
@@ -3432,30 +3434,10 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]:
34323434
overlapping, head_feedrange
34333435
)
34343436

3435-
# Handle the case where Cosmos split a partition between the
3436-
# previous run and this one. Example: the saved
3437-
# head_feedrange used to live inside one partition X, but X
3438-
# has since been split into children X1 and X2. The routing
3439-
# map now returns two partitions for the same feedrange. If
3440-
# we sent one POST to X1 with X's full range as the EPK
3441-
# filter, the backend would filter in-partition only and
3442-
# silently drop every row living on X2 (resume after a
3443-
# split would then come back short of ground truth).
3444-
#
3445-
# So when the lookup returns more than one partition, slice
3446-
# the saved feedrange into one sub-feedrange per child
3447-
# (intersection with the saved feedrange, ordered by EPK
3448-
# min), make the first sub-feedrange the new current one,
3449-
# put the rest in front of the remaining list, and clear the
3450-
# saved backend continuation - it was issued by the old
3451-
# parent partition and the children won't accept it. The next
3452-
# loop iteration sees a single overlap and falls through to
3453-
# the normal single-partition POST below.
3454-
#
3455-
# One edge case remains by design: if some rows were already
3456-
# read from parent X before it split, those rows can show up
3457-
# once more after resume when children X1/X2 restart from the
3458-
# start of their slices.
3437+
# If routing returns multiple overlaps, the head sub-range now spans a split
3438+
# that occurred after the token was created. Re-slice and re-resolve until
3439+
# each head maps to one partition. See
3440+
# ``_FeedRangePaginationState.explode_on_multi_overlap`` for details.
34593441
while pagination_state.explode_on_multi_overlap(overlapping):
34603442
head_feedrange = pagination_state.head_range
34613443
if head_feedrange is None:
@@ -3467,31 +3449,22 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]:
34673449
overlapping, head_feedrange
34683450
)
34693451

3470-
backend_request_options = dict(options)
3471-
if pagination_state.remaining_page_item_count is not None:
3472-
backend_request_options["maxItemCount"] = pagination_state.remaining_page_item_count
3473-
if pagination_state.head_bc is not None:
3474-
backend_request_options["continuation"] = pagination_state.head_bc
3475-
else:
3476-
backend_request_options.pop("continuation", None)
3477-
34783452
# Populate request headers for this single backend POST.
34793453
# The shared helper handles partition routing (PKR id +
34803454
# optional EPK filter), page-size cap, and continuation
34813455
# set/clear so the same rules apply to sync and async.
3482-
assert head_feedrange is not None # narrowed by the loop guards above
34833456
_apply_feedrange_request_headers(
34843457
req_headers,
34853458
overlapping,
34863459
partition_scope,
34873460
head_feedrange,
34883461
pagination_state.remaining_page_item_count,
3489-
backend_request_options.get("continuation"),
3462+
pagination_state.head_bc,
34903463
)
34913464
# Use the session token for this specific partition so we don't
34923465
# send a compound token covering all partitions.
34933466
base.set_session_token_header(
3494-
self, req_headers, path, request_params, backend_request_options, overlapping[0]["id"]
3467+
self, req_headers, path, request_params, options, overlapping[0]["id"]
34953468
)
34963469

34973470
try:
@@ -3500,14 +3473,14 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]:
35003473
)
35013474
except Exception: # pylint: disable=broad-exception-caught
35023475
# Preserve resume progress if a later POST fails mid-page.
3476+
self.last_response_headers = feedrange_response_headers
35033477
try:
35043478
pagination_state.write_outbound_continuation(
35053479
feedrange_response_headers,
35063480
resource_id_str,
35073481
query,
35083482
feed_range_epk,
35093483
)
3510-
self.last_response_headers = feedrange_response_headers
35113484
except Exception as continuation_write_error: # pylint: disable=broad-exception-caught
35123485
_LOGGER.warning(
35133486
"Failed to write continuation while handling query POST failure: %s",
@@ -3571,8 +3544,9 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]:
35713544
)
35723545
)
35733546

3574-
# (c) Build the outbound token. Clear the continuation header if
3575-
# there is no work left at all.
3547+
# Pagination loop is done — write the final outbound
3548+
# continuation (or clear the header if the queue is fully
3549+
# drained) so the caller's ``by_page`` loop terminates.
35763550
pagination_state.write_outbound_continuation(
35773551
feedrange_response_headers,
35783552
resource_id_str,

sdk/cosmos/azure-cosmos/azure/cosmos/_query_aggregate_utils.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ def _extract_query_text(query: Optional[Union[str, dict[str, Any]]]) -> Optional
3939
def _get_select_value_aggregate_function(query: Optional[Union[str, dict[str, Any]]]) -> Optional[str]:
4040
"""Identify the aggregate function for ``SELECT VALUE`` aggregate queries.
4141
42-
This is a lightweight text heuristic over the full query string (not a SQL
43-
parser). Aggregate function tokens inside subqueries can therefore produce
44-
false positives for the outer query.
42+
This is a lightweight text heuristic (not a SQL parser). It extracts only
43+
the OUTER ``SELECT VALUE`` projection and then matches aggregate function
44+
names in that projection so nested subqueries do not drive outer
45+
classification.
4546
4647
:param query: Query text or query spec dictionary.
4748
:type query: Optional[Union[str, dict[str, Any]]]
@@ -53,18 +54,52 @@ def _get_select_value_aggregate_function(query: Optional[Union[str, dict[str, An
5354
return None
5455

5556
normalized = " ".join(query_text.upper().split())
56-
if "SELECT VALUE" not in normalized:
57+
projection = _extract_outer_select_value_projection(normalized)
58+
if projection is None:
5759
return None
5860

59-
# NOTE: This checks the full normalized query text, so aggregate function
60-
# names inside subqueries can still be matched as false positives.
6161
# Match whole function names only (avoid MYCOUNT) and allow COUNT (1).
6262
for aggregate_fn in ("COUNT", "SUM", "MIN", "MAX", "AVG"):
63-
if re.search(rf"(?<![A-Z0-9_]){aggregate_fn}\s*\(", normalized):
63+
if re.search(rf"(?<![A-Z0-9_]){aggregate_fn}\s*\(", projection):
6464
return aggregate_fn
6565
return None
6666

6767

68+
def _extract_outer_select_value_projection(normalized_query: str) -> Optional[str]:
69+
"""Return the outer ``SELECT VALUE`` projection text up to the outer ``FROM``.
70+
71+
Uses a lightweight parenthesis-depth scan so nested subqueries do not
72+
influence outer aggregate detection.
73+
"""
74+
select_value = "SELECT VALUE"
75+
start_idx = normalized_query.find(select_value)
76+
if start_idx < 0:
77+
return None
78+
79+
projection_start = start_idx + len(select_value)
80+
if projection_start < len(normalized_query) and normalized_query[projection_start] == " ":
81+
projection_start += 1
82+
83+
depth = 0
84+
index = projection_start
85+
while index <= len(normalized_query) - 4:
86+
ch = normalized_query[index]
87+
if ch == "(":
88+
depth += 1
89+
elif ch == ")" and depth > 0:
90+
depth -= 1
91+
92+
if depth == 0 and normalized_query[index:index + 4] == "FROM":
93+
prev_char = normalized_query[index - 1] if index > 0 else " "
94+
next_char = normalized_query[index + 4] if index + 4 < len(normalized_query) else " "
95+
if not (prev_char.isalnum() or prev_char == "_") and not (next_char.isalnum() or next_char == "_"):
96+
projection = normalized_query[projection_start:index].strip()
97+
return projection or None
98+
index += 1
99+
100+
return None
101+
102+
68103
def _classify_aggregate_partial(
69104
docs: Any,
70105
query: Optional[Union[str, dict[str, Any]]]

sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,10 @@ class _FeedRangePaginationState:
434434
structure scale uniformly to non-sequential merges / parallel
435435
fan-out: every entry is structurally equal, and any subset of
436436
entries may carry a non-null backend continuation simultaneously.
437+
438+
Not thread-safe. One instance is created per ``query_items`` call
439+
and is mutated only by that call's pagination loop (sync or async)
440+
— never shared across threads or concurrent tasks.
437441
"""
438442

439443
def __init__(
@@ -790,4 +794,3 @@ def _apply_feedrange_request_headers(
790794
req_headers[http_constants.HttpHeaders.Continuation] = inbound_continuation
791795
else:
792796
req_headers.pop(http_constants.HttpHeaders.Continuation, None)
793-

0 commit comments

Comments
 (0)